当前位置: 首页> 科技> 名企 > 建站平台备案_大数据开发培训班课程_网络推广优化招聘_seo培训机构排名

建站平台备案_大数据开发培训班课程_网络推广优化招聘_seo培训机构排名

时间:2025/7/19 9:32:30来源:https://blog.csdn.net/m0_60414444/article/details/146481364 浏览次数:1次
建站平台备案_大数据开发培训班课程_网络推广优化招聘_seo培训机构排名

一、多任务强化学习原理

1. 多任务学习核心思想

多任务强化学习(Multi-Task RL)旨在让智能体同时学习多个任务,通过共享知识提升学习效率和泛化能力。与单任务强化学习的区别在于:

对比维度单任务强化学习多任务强化学习
目标优化单一任务策略同时优化多个任务的共享策略
训练方式单任务独立训练多任务联合训练
知识迁移共享表示或参数实现跨任务知识迁移
应用场景任务特定场景复杂环境中的通用智能体
2. 基于共享表示的多任务框架

通过共享网络层学习任务共性,任务特定层处理任务差异。算法流程如下:

  1. 任务采样:从任务分布中随机选择一个任务

  2. 策略执行:基于共享网络生成动作

  3. 梯度更新:联合优化共享参数和任务特定参数

数学表达:


二、多任务 PPO 算法实现(基于 Gymnasium)

我们将以 Meta-World 多任务机械臂环境 为例,实现基于 PPO 的多任务强化学习:

  1. 定义任务集合:包含 reachpushpick-place 等任务

  2. 构建共享策略网络:共享卷积层 + 任务特定全连接层

  3. 实现多任务采样:动态切换任务训练

  4. 联合梯度更新:平衡多任务损失


