Cliff Walking环境实战Python实现Sarsa与Q-Learning算法深度解析引言当强化学习遇见悬崖漫步想象你正站在一个4×12的网格世界起点右下角是诱人的目标点但中间却横亘着一道致命的悬崖。每走一步都会消耗体力奖励-1跌落悬崖将承受巨大痛苦奖励-100。这就是经典的Cliff Walking环境——强化学习领域的Hello World完美展示了探索与利用的永恒博弈。不同于普通的迷宫问题Cliff Walking的精妙之处在于安全路径贴着悬崖上方的长路径总奖励-13最优路径紧贴悬崖边缘的最短路径总奖励-11本文将带您用Python从零实现两种经典算法保守的Sarsa和冒险的Q-Learning。通过完整的代码示例和可视化分析您将深入理解表格型强化学习的核心实现逻辑两种算法在策略选择上的本质差异如何设计高效的训练流程关键参数对算法表现的影响import numpy as np import matplotlib.pyplot as plt import gym from gym import spaces1. 环境构建打造自己的Cliff Walking1.1 自定义Gym环境我们首先继承gym.Env类创建自定义环境。关键要素包括class CliffWalkingEnv(gym.Env): def __init__(self): self.shape (4, 12) self.start_pos (3, 0) self.goal_pos (3, 11) self.cliff [(3, i) for i in range(1, 11)] self.action_space spaces.Discrete(4) # 上:0 右:1 下:2 左:3 self.observation_space spaces.Discrete(self.shape[0] * self.shape[1]) self.reset()1.2 状态转移逻辑实现核心的_step方法处理移动逻辑和奖励计算def _step(self, action): x, y self.pos # 移动处理 if action 0: x max(x-1, 0) elif action 1: y min(y1, self.shape[1]-1) elif action 2: x min(x1, self.shape[0]-1) elif action 3: y max(y-1, 0) self.pos (x, y) done False reward -1 # 终止条件判断 if self.pos in self.cliff: reward -100 self.reset() elif self.pos self.goal_pos: done True reward 0 return self._get_state(), reward, done, {}1.3 可视化渲染添加渲染功能直观展示智能体移动def _render(self): grid [[. for _ in range(self.shape[1])] for _ in range(self.shape[0])] grid[self.goal_pos[0]][self.goal_pos[1]] G for c in self.cliff: grid[c[0]][c[1]] X grid[self.pos[0]][self.pos[1]] A for row in grid: print( .join(row)) print()2. Sarsa算法实现安全第一的保守派2.1 算法核心原理Sarsa属于on-policy算法其更新公式为Q(s,a) ← Q(s,a) α[r γQ(s,a) - Q(s,a)]其中a是根据当前策略在s状态选择的动作体现行动-评估的一致性。2.2 Python实现细节我们创建SarsaAgent类封装核心逻辑class SarsaAgent: def __init__(self, env, alpha0.1, gamma0.9, epsilon0.1): self.env env self.alpha alpha # 学习率 self.gamma gamma # 折扣因子 self.epsilon epsilon # 探索率 self.Q np.zeros((env.observation_space.n, env.action_space.n)) def choose_action(self, state): if np.random.random() self.epsilon: return self.env.action_space.sample() return np.argmax(self.Q[state])2.3 训练流程剖析完整的训练循环展示Sarsa的在线学习特性def train(env, agent, episodes500): rewards [] for _ in range(episodes): state env.reset() action agent.choose_action(state) total_reward 0 done False while not done: next_state, reward, done, _ env.step(action) next_action agent.choose_action(next_state) # Sarsa核心更新 td_target reward agent.gamma * agent.Q[next_state][next_action] td_error td_target - agent.Q[state][action] agent.Q[state][action] agent.alpha * td_error state, action next_state, next_action total_reward reward rewards.append(total_reward) return rewards2.4 结果可视化分析运行训练后我们观察到收敛路径智能体学会走上方安全路径学习曲线约200轮后趋于稳定策略特点避开悬崖边缘即使路径更长plt.plot(moving_average(rewards, window10)) plt.xlabel(Episode) plt.ylabel(Total Reward) plt.title(Sarsa Learning Curve)3. Q-Learning实现追求最优的冒险家3.1 算法核心差异Q-Learning是off-policy算法其更新公式为Q(s,a) ← Q(s,a) α[r γmax_aQ(s,a) - Q(s,a)]关键区别在于使用最大Q值而非实际采取的动作。3.2 Python实现对比在agent类中修改更新逻辑class QLearningAgent(SarsaAgent): def update(self, state, action, reward, next_state, done): if done: td_target reward else: td_target reward self.gamma * np.max(self.Q[next_state]) td_error td_target - self.Q[state][action] self.Q[state][action] self.alpha * td_error3.3 训练流程调整修改训练循环体现off-policy特性def qlearn_train(env, agent, episodes): rewards [] for _ in range(episodes): state env.reset() total_reward 0 done False while not done: action agent.choose_action(state) next_state, reward, done, _ env.step(action) agent.update(state, action, reward, next_state, done) state next_state total_reward reward rewards.append(total_reward) return rewards3.4 结果对比分析与Sarsa相比Q-Learning表现出路径选择学会冒险走悬崖边缘的最短路径收敛速度通常比Sarsa更快找到高奖励策略风险暴露偶尔会跌落悬崖导致奖励波动# 对比两种算法的移动平均奖励 plt.plot(sarsa_ma, labelSarsa) plt.plot(qlearn_ma, labelQ-Learning) plt.legend()4. 深度解析算法差异与工程实践4.1 策略差异的本质通过价值热力图可以直观理解两种算法的策略差异状态特征Sarsa策略Q-Learning策略靠近悬崖的状态价值较低避免接近价值较高敢冒险安全路径状态价值梯度均匀价值梯度陡峭def plot_values(agent, title): values np.max(agent.Q, axis1).reshape(4,12) plt.imshow(values, cmaphot) plt.title(title)4.2 超参数调优指南关键参数的影响实验数据参数典型范围对Sarsa影响对Q-Learning影响学习率α0.01-0.5过大导致震荡可设更大值(如0.5)探索率ε0.05-0.3需要持续探索可随时间衰减折扣因子γ0.8-0.99较高值(0.95)效果更好适中值(0.9)最佳4.3 实用技巧与陷阱规避经验技巧对Q-Learning使用ε衰减epsilon max(0.01, epsilon*0.995)初始化Q值为乐观值如0鼓励探索监控Q值变化幅度判断收敛常见陷阱固定ε导致Q-Learning持续跌落悬崖α过大导致Sarsa无法稳定收敛没有定期测试贪婪策略的真实表现# ε衰减示例 class DecayEpsilonAgent(QLearningAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.initial_epsilon self.epsilon def choose_action(self, state, episode): self.epsilon self.initial_epsilon / (1 episode // 100) return super().choose_action(state)5. 进阶扩展算法变体与性能提升5.1 Expected Sarsa实现结合Sarsa和Q-Learning优点的变体class ExpectedSarsaAgent(QLearningAgent): def update(self, state, action, reward, next_state, done): if done: td_target reward else: policy np.ones(self.env.action_space.n) * self.epsilon / self.env.action_space.n policy[np.argmax(self.Q[next_state])] 1 - self.epsilon td_target reward self.gamma * np.sum(policy * self.Q[next_state]) self.Q[state][action] self.alpha * (td_target - self.Q[state][action])5.2 使用经验回放提升样本效率的改进方案class ReplayBuffer: def __init__(self, capacity1000): self.buffer collections.deque(maxlencapacity) def add(self, experience): self.buffer.append(experience) def sample(self, batch_size): return random.sample(self.buffer, batch_size) # 在训练循环中 buffer ReplayBuffer() for episode in range(episodes): # ...环境交互... buffer.add((state, action, reward, next_state, done)) # 从buffer采样进行更新 batch buffer.sample(32) for exp in batch: agent.update(*exp)5.3 多步TD学习平衡MC和TD方法的折中方案class NStepSarsaAgent(SarsaAgent): def __init__(self, n_steps3, *args, **kwargs): super().__init__(*args, **kwargs) self.n_steps n_steps self.trajectory [] def update(self, state, action, reward, next_state, done): self.trajectory.append((state, action, reward)) if len(self.trajectory) self.n_steps or done: states, actions, rewards zip(*self.trajectory) G sum([r * (self.gamma**i) for i, r in enumerate(rewards)]) if not done: G (self.gamma**self.n_steps) * self.Q[next_state][self.choose_action(next_state)] s, a states[0], actions[0] self.Q[s][a] self.alpha * (G - self.Q[s][a]) self.trajectory.pop(0)结语从Cliff Walking到现实应用通过这个看似简单的网格世界我们已经掌握了强化学习最核心的思想精髓。在实际项目中这些算法经过适当调整可以应用于机器人路径规划游戏AI策略优化资源调度决策系统记住没有放之四海皆准的完美算法——Sarsa的保守稳健和Q-Learning的激进高效各有适用场景。真正的高手懂得根据实际问题特点选择合适的工具并通过系统化的实验验证找到最佳参数组合。
Cliff Walking环境实战:用Python手把手教你实现Sarsa和Q-Learning(附完整代码)
发布时间:2026/5/24 6:17:16
Cliff Walking环境实战Python实现Sarsa与Q-Learning算法深度解析引言当强化学习遇见悬崖漫步想象你正站在一个4×12的网格世界起点右下角是诱人的目标点但中间却横亘着一道致命的悬崖。每走一步都会消耗体力奖励-1跌落悬崖将承受巨大痛苦奖励-100。这就是经典的Cliff Walking环境——强化学习领域的Hello World完美展示了探索与利用的永恒博弈。不同于普通的迷宫问题Cliff Walking的精妙之处在于安全路径贴着悬崖上方的长路径总奖励-13最优路径紧贴悬崖边缘的最短路径总奖励-11本文将带您用Python从零实现两种经典算法保守的Sarsa和冒险的Q-Learning。通过完整的代码示例和可视化分析您将深入理解表格型强化学习的核心实现逻辑两种算法在策略选择上的本质差异如何设计高效的训练流程关键参数对算法表现的影响import numpy as np import matplotlib.pyplot as plt import gym from gym import spaces1. 环境构建打造自己的Cliff Walking1.1 自定义Gym环境我们首先继承gym.Env类创建自定义环境。关键要素包括class CliffWalkingEnv(gym.Env): def __init__(self): self.shape (4, 12) self.start_pos (3, 0) self.goal_pos (3, 11) self.cliff [(3, i) for i in range(1, 11)] self.action_space spaces.Discrete(4) # 上:0 右:1 下:2 左:3 self.observation_space spaces.Discrete(self.shape[0] * self.shape[1]) self.reset()1.2 状态转移逻辑实现核心的_step方法处理移动逻辑和奖励计算def _step(self, action): x, y self.pos # 移动处理 if action 0: x max(x-1, 0) elif action 1: y min(y1, self.shape[1]-1) elif action 2: x min(x1, self.shape[0]-1) elif action 3: y max(y-1, 0) self.pos (x, y) done False reward -1 # 终止条件判断 if self.pos in self.cliff: reward -100 self.reset() elif self.pos self.goal_pos: done True reward 0 return self._get_state(), reward, done, {}1.3 可视化渲染添加渲染功能直观展示智能体移动def _render(self): grid [[. for _ in range(self.shape[1])] for _ in range(self.shape[0])] grid[self.goal_pos[0]][self.goal_pos[1]] G for c in self.cliff: grid[c[0]][c[1]] X grid[self.pos[0]][self.pos[1]] A for row in grid: print( .join(row)) print()2. Sarsa算法实现安全第一的保守派2.1 算法核心原理Sarsa属于on-policy算法其更新公式为Q(s,a) ← Q(s,a) α[r γQ(s,a) - Q(s,a)]其中a是根据当前策略在s状态选择的动作体现行动-评估的一致性。2.2 Python实现细节我们创建SarsaAgent类封装核心逻辑class SarsaAgent: def __init__(self, env, alpha0.1, gamma0.9, epsilon0.1): self.env env self.alpha alpha # 学习率 self.gamma gamma # 折扣因子 self.epsilon epsilon # 探索率 self.Q np.zeros((env.observation_space.n, env.action_space.n)) def choose_action(self, state): if np.random.random() self.epsilon: return self.env.action_space.sample() return np.argmax(self.Q[state])2.3 训练流程剖析完整的训练循环展示Sarsa的在线学习特性def train(env, agent, episodes500): rewards [] for _ in range(episodes): state env.reset() action agent.choose_action(state) total_reward 0 done False while not done: next_state, reward, done, _ env.step(action) next_action agent.choose_action(next_state) # Sarsa核心更新 td_target reward agent.gamma * agent.Q[next_state][next_action] td_error td_target - agent.Q[state][action] agent.Q[state][action] agent.alpha * td_error state, action next_state, next_action total_reward reward rewards.append(total_reward) return rewards2.4 结果可视化分析运行训练后我们观察到收敛路径智能体学会走上方安全路径学习曲线约200轮后趋于稳定策略特点避开悬崖边缘即使路径更长plt.plot(moving_average(rewards, window10)) plt.xlabel(Episode) plt.ylabel(Total Reward) plt.title(Sarsa Learning Curve)3. Q-Learning实现追求最优的冒险家3.1 算法核心差异Q-Learning是off-policy算法其更新公式为Q(s,a) ← Q(s,a) α[r γmax_aQ(s,a) - Q(s,a)]关键区别在于使用最大Q值而非实际采取的动作。3.2 Python实现对比在agent类中修改更新逻辑class QLearningAgent(SarsaAgent): def update(self, state, action, reward, next_state, done): if done: td_target reward else: td_target reward self.gamma * np.max(self.Q[next_state]) td_error td_target - self.Q[state][action] self.Q[state][action] self.alpha * td_error3.3 训练流程调整修改训练循环体现off-policy特性def qlearn_train(env, agent, episodes): rewards [] for _ in range(episodes): state env.reset() total_reward 0 done False while not done: action agent.choose_action(state) next_state, reward, done, _ env.step(action) agent.update(state, action, reward, next_state, done) state next_state total_reward reward rewards.append(total_reward) return rewards3.4 结果对比分析与Sarsa相比Q-Learning表现出路径选择学会冒险走悬崖边缘的最短路径收敛速度通常比Sarsa更快找到高奖励策略风险暴露偶尔会跌落悬崖导致奖励波动# 对比两种算法的移动平均奖励 plt.plot(sarsa_ma, labelSarsa) plt.plot(qlearn_ma, labelQ-Learning) plt.legend()4. 深度解析算法差异与工程实践4.1 策略差异的本质通过价值热力图可以直观理解两种算法的策略差异状态特征Sarsa策略Q-Learning策略靠近悬崖的状态价值较低避免接近价值较高敢冒险安全路径状态价值梯度均匀价值梯度陡峭def plot_values(agent, title): values np.max(agent.Q, axis1).reshape(4,12) plt.imshow(values, cmaphot) plt.title(title)4.2 超参数调优指南关键参数的影响实验数据参数典型范围对Sarsa影响对Q-Learning影响学习率α0.01-0.5过大导致震荡可设更大值(如0.5)探索率ε0.05-0.3需要持续探索可随时间衰减折扣因子γ0.8-0.99较高值(0.95)效果更好适中值(0.9)最佳4.3 实用技巧与陷阱规避经验技巧对Q-Learning使用ε衰减epsilon max(0.01, epsilon*0.995)初始化Q值为乐观值如0鼓励探索监控Q值变化幅度判断收敛常见陷阱固定ε导致Q-Learning持续跌落悬崖α过大导致Sarsa无法稳定收敛没有定期测试贪婪策略的真实表现# ε衰减示例 class DecayEpsilonAgent(QLearningAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.initial_epsilon self.epsilon def choose_action(self, state, episode): self.epsilon self.initial_epsilon / (1 episode // 100) return super().choose_action(state)5. 进阶扩展算法变体与性能提升5.1 Expected Sarsa实现结合Sarsa和Q-Learning优点的变体class ExpectedSarsaAgent(QLearningAgent): def update(self, state, action, reward, next_state, done): if done: td_target reward else: policy np.ones(self.env.action_space.n) * self.epsilon / self.env.action_space.n policy[np.argmax(self.Q[next_state])] 1 - self.epsilon td_target reward self.gamma * np.sum(policy * self.Q[next_state]) self.Q[state][action] self.alpha * (td_target - self.Q[state][action])5.2 使用经验回放提升样本效率的改进方案class ReplayBuffer: def __init__(self, capacity1000): self.buffer collections.deque(maxlencapacity) def add(self, experience): self.buffer.append(experience) def sample(self, batch_size): return random.sample(self.buffer, batch_size) # 在训练循环中 buffer ReplayBuffer() for episode in range(episodes): # ...环境交互... buffer.add((state, action, reward, next_state, done)) # 从buffer采样进行更新 batch buffer.sample(32) for exp in batch: agent.update(*exp)5.3 多步TD学习平衡MC和TD方法的折中方案class NStepSarsaAgent(SarsaAgent): def __init__(self, n_steps3, *args, **kwargs): super().__init__(*args, **kwargs) self.n_steps n_steps self.trajectory [] def update(self, state, action, reward, next_state, done): self.trajectory.append((state, action, reward)) if len(self.trajectory) self.n_steps or done: states, actions, rewards zip(*self.trajectory) G sum([r * (self.gamma**i) for i, r in enumerate(rewards)]) if not done: G (self.gamma**self.n_steps) * self.Q[next_state][self.choose_action(next_state)] s, a states[0], actions[0] self.Q[s][a] self.alpha * (G - self.Q[s][a]) self.trajectory.pop(0)结语从Cliff Walking到现实应用通过这个看似简单的网格世界我们已经掌握了强化学习最核心的思想精髓。在实际项目中这些算法经过适当调整可以应用于机器人路径规划游戏AI策略优化资源调度决策系统记住没有放之四海皆准的完美算法——Sarsa的保守稳健和Q-Learning的激进高效各有适用场景。真正的高手懂得根据实际问题特点选择合适的工具并通过系统化的实验验证找到最佳参数组合。