强化学习实战从 DQN 到 PPO 的完整指南1. 引言强化学习RL是让智能体通过与环境交互来学习最优策略的方法。从 Atari 游戏到机器人控制从 RLHF 到代码生成RL 的应用越来越广泛。本文将从基础概念到前沿算法系统讲解强化学习。核心概念智能体 (Agent) ←→ 环境 (Environment) - 状态 (State): s_t - 动作 (Action): a_t - 奖励 (Reward): r_t - 策略 (Policy): π(a|s) - 价值函数 (Value): V(s) 或 Q(s,a) 目标最大化累积奖励 E[Σ γ^t · r_t]2. DQNDeep Q-Network2.1 原理Q-learning 更新规则 Q(s,a) ← Q(s,a) α[r γ·max Q(s,a) - Q(s,a)] DQN 改进 1. 用神经网络近似 Q 函数 2. 经验回放Experience Replay 3. 目标网络Target Network2.2 DQN 实现importtorchimporttorch.nnasnnimportnumpyasnpfromcollectionsimportdequeimportrandomclassDQN(nn.Module):DQN 网络def__init__(self,state_dim,action_dim,hidden256):super().__init__()self.netnn.Sequential(nn.Linear(state_dim,hidden),nn.ReLU(),nn.Linear(hidden,hidden),nn.ReLU(),nn.Linear(hidden,action_dim),)defforward(self,x):returnself.net(x)classReplayBuffer:经验回放缓冲区def__init__(self,capacity100000):self.bufferdeque(maxlencapacity)defpush(self,state,action,reward,next_state,done):self.buffer.append((state,action,reward,next_state,done))defsample(self,batch_size):batchrandom.sample(self.buffer,batch_size)states,actions,rewards,next_states,doneszip(*batch)return(torch.FloatTensor(np.array(states)),torch.LongTensor(actions),torch.FloatTensor(rewards),torch.FloatTensor(np.array(next_states)),torch.FloatTensor(dones),)def__len__(self):returnlen(self.buffer)classDQNAgent:DQN 智能体def__init__(self,state_dim,action_dim,lr1e-3,gamma0.99,epsilon_start1.0,epsilon_end0.01,epsilon_decay0.995):self.action_dimaction_dim self.gammagamma self.epsilonepsilon_start self.epsilon_endepsilon_end self.epsilon_decayepsilon_decay# Q 网络和目标网络self.q_netDQN(state_dim,action_dim)self.target_netDQN(state_dim,action_dim)self.target_net.load_state_dict(self.q_net.state_dict())self.optimizertorch.optim.Adam(self.q_net.parameters(),lrlr)self.bufferReplayBuffer()defselect_action(self,state):ε-贪婪策略ifrandom.random()self.epsilon:returnrandom.randint(0,self.action_dim-1)withtorch.no_grad():state_ttorch.FloatTensor(state).unsqueeze(0)q_valuesself.q_net(state_t)returnq_values.argmax(dim1).item()deftrain_step(self,batch_size64):训练一步iflen(self.buffer)batch_size:returnstates,actions,rewards,next_states,donesself.buffer.sample(batch_size)# 当前 Q 值q_valuesself.q_net(states).gather(1,actions.unsqueeze(1)).squeeze(1)# 目标 Q 值withtorch.no_grad():next_qself.target_net(next_states).max(dim1)[0]target_qrewardsself.gamma*next_q*(1-dones)# 更新lossnn.MSELoss()(q_values,target_q)self.optimizer.zero_grad()loss.backward()self.optimizer.step()returnloss.item()defupdate_target(self):更新目标网络self.target_net.load_state_dict(self.q_net.state_dict())defdecay_epsilon(self):衰减探索率self.epsilonmax(self.epsilon_end,self.epsilon*self.epsilon_decay)2.3 训练循环importgymnasiumasgymdeftrain_dqn(env_nameCartPole-v1,episodes500):envgym.make(env_name)state_dimenv.observation_space.shape[0]action_dimenv.action_space.n agentDQNAgent(state_dim,action_dim)rewards_history[]forepisodeinrange(episodes):state,_env.reset()total_reward0whileTrue:actionagent.select_action(state)next_state,reward,terminated,truncated,_env.step(action)doneterminatedortruncated agent.buffer.push(state,action,reward,next_state,float(done))agent.train_step()statenext_state total_rewardrewardifdone:breakagent.decay_epsilon()ifepisode%100:agent.update_target()rewards_history.append(total_reward)ifepisode%500:avgnp.mean(rewards_history[-50:])print(fEpisode{episode}, Avg Reward:{avg:.1f}, ε:{agent.epsilon:.3f})returnagent3. 策略梯度Policy Gradient3.1 REINFORCE 算法classPolicyNetwork(nn.Module):策略网络def__init__(self,state_dim,action_dim,hidden256):super().__init__()self.netnn.Sequential(nn.Linear(state_dim,hidden),nn.ReLU(),nn.Linear(hidden,hidden),nn.ReLU(),nn.Linear(hidden,action_dim),nn.Softmax(dim-1),)defforward(self,x):returnself.net(x)deftrain_reinforce(env_nameCartPole-v1,episodes1000,lr1e-3,gamma0.99):envgym.make(env_name)state_dimenv.observation_space.shape[0]action_dimenv.action_space.n policyPolicyNetwork(state_dim,action_dim)optimizertorch.optim.Adam(policy.parameters(),lrlr)forepisodeinrange(episodes):states,actions,rewards[],[],[]state,_env.reset()whileTrue:state_ttorch.FloatTensor(state).unsqueeze(0)probspolicy(state_t)actiontorch.multinomial(probs,1).item()next_state,reward,terminated,truncated,_env.step(action)doneterminatedortruncated states.append(state)actions.append(action)rewards.append(reward)statenext_stateifdone:break# 计算折扣回报returns[]G0forrinreversed(rewards):Grgamma*G returns.insert(0,G)returnstorch.FloatTensor(returns)returns(returns-returns.mean())/(returns.std()1e-8)# 策略梯度更新states_ttorch.FloatTensor(np.array(states))actions_ttorch.LongTensor(actions)probspolicy(states_t)log_probstorch.log(probs.gather(1,actions_t.unsqueeze(1)).squeeze(1))loss-(log_probs*returns).mean()optimizer.zero_grad()loss.backward()optimizer.step()ifepisode%500:print(fEpisode{episode}, Total Reward:{sum(rewards):.1f})4. PPOProximal Policy Optimization4.1 核心思想PPO 通过裁剪目标函数限制策略更新幅度 L_CLIP E[min(r_t(θ)·A_t, clip(r_t(θ), 1-ε, 1ε)·A_t)] 其中 r_t(θ) π_θ(a|s) / π_θ_old(a|s) 概率比 A_t 优势函数估计 ε 裁剪参数通常 0.1-0.24.2 PPO 实现classPPO:PPO 算法def__init__(self,state_dim,action_dim,lr3e-4,gamma0.99,gae_lambda0.95,clip_epsilon0.2,epochs10):self.gammagamma self.gae_lambdagae_lambda self.clip_epsilonclip_epsilon self.epochsepochs# Actor-Critic 网络self.actorPolicyNetwork(state_dim,action_dim)self.criticnn.Sequential(nn.Linear(state_dim,256),nn.ReLU(),nn.Linear(256,256),nn.ReLU(),nn.Linear(256,1),)self.optimizertorch.optim.Adam(list(self.actor.parameters())list(self.critic.parameters()),lrlr)defcompute_gae(self,rewards,values,dones):计算 GAE广义优势估计advantages[]gae0fortinreversed(range(len(rewards))):iftlen(rewards)-1:next_value0else:next_valuevalues[t1]deltarewards[t]self.gamma*next_value*(1-dones[t])-values[t]gaedeltaself.gamma*self.gae_lambda*(1-dones[t])*gae advantages.insert(0,gae)returns[advvalforadv,valinzip(advantages,values)]returntorch.FloatTensor(advantages),torch.FloatTensor(returns)defupdate(self,trajectories):PPO 更新statestorch.FloatTensor(np.array(trajectories[states]))actionstorch.LongTensor(trajectories[actions])old_log_probstorch.FloatTensor(trajectories[log_probs])rewardstrajectories[rewards]valuestrajectories[values]donestrajectories[dones]advantages,returnsself.compute_gae(rewards,values,dones)advantages(advantages-advantages.mean())/(advantages.std()1e-8)for_inrange(self.epochs):# 当前策略probsself.actor(states)log_probstorch.log(probs.gather(1,actions.unsqueeze(1)).squeeze(1))current_valuesself.critic(states).squeeze(1)# 概率比ratiotorch.exp(log_probs-old_log_probs)# 裁剪目标surr1ratio*advantages surr2torch.clamp(ratio,1-self.clip_epsilon,1self.clip_epsilon)*advantages actor_loss-torch.min(surr1,surr2).mean()# Critic 损失critic_lossnn.MSELoss()(current_values,returns)# 熵正则化entropy-(probs*torch.log(probs1e-8)).sum(dim-1).mean()# 总损失lossactor_loss0.5*critic_loss-0.01*entropy self.optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(list(self.actor.parameters())list(self.critic.parameters()),0.5)self.optimizer.step()5. 强化学习在 LLM 中的应用# PPO 在 RLHF 中的应用简化版fromtrlimportPPOTrainer,PPOConfig configPPOConfig(learning_rate1.41e-5,batch_size64,ppo_epochs4,kl_penaltykl,init_kl_coef0.2,)# 奖励来自人类偏好训练的奖励模型# 策略是 LLM 本身# 状态是 prompt动作是生成的 token6. 算法对比算法类型动作空间样本效率稳定性DQN值函数离散高中REINFORCE策略梯度连续/离散低低A2CActor-Critic连续/离散中中PPOActor-Critic连续/离散高高SACActor-Critic连续高高7. 总结强化学习的核心算法DQN值函数方法适合离散动作空间策略梯度直接优化策略适合连续动作PPO当前最流行的通用 RL 算法稳定高效RLHFPPO 在 LLM 对齐中的成功应用
强化学习实战:从 DQN 到 PPO 的完整指南
发布时间:2026/6/21 3:36:34
强化学习实战从 DQN 到 PPO 的完整指南1. 引言强化学习RL是让智能体通过与环境交互来学习最优策略的方法。从 Atari 游戏到机器人控制从 RLHF 到代码生成RL 的应用越来越广泛。本文将从基础概念到前沿算法系统讲解强化学习。核心概念智能体 (Agent) ←→ 环境 (Environment) - 状态 (State): s_t - 动作 (Action): a_t - 奖励 (Reward): r_t - 策略 (Policy): π(a|s) - 价值函数 (Value): V(s) 或 Q(s,a) 目标最大化累积奖励 E[Σ γ^t · r_t]2. DQNDeep Q-Network2.1 原理Q-learning 更新规则 Q(s,a) ← Q(s,a) α[r γ·max Q(s,a) - Q(s,a)] DQN 改进 1. 用神经网络近似 Q 函数 2. 经验回放Experience Replay 3. 目标网络Target Network2.2 DQN 实现importtorchimporttorch.nnasnnimportnumpyasnpfromcollectionsimportdequeimportrandomclassDQN(nn.Module):DQN 网络def__init__(self,state_dim,action_dim,hidden256):super().__init__()self.netnn.Sequential(nn.Linear(state_dim,hidden),nn.ReLU(),nn.Linear(hidden,hidden),nn.ReLU(),nn.Linear(hidden,action_dim),)defforward(self,x):returnself.net(x)classReplayBuffer:经验回放缓冲区def__init__(self,capacity100000):self.bufferdeque(maxlencapacity)defpush(self,state,action,reward,next_state,done):self.buffer.append((state,action,reward,next_state,done))defsample(self,batch_size):batchrandom.sample(self.buffer,batch_size)states,actions,rewards,next_states,doneszip(*batch)return(torch.FloatTensor(np.array(states)),torch.LongTensor(actions),torch.FloatTensor(rewards),torch.FloatTensor(np.array(next_states)),torch.FloatTensor(dones),)def__len__(self):returnlen(self.buffer)classDQNAgent:DQN 智能体def__init__(self,state_dim,action_dim,lr1e-3,gamma0.99,epsilon_start1.0,epsilon_end0.01,epsilon_decay0.995):self.action_dimaction_dim self.gammagamma self.epsilonepsilon_start self.epsilon_endepsilon_end self.epsilon_decayepsilon_decay# Q 网络和目标网络self.q_netDQN(state_dim,action_dim)self.target_netDQN(state_dim,action_dim)self.target_net.load_state_dict(self.q_net.state_dict())self.optimizertorch.optim.Adam(self.q_net.parameters(),lrlr)self.bufferReplayBuffer()defselect_action(self,state):ε-贪婪策略ifrandom.random()self.epsilon:returnrandom.randint(0,self.action_dim-1)withtorch.no_grad():state_ttorch.FloatTensor(state).unsqueeze(0)q_valuesself.q_net(state_t)returnq_values.argmax(dim1).item()deftrain_step(self,batch_size64):训练一步iflen(self.buffer)batch_size:returnstates,actions,rewards,next_states,donesself.buffer.sample(batch_size)# 当前 Q 值q_valuesself.q_net(states).gather(1,actions.unsqueeze(1)).squeeze(1)# 目标 Q 值withtorch.no_grad():next_qself.target_net(next_states).max(dim1)[0]target_qrewardsself.gamma*next_q*(1-dones)# 更新lossnn.MSELoss()(q_values,target_q)self.optimizer.zero_grad()loss.backward()self.optimizer.step()returnloss.item()defupdate_target(self):更新目标网络self.target_net.load_state_dict(self.q_net.state_dict())defdecay_epsilon(self):衰减探索率self.epsilonmax(self.epsilon_end,self.epsilon*self.epsilon_decay)2.3 训练循环importgymnasiumasgymdeftrain_dqn(env_nameCartPole-v1,episodes500):envgym.make(env_name)state_dimenv.observation_space.shape[0]action_dimenv.action_space.n agentDQNAgent(state_dim,action_dim)rewards_history[]forepisodeinrange(episodes):state,_env.reset()total_reward0whileTrue:actionagent.select_action(state)next_state,reward,terminated,truncated,_env.step(action)doneterminatedortruncated agent.buffer.push(state,action,reward,next_state,float(done))agent.train_step()statenext_state total_rewardrewardifdone:breakagent.decay_epsilon()ifepisode%100:agent.update_target()rewards_history.append(total_reward)ifepisode%500:avgnp.mean(rewards_history[-50:])print(fEpisode{episode}, Avg Reward:{avg:.1f}, ε:{agent.epsilon:.3f})returnagent3. 策略梯度Policy Gradient3.1 REINFORCE 算法classPolicyNetwork(nn.Module):策略网络def__init__(self,state_dim,action_dim,hidden256):super().__init__()self.netnn.Sequential(nn.Linear(state_dim,hidden),nn.ReLU(),nn.Linear(hidden,hidden),nn.ReLU(),nn.Linear(hidden,action_dim),nn.Softmax(dim-1),)defforward(self,x):returnself.net(x)deftrain_reinforce(env_nameCartPole-v1,episodes1000,lr1e-3,gamma0.99):envgym.make(env_name)state_dimenv.observation_space.shape[0]action_dimenv.action_space.n policyPolicyNetwork(state_dim,action_dim)optimizertorch.optim.Adam(policy.parameters(),lrlr)forepisodeinrange(episodes):states,actions,rewards[],[],[]state,_env.reset()whileTrue:state_ttorch.FloatTensor(state).unsqueeze(0)probspolicy(state_t)actiontorch.multinomial(probs,1).item()next_state,reward,terminated,truncated,_env.step(action)doneterminatedortruncated states.append(state)actions.append(action)rewards.append(reward)statenext_stateifdone:break# 计算折扣回报returns[]G0forrinreversed(rewards):Grgamma*G returns.insert(0,G)returnstorch.FloatTensor(returns)returns(returns-returns.mean())/(returns.std()1e-8)# 策略梯度更新states_ttorch.FloatTensor(np.array(states))actions_ttorch.LongTensor(actions)probspolicy(states_t)log_probstorch.log(probs.gather(1,actions_t.unsqueeze(1)).squeeze(1))loss-(log_probs*returns).mean()optimizer.zero_grad()loss.backward()optimizer.step()ifepisode%500:print(fEpisode{episode}, Total Reward:{sum(rewards):.1f})4. PPOProximal Policy Optimization4.1 核心思想PPO 通过裁剪目标函数限制策略更新幅度 L_CLIP E[min(r_t(θ)·A_t, clip(r_t(θ), 1-ε, 1ε)·A_t)] 其中 r_t(θ) π_θ(a|s) / π_θ_old(a|s) 概率比 A_t 优势函数估计 ε 裁剪参数通常 0.1-0.24.2 PPO 实现classPPO:PPO 算法def__init__(self,state_dim,action_dim,lr3e-4,gamma0.99,gae_lambda0.95,clip_epsilon0.2,epochs10):self.gammagamma self.gae_lambdagae_lambda self.clip_epsilonclip_epsilon self.epochsepochs# Actor-Critic 网络self.actorPolicyNetwork(state_dim,action_dim)self.criticnn.Sequential(nn.Linear(state_dim,256),nn.ReLU(),nn.Linear(256,256),nn.ReLU(),nn.Linear(256,1),)self.optimizertorch.optim.Adam(list(self.actor.parameters())list(self.critic.parameters()),lrlr)defcompute_gae(self,rewards,values,dones):计算 GAE广义优势估计advantages[]gae0fortinreversed(range(len(rewards))):iftlen(rewards)-1:next_value0else:next_valuevalues[t1]deltarewards[t]self.gamma*next_value*(1-dones[t])-values[t]gaedeltaself.gamma*self.gae_lambda*(1-dones[t])*gae advantages.insert(0,gae)returns[advvalforadv,valinzip(advantages,values)]returntorch.FloatTensor(advantages),torch.FloatTensor(returns)defupdate(self,trajectories):PPO 更新statestorch.FloatTensor(np.array(trajectories[states]))actionstorch.LongTensor(trajectories[actions])old_log_probstorch.FloatTensor(trajectories[log_probs])rewardstrajectories[rewards]valuestrajectories[values]donestrajectories[dones]advantages,returnsself.compute_gae(rewards,values,dones)advantages(advantages-advantages.mean())/(advantages.std()1e-8)for_inrange(self.epochs):# 当前策略probsself.actor(states)log_probstorch.log(probs.gather(1,actions.unsqueeze(1)).squeeze(1))current_valuesself.critic(states).squeeze(1)# 概率比ratiotorch.exp(log_probs-old_log_probs)# 裁剪目标surr1ratio*advantages surr2torch.clamp(ratio,1-self.clip_epsilon,1self.clip_epsilon)*advantages actor_loss-torch.min(surr1,surr2).mean()# Critic 损失critic_lossnn.MSELoss()(current_values,returns)# 熵正则化entropy-(probs*torch.log(probs1e-8)).sum(dim-1).mean()# 总损失lossactor_loss0.5*critic_loss-0.01*entropy self.optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(list(self.actor.parameters())list(self.critic.parameters()),0.5)self.optimizer.step()5. 强化学习在 LLM 中的应用# PPO 在 RLHF 中的应用简化版fromtrlimportPPOTrainer,PPOConfig configPPOConfig(learning_rate1.41e-5,batch_size64,ppo_epochs4,kl_penaltykl,init_kl_coef0.2,)# 奖励来自人类偏好训练的奖励模型# 策略是 LLM 本身# 状态是 prompt动作是生成的 token6. 算法对比算法类型动作空间样本效率稳定性DQN值函数离散高中REINFORCE策略梯度连续/离散低低A2CActor-Critic连续/离散中中PPOActor-Critic连续/离散高高SACActor-Critic连续高高7. 总结强化学习的核心算法DQN值函数方法适合离散动作空间策略梯度直接优化策略适合连续动作PPO当前最流行的通用 RL 算法稳定高效RLHFPPO 在 LLM 对齐中的成功应用