三、代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
from torch.cuda.amp import autocast, GradScaler
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
from collections import deque
​
# ================== 配置参数 ==================
class MultiTaskPPOConfig:task_names = ['reach-v2-goal-observable','push-v2-goal-observable','pick-place-v2-goal-observable']num_tasks = 3hidden_dim = 512task_specific_dim = 128lr = 3e-4gamma = 0.99gae_lambda = 0.95clip_epsilon = 0.2ppo_epochs = 4batch_size = 512max_episodes = 2000max_steps = 500grad_clip = 0.5device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== 共享策略网络 ==================
class SharedPolicy(nn.Module):def __init__(self, state_dim, action_dim):super().__init__()self.action_dim = action_dimself.shared_net = nn.Sequential(nn.Linear(state_dim, MultiTaskPPOConfig.hidden_dim),nn.LayerNorm(MultiTaskPPOConfig.hidden_dim),nn.GELU(),nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.hidden_dim),nn.GELU())# 多任务头部self.task_heads = nn.ModuleList([nn.Sequential(nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),nn.GELU(),nn.Linear(MultiTaskPPOConfig.task_specific_dim, action_dim)) for _ in range(MultiTaskPPOConfig.num_tasks)])self.value_heads = nn.ModuleList([nn.Sequential(nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),nn.GELU(),nn.Linear(MultiTaskPPOConfig.task_specific_dim, 1)) for _ in range(MultiTaskPPOConfig.num_tasks)])
​def forward(self, states, task_ids):shared_features = self.shared_net(states)batch_size = states.size(0)# 初始化与输入相同dtype的输出张量action_means = torch.zeros_like(states[:, :self.action_dim],  # 假设states维度足够dtype=states.dtype, device=states.device)values = torch.zeros(batch_size, 1, dtype=states.dtype, device=states.device)unique_task_ids = torch.unique(task_ids)for task_id_tensor in unique_task_ids:task_id = task_id_tensor.item()mask = (task_ids == task_id_tensor)if not mask.any():continueselected_features = shared_features[mask]# 显式转换输出类型到states.dtype (通常是float32)task_action = self.task_heads[task_id](selected_features).to(dtype=states.dtype)task_value = self.value_heads[task_id](selected_features).to(dtype=states.dtype)action_means[mask] = task_actionvalues[mask] = task_valuereturn action_means, values
​
# ================== 训练系统 ==================
class MultiTaskPPOTrainer:def __init__(self):# 初始化多任务环境self.envs = []self.state_dim = Noneself.action_dim = None# 验证环境并获取维度for task_name in MultiTaskPPOConfig.task_names:env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[task_name]()obs, _ = env.reset()if self.state_dim is None:self.state_dim = obs.shape[0]self.action_dim = env.action_space.shape[0]else:assert obs.shape[0] == self.state_dim, f"状态维度不一致: {task_name}"self.envs.append(env)# 初始化策略网络self.policy = SharedPolicy(self.state_dim, self.action_dim).to(MultiTaskPPOConfig.device)self.optimizer = optim.AdamW(self.policy.parameters(), lr=MultiTaskPPOConfig.lr)self.scaler = GradScaler()# 初始化经验回放缓冲self.buffer = deque(maxlen=MultiTaskPPOConfig.max_steps)
​def collect_experience(self, num_steps):"""并行收集多任务经验"""for _ in range(num_steps):task_id = int(np.random.randint(MultiTaskPPOConfig.num_tasks))env = self.envs[task_id]if not hasattr(env, '_last_obs'):state, _ = env.reset()else:state = env._last_obswith torch.no_grad():state_tensor = torch.FloatTensor(state).unsqueeze(0).to(MultiTaskPPOConfig.device)# 将task_id转换为张量task_id_tensor = torch.tensor([task_id], dtype=torch.long, device=MultiTaskPPOConfig.device)action_mean, value = self.policy(state_tensor, task_id_tensor)dist = Normal(action_mean, torch.ones_like(action_mean))action = dist.sample().squeeze(0).cpu().numpy()log_prob = dist.log_prob(action_mean).detach()next_state, reward, done, trunc, _ = env.step(action)self.buffer.append({'state': state,'action': action,'log_prob': log_prob.cpu(),'reward': float(reward),'done': bool(done),'task_id': task_id,'value': float(value.item())})state = next_state if not (done or trunc) else env.reset()[0]
​def compute_gae(self, values, rewards, dones):"""计算广义优势估计(GAE)"""advantages = []last_advantage = 0next_value = 0for t in reversed(range(len(rewards))):delta = rewards[t] + MultiTaskPPOConfig.gamma * next_value * (1 - dones[t]) - values[t]last_advantage = delta + MultiTaskPPOConfig.gamma * MultiTaskPPOConfig.gae_lambda * (1 - dones[t]) * last_advantageadvantages.append(last_advantage)next_value = values[t]advantages = torch.tensor(advantages[::-1], dtype=torch.float32).to(MultiTaskPPOConfig.device)returns = advantages + torch.tensor(values, dtype=torch.float32).to(MultiTaskPPOConfig.device)return (advantages - advantages.mean()) / (advantages.std() + 1e-8), returns
​def update_policy(self):"""策略更新阶段正确转换张量"""if not self.buffer:return 0, 0"""使用PPO进行策略优化"""# 从缓冲中提取数据batch = list(self.buffer)states = torch.tensor([x['state'] for x in batch],dtype=torch.float32,device=MultiTaskPPOConfig.device)actions = torch.FloatTensor(np.array([x['action'] for x in batch])).to(MultiTaskPPOConfig.device)old_log_probs = torch.cat([x['log_prob'] for x in batch]).to(MultiTaskPPOConfig.device)rewards = torch.FloatTensor([x['reward'] for x in batch]).to(MultiTaskPPOConfig.device)dones = torch.FloatTensor([x['done'] for x in batch]).to(MultiTaskPPOConfig.device)task_ids = torch.tensor([x['task_id'] for x in batch],dtype=torch.long,  # 必须指定为long类型device=MultiTaskPPOConfig.device)values = torch.FloatTensor([x['value'] for x in batch]).to(MultiTaskPPOConfig.device)
​# 计算GAE和returnsadvantages, returns = self.compute_gae(values.cpu().numpy(), rewards.cpu().numpy(), dones.cpu().numpy())
​# 自动混合精度训练with autocast():total_policy_loss = 0total_value_loss = 0for _ in range(MultiTaskPPOConfig.ppo_epochs):# 随机打乱数据perm = torch.randperm(len(batch))for i in range(0, len(batch), MultiTaskPPOConfig.batch_size):idx = perm[i:i+MultiTaskPPOConfig.batch_size]# 获取小批量数据batch_states = states[idx]batch_actions = actions[idx]batch_old_log_probs = old_log_probs[idx]batch_returns = returns[idx]batch_advantages = advantages[idx]batch_task_ids = task_ids[idx]# 前向传播action_means, new_values = self.policy(states, task_ids)dist = Normal(action_means, torch.ones_like(action_means))new_log_probs = dist.log_prob(batch_actions)# 计算重要性采样比率ratio = (new_log_probs - batch_old_log_probs).exp()# 策略损失surr1 = ratio * batch_advantages.unsqueeze(-1)surr2 = torch.clamp(ratio, 1-MultiTaskPPOConfig.clip_epsilon, 1+MultiTaskPPOConfig.clip_epsilon) * batch_advantages.unsqueeze(-1)policy_loss = -torch.min(surr1, surr2).mean()# 值函数损失value_loss = 0.5 * (new_values.squeeze() - batch_returns).pow(2).mean()# 总损失loss = policy_loss + value_loss# 反向传播self.scaler.scale(loss).backward()total_policy_loss += policy_loss.item()total_value_loss += value_loss.item()
​# 梯度裁剪和参数更新self.scaler.unscale_(self.optimizer)torch.nn.utils.clip_grad_norm_(self.policy.parameters(), MultiTaskPPOConfig.grad_clip)self.scaler.step(self.optimizer)self.scaler.update()self.optimizer.zero_grad()
​return total_policy_loss / MultiTaskPPOConfig.ppo_epochs, total_value_loss / MultiTaskPPOConfig.ppo_epochs
​def train(self):print(f"开始训练,设备:{MultiTaskPPOConfig.device}")start_time = time.time()episode_rewards = {i: deque(maxlen=100) for i in range(MultiTaskPPOConfig.num_tasks)}for episode in range(MultiTaskPPOConfig.max_episodes):# 经验收集阶段self.collect_experience(MultiTaskPPOConfig.max_steps)# 策略优化阶段policy_loss, value_loss = self.update_policy()# 记录统计信息task_id = np.random.randint(MultiTaskPPOConfig.num_tasks)episode_reward = sum(x['reward'] for x in self.buffer if x['task_id'] == task_id)episode_rewards[task_id].append(episode_reward)# 定期输出日志if (episode + 1) % 100 == 0:avg_rewards = {k: np.mean(v) if v else 0 for k, v in episode_rewards.items()}time_cost = time.time() - start_timeprint(f"Episode {episode+1:5d} | Time: {time_cost:6.1f}s")for task_id in range(MultiTaskPPOConfig.num_tasks):task_name = MultiTaskPPOConfig.task_names[task_id]print(f"  {task_name:25s} | Avg Reward: {avg_rewards[task_id]:7.2f}")print(f"  Policy Loss: {policy_loss:.4f} | Value Loss: {value_loss:.4f}\n")start_time = time.time()
​
if __name__ == "__main__":trainer = MultiTaskPPOTrainer()print(f"状态维度: {trainer.state_dim}, 动作维度: {trainer.action_dim}")trainer.train()

