PyTorch强化学习实战(14)——优先经验回放机制

📅 2026/6/17 14:30:14
PyTorch强化学习实战(14)——优先经验回放机制
PyTorch强化学习实战14——优先经验回放机制0. 前言1. 优先经验回放缓冲区2. 实现优先经验回放缓冲区3. 运行结果4. 超参数调优小结系列链接0. 前言经验回放 (Experience Replay) 通过打破样本间的时序相关性极大地稳定了训练过程使深度Q网络 (Deep Q-Network, DQN) 能够从非平稳分布中高效学习。然而传统经验回放采用均匀采样策略对所有经验样本一视同仁这引发了一个关键问题是否所有经验都具有同等价值2015年DeepMind的研究团队发表了《Prioritized Experience Replay》提出了一种全新的采样机制——优先级经验回放。该方法的核心是强化学习算法可以从更重要、更有价值的经验中学习得更快、更好。论文通过将优先采样机制引入DQN不仅在Atari 2600游戏中取得了显著超越基准的效果更重要的是它为后续几乎所有深度强化学习算法提供了重要的性能加速组件。本节将深入探讨优先经验回放机制解析 优先经验回放机制的原理、实现方案并与经典DQN进行性能对比。1. 优先经验回放缓冲区2015年论文《Prioritized experience replay》提出了提升深度Q网络 (Deep Q-Network, DQN) 训练效率的优先经验回放 (Prioritized experience replay)。该方法通过根据训练损失为缓存样本分配优先级显著提高了回放缓冲区中样本的利用效率。基本DQN使用回放缓冲区来旨在消除回合中连续转移样本之间的相关性。由于环境具有平滑特性(即智能体行为通常不会引发环境剧烈变化)同一回合中的经验样本往往存在高度关联性。然而随机梯度下降 (Stochastic Gradient Descent,SGD) 方法要求训练数据满足独立同分布 (independent and identically distributed,i.i.d) 特性。为了处理这个问题经典DQN的解决方案是使用一个大的转移缓冲区并通过均匀随机采样获取训练批次。论文作者对这种均匀随机采样策略进行改进并证实根据训练损失为缓存样本设置优先级再按优先级比例进行采样能显著提升DQN的收敛速度与策略质量。该方法的核心思想可概括为重点学习非常规的数据。关键在于平衡异常样本与普通样本的训练强度——若过度聚焦缓存中的少数样本不仅会破坏i.i.d特性还可能导致模型在该子集上过拟合。从数学角度来看缓存中每个样本的优先级计算公式为P ( i ) p i α ∑ k p k α P(i)\frac {p_i^\alpha}{\sum_kp_k^\alpha}P(i)∑k​pkα​piα​​其中p i p_ipi​表示第i ii个样本的优先级α αα为优先级权重系数。当α 0 α0α0时采样方式退化为经典DQN的均匀采样α αα值越大高优先级样本的选取概率就越高。该系数是需要调节的超参数论文建议的初始值为0.6。论文提出了多种优先级定义方案其中最常用的是使优先级与贝尔曼更新中的样本损失值成正比。新增至缓存的样本会被赋予最高优先级以确保其能被快速采样。通过调整样本优先级我们会引入数据分布偏差(某些转移样本会被更频繁地采样)这需要通过补偿机制来维持随机梯度下降 (Stochastic Gradient Descent,SGD) 的有效性。为此引入了样本权重系数将每个样本的损失值乘以对应权重w i ( 1 N ⋅ 1 P ( i ) ) β w_i(\frac 1N·\frac {1}{P(i)})^βwi​(N1​⋅P(i)1​)β。其中β ββ是取值0到1之间的另一超参数当β 1 β1β1时可完全抵消采样偏差但实验表明最佳做法是从0-1之间的初始值开始在训练过程中逐步增大至1这样更有利于模型收敛。2. 实现优先经验回放缓冲区要实现优先经验回放缓冲区我们需要对经典深度Q网络 (Deep Q-Network, DQN) 代码进行以下修改首先我们需要一个新的经验回放缓冲区来跟踪优先级根据优先级采样批次数据计算权重并在损失值已知后更新优先级其次需要修改损失函数。不仅需要为每个样本加入权重还需将损失值传回经验回放缓冲区以调整已采样转移的优先级我们在dqn_prio_replay.py中实现以上修改。为保持简洁新的优先级回放缓冲区类采用了与先前回放缓冲区非常相似的存储方案。但遗憾的是优先级的新需求使得无法实现O ( 1 ) O(1)O(1)时间复杂度的采样(即采样时间会随缓冲区容量增加而增长)。若使用简单列表存储每次采样新批次时都需要处理全部优先级数据这使得采样时间复杂度达到O ( N ) O(N)O(N)与缓冲区大小成正比。对于10万样本量级的小型缓冲区影响不大但对于现实应用中百万级转移数据的大型缓冲区可能成为问题。存在其他支持O ( l o g N ) O(log N)O(logN)高效采样的存储方案例如采用线段树数据结构。TorchRL等库提供了不同版本的优化缓冲区实现我们也在lib.experience.PrioritizedReplayBuffer类中提供了高效优先级回放缓冲区可以改用该高效版本并观察其对训练性能的影响。(1)接下来我们以基础版本为例首先定义β ββ参数的增长率BETA_START0.4BETA_FRAMES100_000(2)β ββ值将在前10万帧训练过程中从0.4线性增长至1.0。接下来实现优先级回放缓冲区类classPrioReplayBuffer(ExperienceReplayBuffer):def__init__(self,exp_source:ExperienceSource,buf_size:int,prob_alpha:float0.6):super().__init__(exp_source,buf_size)self.experience_source_iteriter(exp_source)self.capacitybuf_size self.pos0self.buffer[]self.prob_alphaprob_alpha self.prioritiesnp.zeros((buf_size,),dtypenp.float32)self.betaBETA_START优先级回放缓冲区类继承自简易回放缓冲区ExperienceReplayBuffer(后者采用环形缓冲区存储样本可在不重新分配列表空间的情况下保持固定容量)。我们的子类额外使用NumPy数组来维护优先级数据。(3)update_beta()方法需要定期调用以便根据调度增加beta值。populate()方法则负责从ExperienceSource对象提取指定数量的转移数据并存入缓冲区defupdate_beta(self,idx:int)-float:vBETA_STARTidx*(1.0-BETA_START)/BETA_FRAMES self.betamin(1.0,v)returnself.betadefpopulate(self,count):max_prioself.priorities.max()ifself.bufferelse1.0for_inrange(count):samplenext(self.exp_source_iter)iflen(self.buffer)self.capacity:self.buffer.append(sample)else:self.buffer[self.pos]sample self.priorities[self.pos]max_prio self.pos(self.pos1)%self.capacity由于我们采用环形缓冲区存储状态转移数据在采样时会遇到两种不同情况当缓冲区未达最大容量时只需追加新转移数据。如果缓冲区已满则需覆写由pos类字段追踪的最旧转移数据并通过取模运算循环调整写入位置。(4)在sample()方法中需利用超参数α αα将优先级转换为概率分布defsample(self,batch_size,beta0.4):iflen(self.buffer)self.capacity:priosself.prioritieselse:priosself.priorities[:self.pos]probsnp.array(prios,dtypenp.float32)**self.prob_alpha probs/probs.sum()随后根据该概率分布从缓冲区抽取批次样本indicesnp.random.choice(len(self.buffer),batch_size,pprobs,replaceTrue)samples[self.buffer[idx]foridxinindices]最后计算批次样本的权重系数totallen(self.buffer)weights(total*probs[indices])**(-beta)weights/weights.max()returnsamples,indices,np.array(weights,dtypenp.float32)该方法返回三个对象批数据、索引及权重。其中索引用于后续更新已采样数据的优先级。。(5)优先级回放缓冲区的最后一个函数是允许我们更新已处理批次的新优先级defupdate_priorities(self,batch_indices,batch_priorities):foridx,prioinzip(batch_indices,batch_priorities):self.priorities[idx]prio调用者需负责在批处理计算损失时使用此函数。(6)在本节中下一个自定义函数是损失计算。由于PyTorch的MSELoss类不支持加权(因为MSE通常用于回归问题而样本加权常见于分类损失)我们需要手动计算MSE并显式地将结果与权重相乘defcalc_loss(batch:tt.List[ExperienceFirstLast],batch_weights:np.ndarray,net:nn.Module,tgt_net:nn.Module,gamma:float,device:torch.device)-tt.Tuple[torch.Tensor,np.ndarray]:states,actions,rewards,dones,next_statescommon.unpack_batch(batch)states_vtorch.as_tensor(states).to(device)actions_vtorch.tensor(actions).to(device)rewards_vtorch.tensor(rewards).to(device)done_masktorch.BoolTensor(dones).to(device)batch_weights_vtorch.tensor(batch_weights).to(device)actions_vactions_v.unsqueeze(-1)state_action_valsnet(states_v).gather(1,actions_v)state_action_valsstate_action_vals.squeeze(-1)withtorch.no_grad():next_states_vtorch.as_tensor(next_states).to(device)next_s_valstgt_net(next_states_v).max(1)[0]next_s_vals[done_mask]0.0exp_sa_valsnext_s_vals.detach()*gammarewards_v l(state_action_vals-exp_sa_vals)**2losses_vbatch_weights_v*lreturnlosses_v.mean(),(losses_v1e-5).data.cpu().numpy()在损失计算的最后部分我们实现了均方误差损失函数但采用显式表达式而非调用库函数。这使得我们可以纳入样本权重系数并保留每个样本的独立损失值。这些损失值将被回传至优先级回放缓冲区用于更新优先级。为避免零损失值导致缓冲区元素优先级归零的情况我们为每个损失值添加了一个小常量值。(7)在程序的主逻辑中仅需两处修改回放缓冲区的初始化和数据处理函数。由于缓冲区初始化过程直观明了我们将重点分析新的数据处理函数实现defprocess_batch(engine,batch_data):batch,batch_indices,batch_weightsbatch_data optimizer.zero_grad()loss_v,sample_prioscalc_loss(batch,batch_weights,net,tgt_net.target_model,gammaparams.gamma,devicedevice)loss_v.backward()optimizer.step()buffer.update_priorities(batch_indices,sample_prios)epsilon_tracker.frame(engine.state.iteration)ifengine.state.iteration%params.target_net_sync0:tgt_net.sync()return{loss:loss_v.item(),epsilon:selector.epsilon,beta:buffer.update_beta(engine.state.iteration),}主要变化如下批次现在包含三个实体批数据、采样项的索引以及样本权重调用新的损失函数该函数接收权重并返回附加项的优先级。这些优先级会被传递至buffer.update_priorities()函数以重新调整已采样项的优先级调用缓冲区的update_beta()方法根据调度策略调整beta参数3. 运行结果训练过程与经典 DQN 相同。根据实验数据优先回放缓冲区在解决环境问题上的耗时与经典DQN几乎相同但所需的训练迭代次数和训练回合数更少。实际耗时相近的主要原因在于当前回放缓冲区的实现效率较低——这一问题完全可以通过采用O ( l o g N ) O(log N)O(logN)复杂度的缓冲区实现方案来解决。下图展示了基线方法与优先回放缓冲区的奖励动态对比。横坐标表示游戏回合数另外需要注意的是在TensorBoard中可以观察到优先回放缓冲区的损失明显较低。下图展示了具体对比更低的损失值符合预期也表明我们的实现是有效的。优先采样的核心思想是通过重点训练高损失值的样本来提升训练效率。但这里存在一个潜在风险训练过程中的损失值并非首要优化目标——我们可能获得极低的损失值却因探索不足导致最终学得的策略远非最优。4. 超参数调优针对优先级回放缓冲区的超参数调优新增了α αα参数(取值范围0.3至0.9步长0.1)。最佳参数组合 (α 0.6 α0.6α0.6) 仅用330个训练回合就解决了Pong游戏learning_rate8.839010139505506e-05gamma0.99基准DQN与调优后的优先级回放缓冲区的对比图如下所示从图中可见优先级回放缓冲区的游戏表现提升更快但达到21分所需的游戏回合数几乎相同。在图(以游戏步数为单位)中可以看出优先级回放缓冲区的表现也略胜一筹。小结本节深入介绍了优先经验回放机制它通过根据样本损失值分配优先级打破了经典DQN的均匀采样策略从而提升训练效率与策略质量。详细阐述了优先级的计算公式、采样与权重补偿机制并给出了具体代码实现包括缓冲区设计、损失函数修改及超参数β ββ的调度策略。实验结果显示该方法在减少训练迭代次数的同时能够获得更低的损失值。系列链接PyTorch强化学习实战1——强化学习Reinforcement LearningRL详解PyTorch强化学习实战2——强化学习环境库GymnasiumPyTorch强化学习实战3——Gymnasium API扩展功能PyTorch强化学习实战4——PyTorch基础PyTorch强化学习实战5——PyTorch Ignite 事件驱动机制与实践PyTorch强化学习实战6——交叉熵方法详解与实现PyTorch强化学习实战7——表格学习与贝尔曼方程PyTorch强化学习实战8——Q学习详解与实现PyTorch强化学习实战9——深度Q学习PyTorch强化学习实战10——强化学习高级组件PyTorch强化学习实战11——N步DQNN-step DQNPyTorch强化学习实战12——Double DQNDDQNPyTorch强化学习实战13——噪声网络NoisyNet-DQN