( 참고 : Fastcampus 강의 )

[ 31. Actor-Critic 실습 ]


1. TD Actor Critic 복습

Advantage function \(A(s,a)\) :

  • \(V_{\psi}(s)\) 활용해서 추산
  • \(A(s,a) \approx \delta_\psi(s,a) = r+\gamma V_\psi(s')-V(s)\).

figure2


2. Import Packages

import sys; sys.path.append('..')

import gym
import torch

from src.part3.MLP import MultiLayerPerceptron as MLP
from src.part4.ActorCritic import TDActorCritic
from src.common.train_utils import EMAMeter, to_tensor


3. Setting

(1) Environment

env = gym.make('CartPole-v1')
s_dim = env.observation_space.shape[0]
a_dim = env.action_space.n


4. TD Actor Critic

class TDActorCritic(nn.Module):


(1) __init__

def __init__(self,
             policy_net,
             value_net,
             gamma: float = 1.0,
             lr: float = 0.0002):
    super(TDActorCritic, self).__init__()
    self.policy_net = policy_net # (1) Poliy network ( parameter : theta )
    self.value_net = value_net # (2) Value Network ( parameter : psi )
    self.gamma = gamma
    self.lr = lr
	
    # two parameters : (1) theta &  (2) psi
    total_param = list(policy_net.parameters()) + list(value_net.parameters())
    self.optimizer = torch.optim.Adam(params=total_param, lr=lr)

    self._eps = 1e-25
    self._mse = torch.nn.MSELoss()


(2) get_action

def get_action(self, state):
    with torch.no_grad():
        logits = self.policy(state)
        action_dist = Categorical(logits=logits)
        action = action_dist.sample()  
        return action


(3) update

  • td_target : \(r+\gamma V_{\psi}(s')\)
  • td_error : td_target - self.value_net(s) = (\(r+\gamma V_{\psi}(s')\)) - (\(V_{\psi}(s, a)\))
  • loss :
    • loss 1 : -torch.log(prob + self._eps) x td_error :
      • \(\delta \nabla_{\theta} \ln \pi_{\theta}\left(A_{t} \mid S_{t}\right)\).
    • loss 2 : self._mse(v, td_target) :
      • \(\delta=\left\|r+\gamma V_{\psi}(s)-V_{\psi}(s, a)\right\|_{2}\).
def update(self, s, a, r, s_, d):
    # (1) TD target & error 계산하기
    with torch.no_grad():
        td_target = reward + self.gamma * self.value_net(s_) * (1-d)
        td_error = td_target - self.value_net(s)

    # (2) (Loss 계산 위해) log prob 계산하기
    dist = Categorical(logits=self.policy_net(s))
    prob = dist.probs.gather(1, a)

    # (3) 현재 state의 value 계산
    v = self.value_net(s)
    
    # (4) Loss 계산
    loss1 = -torch.log(prob + self._eps) * td_error 
    loss2 = self._mse(v, td_target)
    loss = (loss1+loss2).mean()
	
    # (5) Gradient Descent
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()


5. Run Iteration

(1) Network & Metric

  • policy_net : \(\pi_{\theta}( a \mid s)\)
  • value_net : \(V_{\psi}(s)\)
policy_net = MLP(s_dim, a_dim, [128])
value_net = MLP(s_dim, 1, [128])

agent = TDActorCritic(policy_net, value_net)
ema = EMAMeter()


(2) Run

n_episode = 10000
print_log = 500

( 아래의 코드는 기존의 코드들과 동일하다 )

for ep in range(n_episode):
    s = env.reset()
    cum_r = 0

    while True:
        s = to_tensor(s, size=(1, 4))
        a = agent.get_action(s)
        s_, r, d, info = env.step(a.item())
        s_ = to_tensor(s_, size=(1,4))
        agent.update(s, a.view(-1,1), r, s_, d)
        
        s = ns.numpy()
        cum_r += r
        if done:
            break

    ema.update(cum_r)
    if ep % print_log == 0:
        print("Episode {} || EMA: {} ".format(ep, ema.s))