四、关键代码解析

  1. 共享策略网络

    • SharedPolicy 包含共享网络层和任务特定头部

    • task_headsvalue_heads 分别处理不同任务的动作和值函数

  2. 多任务采样机制

    • 每个回合随机选择一个任务进行训练

    • 动态切换环境实例 env = self.envs[task_id]

  3. 联合梯度更新

    • 计算多任务的策略损失和值函数损失

    • 通过 task_id 索引选择对应任务头部参数


五、训练输出示例

状态维度: 39, 动作维度: 4
开始训练,设备:cuda
/workspace/e23.py:184: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)states = torch.tensor(
/workspace/e23.py:204: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.with autocast():
Episode   100 | Time:  931.2sreach-v2-goal-observable  | Avg Reward:  226.83push-v2-goal-observable   | Avg Reward:    8.82pick-place-v2-goal-observable | Avg Reward:    3.31Policy Loss: 0.0386 | Value Loss: 13.2587
​
Episode   200 | Time:  935.3sreach-v2-goal-observable  | Avg Reward:  227.12push-v2-goal-observable   | Avg Reward:    8.83pick-place-v2-goal-observable | Avg Reward:    3.23Policy Loss: 0.0434 | Value Loss: 14.9413
​
Episode   300 | Time:  939.4sreach-v2-goal-observable  | Avg Reward:  226.78push-v2-goal-observable   | Avg Reward:    8.82pick-place-v2-goal-observable | Avg Reward:    3.23Policy Loss: 0.0429 | Value Loss: 13.9076
​
Episode   400 | Time:  938.4sreach-v2-goal-observable  | Avg Reward:  225.74push-v2-goal-observable   | Avg Reward:    8.84pick-place-v2-goal-observable | Avg Reward:    3.20Policy Loss: 0.0378 | Value Loss: 14.7157
​
Episode   500 | Time:  938.4sreach-v2-goal-observable  | Avg Reward:  225.45push-v2-goal-observable   | Avg Reward:    8.81pick-place-v2-goal-observable | Avg Reward:    3.20Policy Loss: 0.0381 | Value Loss: 11.7940
​
Episode   600 | Time:  928.5sreach-v2-goal-observable  | Avg Reward:  225.39push-v2-goal-observable   | Avg Reward:    8.75pick-place-v2-goal-observable | Avg Reward:    3.20Policy Loss: 0.0462 | Value Loss: 14.5566
​
Episode   700 | Time:  926.6sreach-v2-goal-observable  | Avg Reward:  226.37push-v2-goal-observable   | Avg Reward:    8.65pick-place-v2-goal-observable | Avg Reward:    3.23Policy Loss: 0.0394 | Value Loss: 15.5556
​
Episode   800 | Time:  943.8sreach-v2-goal-observable  | Avg Reward:  224.72push-v2-goal-observable   | Avg Reward:    8.64pick-place-v2-goal-observable | Avg Reward:    3.23Policy Loss: 0.0361 | Value Loss: 16.0126
​
Episode   900 | Time:  937.2sreach-v2-goal-observable  | Avg Reward:  224.15push-v2-goal-observable   | Avg Reward:    8.72pick-place-v2-goal-observable | Avg Reward:    3.21Policy Loss: 0.0417 | Value Loss: 14.1907
​
Episode  1000 | Time:  940.7sreach-v2-goal-observable  | Avg Reward:  223.77push-v2-goal-observable   | Avg Reward:    8.73pick-place-v2-goal-observable | Avg Reward:    3.19Policy Loss: 0.0399 | Value Loss: 16.0540
​
Episode  1100 | Time:  937.0sreach-v2-goal-observable  | Avg Reward:  224.73push-v2-goal-observable   | Avg Reward:    8.68pick-place-v2-goal-observable | Avg Reward:    3.17Policy Loss: 0.0409 | Value Loss: 15.5525
​
Episode  1200 | Time:  933.0sreach-v2-goal-observable  | Avg Reward:  224.73push-v2-goal-observable   | Avg Reward:    8.68pick-place-v2-goal-observable | Avg Reward:    3.17Policy Loss: 0.0388 | Value Loss: 17.4549
​
Episode  1300 | Time:  942.1sreach-v2-goal-observable  | Avg Reward:  224.35push-v2-goal-observable   | Avg Reward:    8.71pick-place-v2-goal-observable | Avg Reward:    3.19Policy Loss: 0.0447 | Value Loss: 14.6700
​
Episode  1400 | Time:  966.6sreach-v2-goal-observable  | Avg Reward:  224.27push-v2-goal-observable   | Avg Reward:    8.73pick-place-v2-goal-observable | Avg Reward:    3.19Policy Loss: 0.0434 | Value Loss: 13.3487
​
Episode  1500 | Time:  943.0sreach-v2-goal-observable  | Avg Reward:  223.03push-v2-goal-observable   | Avg Reward:    8.69pick-place-v2-goal-observable | Avg Reward:    3.21Policy Loss: 0.0438 | Value Loss: 14.7557
​
Episode  1600 | Time:  929.1sreach-v2-goal-observable  | Avg Reward:  224.01push-v2-goal-observable   | Avg Reward:    8.69pick-place-v2-goal-observable | Avg Reward:    3.21Policy Loss: 0.0365 | Value Loss: 12.2506
​
Episode  1700 | Time:  937.9sreach-v2-goal-observable  | Avg Reward:  222.88push-v2-goal-observable   | Avg Reward:    8.71pick-place-v2-goal-observable | Avg Reward:    3.21Policy Loss: 0.0365 | Value Loss: 11.8954
​
Episode  1800 | Time:  930.1sreach-v2-goal-observable  | Avg Reward:  224.42push-v2-goal-observable   | Avg Reward:    8.75pick-place-v2-goal-observable | Avg Reward:    3.18Policy Loss: 0.0437 | Value Loss: 13.6396
​
Episode  1900 | Time:  927.0sreach-v2-goal-observable  | Avg Reward:  224.66push-v2-goal-observable   | Avg Reward:    8.71pick-place-v2-goal-observable | Avg Reward:    3.18Policy Loss: 0.0360 | Value Loss: 14.3216
​
Episode  2000 | Time:  934.3sreach-v2-goal-observable  | Avg Reward:  224.73push-v2-goal-observable   | Avg Reward:    8.63pick-place-v2-goal-observable | Avg Reward:    3.18Policy Loss: 0.0475 | Value Loss: 14.0712

六、总结与扩展

本文实现了多任务强化学习的核心范式——基于共享策略的 PPO 算法,展示了跨任务知识迁移的能力。读者可尝试以下扩展方向:

  1. 动态任务权重 根据任务难度自适应调整损失权重:

    # 在 update() 中添加任务权重
    task_weights = calculate_task_difficulty()
    loss = sum([weight * loss_i for weight, loss_i in zip(task_weights, losses)])

  2. 分层强化学习 引入高层策略调度任务:

    class MetaController(nn.Module):def __init__(self, num_tasks):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, 64),nn.ReLU(),nn.Linear(64, num_tasks))

  3. 课程学习 从简单任务逐步过渡到复杂任务:

    def schedule_task():if episode < 1000:return 'reach-v2-goal-observable'elif episode < 2000:return 'push-v2-goal-observable'else:return 'pick-place-v2-goal-observable'

在下一篇文章中,我们将探索 分层强化学习(HRL),并实现 Option-Critic 算法!


注意事项

1.安装依赖:

pip install metaworld gymnasium torch

2.metaworld问题:

如果稳定版存在问题,尝试安装GitHub上的最新版:

pip install git+https://github.com/rlworkgroup/metaworld.git@master

关键字:建站平台备案_大数据开发培训班课程_网络推广优化招聘_seo培训机构排名

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: