( 참고 : Fastcampus 강의 )
[ 23. Q-learning 실습 ]
1. 복습
2. Import Packages
import numpy as np
import matplotlib.pyplot as plt
from src.part2.temporal_difference import QLearner
from src.common.gridworld import GridworldEnv
from src.common.grid_visualization import visualize_value_function, visualize_policy
np.random.seed(0)
3. Agent 생성
QLearner
는 앞서서 봤던 TDAgent
를 상속받는다.
qlearning_agent = QLearner(gamma=1.0,
lr=1e-1,
num_states=env.nS,
num_actions=env.nA,
epsilon=1.0)
4. SARSA vs Q- Learning
(1) SARSA
\(Q(s, a) \leftarrow Q(s, a)+\alpha\left(r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime}\right)-Q(s, a)\right)\).
def update_sample(self, s, a, r, s_, a_, done):
td_target = r + self.gamma * self.q[s_, a_] * (1 - done)
self.q[s, a] += self.lr * (td_target - self.q[s, a])
(2) Q-learning
\(Q(s, a) \leftarrow Q(s, a)+\alpha\left(r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime}\right)-Q(s, a)\right)\).
def update_sample(self, s, a, r, s_, done):
td_target = r + self.gamma * self.q[s_, :].max() * (1 - done)
self.q[s, a] += self.lr * (td_target - self.q[s, a])
차이점 :
- SARSA : self.q[s_, a_]
- Q-Learning : self.q[s_, :].max()
5. Run Iteration
episode 횟수 : 10,000 ( 출력 로그 간격 : 1,000)
num_episode = 10000
print_log = 1000
qlearning_qs = []
iter_idx = []
qlearning_rewards = []
for i in range(num_episode):
reward_sum = 0
env.reset()
while True:
# (1) state 관찰 -> (2) action -> (3) reward,다음 state 받기
s = env.s
a = qlearning_agent.get_action(s)
s_, r, done, info = env.step(a)
##### [SARSA와 달리, 이 과정이 없음] (4) 다음 state에 맞는 action #####
# (5) 앞에서 얻게 된 s,a,r,s_로 update하기
qlearning_agent.update_sample(state=s,
action=a,
reward=r,
next_state=s_,
done=done)
reward_sum += r
if done:
break
qlearning_rewards.append(reward_sum)
if i % print_log == 0:
print("Running {} th episode".format(i))
print("Reward sum : {}".format(reward_sum))
qlearning_qs.append(qlearning_agent.q.copy())
iter_idx.append(i)