( 참고 : Fastcampus 강의 )
[ 31. Actor-Critic 실습 ]
1. TD Actor Critic 복습
Advantage function \(A(s,a)\) :
- \(V_{\psi}(s)\) 활용해서 추산
- \(A(s,a) \approx \delta_\psi(s,a) = r+\gamma V_\psi(s')-V(s)\).
2. Import Packages
import sys; sys.path.append('..')
import gym
import torch
from src.part3.MLP import MultiLayerPerceptron as MLP
from src.part4.ActorCritic import TDActorCritic
from src.common.train_utils import EMAMeter, to_tensor
3. Setting
(1) Environment
env = gym.make('CartPole-v1')
s_dim = env.observation_space.shape[0]
a_dim = env.action_space.n
4. TD Actor Critic
class TDActorCritic(nn.Module):
(1) __init__
def __init__(self,
policy_net,
value_net,
gamma: float = 1.0,
lr: float = 0.0002):
super(TDActorCritic, self).__init__()
self.policy_net = policy_net # (1) Poliy network ( parameter : theta )
self.value_net = value_net # (2) Value Network ( parameter : psi )
self.gamma = gamma
self.lr = lr
# two parameters : (1) theta & (2) psi
total_param = list(policy_net.parameters()) + list(value_net.parameters())
self.optimizer = torch.optim.Adam(params=total_param, lr=lr)
self._eps = 1e-25
self._mse = torch.nn.MSELoss()
(2) get_action
def get_action(self, state):
with torch.no_grad():
logits = self.policy(state)
action_dist = Categorical(logits=logits)
action = action_dist.sample()
return action
(3) update
td_target
: \(r+\gamma V_{\psi}(s')\)td_error
:td_target - self.value_net(s)
= (\(r+\gamma V_{\psi}(s')\)) - (\(V_{\psi}(s, a)\))loss
:loss 1
: -torch.log(prob + self._eps) x td_error :- \(\delta \nabla_{\theta} \ln \pi_{\theta}\left(A_{t} \mid S_{t}\right)\).
loss 2
: self._mse(v, td_target) :- \(\delta=\left\|r+\gamma V_{\psi}(s)-V_{\psi}(s, a)\right\|_{2}\).
def update(self, s, a, r, s_, d):
# (1) TD target & error 계산하기
with torch.no_grad():
td_target = reward + self.gamma * self.value_net(s_) * (1-d)
td_error = td_target - self.value_net(s)
# (2) (Loss 계산 위해) log prob 계산하기
dist = Categorical(logits=self.policy_net(s))
prob = dist.probs.gather(1, a)
# (3) 현재 state의 value 계산
v = self.value_net(s)
# (4) Loss 계산
loss1 = -torch.log(prob + self._eps) * td_error
loss2 = self._mse(v, td_target)
loss = (loss1+loss2).mean()
# (5) Gradient Descent
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
5. Run Iteration
(1) Network & Metric
- policy_net : \(\pi_{\theta}( a \mid s)\)
- value_net : \(V_{\psi}(s)\)
policy_net = MLP(s_dim, a_dim, [128])
value_net = MLP(s_dim, 1, [128])
agent = TDActorCritic(policy_net, value_net)
ema = EMAMeter()
(2) Run
n_episode = 10000
print_log = 500
( 아래의 코드는 기존의 코드들과 동일하다 )
for ep in range(n_episode):
s = env.reset()
cum_r = 0
while True:
s = to_tensor(s, size=(1, 4))
a = agent.get_action(s)
s_, r, d, info = env.step(a.item())
s_ = to_tensor(s_, size=(1,4))
agent.update(s, a.view(-1,1), r, s_, d)
s = ns.numpy()
cum_r += r
if done:
break
ema.update(cum_r)
if ep % print_log == 0:
print("Episode {} || EMA: {} ".format(ep, ema.s))