每天学一点强化学习(一)
- 多步sarsa的实现方法与原理
- 核心代码
- 强化学习资料
多步sarsa的实现方法与原理
原理与实现来自于《动手学强化学习》:TD算法的实现: sarsa
核心代码
代码同样来自于上述的《动手学强化学习》,这里只说一下思路。通过列表储存历史的状态,动作和奖励。再进行更新时,使用update进行实现,需要先判断列表中的数据是否储存了t到t+n步的历史数据,当储存够n步的历史数据之后,进行sarsa的迭代过程,在这个迭代过程中,s0, a0, r 只是被储存,真正的被更新的动作价值其实是t时刻的,也就是列表的第一个(因为使用了pop将列表第一个值进行了弹出的操作),如代码注释所说,核心更新的代码为以下行,G是通过t:t+n的列表进行计算的,这是符合多步sarsa的原理实现 的:
self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])
class nstep_Sarsa:""" n步Sarsa算法 """def __init__(self, n, ncol, nrow, epsilon, alpha, gamma, n_action=4):self.Q_table = np.zeros([nrow * ncol, n_action])self.n_action = n_actionself.alpha = alphaself.gamma = gammaself.epsilon = epsilonself.n = n # 采用n步Sarsa算法self.state_list = [] # 保存之前的状态self.action_list = [] # 保存之前的动作self.reward_list = [] # 保存之前的奖励def take_action(self, state):if np.random.random() < self.epsilon:action = np.random.randint(self.n_action)else:action = np.argmax(self.Q_table[state])return actiondef best_action(self, state): # 用于打印策略Q_max = np.max(self.Q_table[state])a = [0 for _ in range(self.n_action)]for i in range(self.n_action):if self.Q_table[state, i] == Q_max:a[i] = 1return adef update(self, s0, a0, r, s1, a1, done):self.state_list.append(s0)self.action_list.append(a0)self.reward_list.append(r)if len(self.state_list) == self.n: # 若保存的数据可以进行n步更新G = self.Q_table[s1, a1] # 得到Q(s_{t+n}, a_{t+n})for i in reversed(range(self.n)):G = self.gamma * G + self.reward_list[i] # 不断向前计算每一步的回报# 如果到达终止状态,最后几步虽然长度不够n步,也将其进行更新if done and i > 0:s = self.state_list[i]a = self.action_list[i]self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])s = self.state_list.pop(0) # 将需要更新的状态动作从列表中删除,下次不必更新a = self.action_list.pop(0)self.reward_list.pop(0)# n步Sarsa的主要更新步骤self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])if done: # 如果到达终止状态,即将开始下一条序列,则将列表全清空self.state_list = []self.action_list = []self.reward_list = []
强化学习资料
1: 动手学强化学习:代码比较简单易懂
2: 强化学习的数学原理:讲述强化学习原理,常看常新