( 참고 : Fastcampus 강의 )

[ 19. SARSA ( = TD Control ) 실습 ]


1. 복습

(1) N-step TD

figure2


(2) N-step SARSA

figure2


(3) SARSA(\(\lambda\))

figure2


2. Agent 생성

sarsa_agent = SARSA(gamma=1.0,
                    lr=1e-1,
                    num_states=env.nS,
                    num_actions=env.nA,
                    epsilon=1.0)


3. SARSA Update 식

\(Q(s, a) \leftarrow Q(s, a)+\alpha\left(r+\gamma Q\left(s^{\prime}, a^{\prime}\right)-Q(s, a)\right)\).

def update_sample(self, s, a, r, s_, a_, done):
    td_target = r + self.gamma * self.q[s_, a_] * (1 - done)
    self.q[s, a] += self.lr * (td_target - self.q[s, a])


4. Run Iteration

episode 횟수 : 10,000 ( 출력 로그 간격 : 1,000)

num_episode = 10000
print_log = 1000
sarsa_qs = []
iter_idx = []
sarsa_rewards = []


for i in range(num_episode):
    total_reward = 0
    env.reset()    
    while True:
        # (1) state 관찰 -> (2) action -> (3) reward,다음 state 받기
        s = env.s
        a = sarsa_agent.get_action(s)
        s_, r, done, info = env.step(a)
        # (4) 다음 state에 맞는 action
        a_ = sarsa_agent.get_action(s_)
        # (5) 앞에서 얻게 된 s,a,r,s_,a_로 update하기
        sarsa_agent.update_sample(state=s,
                                  action=a,
                                  reward=r,
                                  next_state=s_,
                                  next_action=a_,
                                  done=done)
        total_reward += r
        if done:
            break
    
    sarsa_rewards.append(total_reward)
    if i % print_log == 0:
        print("Running {} th episode".format(i))
        print("Reward sum : {}".format(total_reward))
        sarsa_qs.append(sarsa_agent.q.copy())
        iter_idx.append(i)