强化学习算法实战:一个例子实现sarsa、dqn、ddqn、qac、a2c、trpo、ppo

简介

在学习强化学习算法:sarsa、dqn、ddqn、qac、a2c、trpo、ppo时,由于有大量数据公式的推导,觉得十分晦涩,且听过就忘记了。
但是当把算法应用于实战时,代码的实现要比数据推导要直观很多。
接下来通过不同的算法实现gym中的CartPole-v1游戏。

游戏介绍

CartPole(推车倒立摆) 是强化学习中经典的基准测试任务,因为其直观可视、方便调试、状态和动作空间小等特性,常用于入门教学和算法验证。它的目标是训练一个智能体(agent)通过左右移动小车,使车顶的杆子尽可能长时间保持竖直不倒。
在这里插入图片描述

  • 环境:小车(cart)可以在水平轨道上左右移动,顶部通过关节连接一根自由摆动的杆子(pole)。
  • 目标:通过左右移动小车,使杆子的倾斜角度不超出阈值(±12°或±15°),同时小车不超出轨道范围(如轨道长度的±2.4单位)。简单理解为,就是杆子不会倒下里,小车不会飞出屏幕。
  • 状态:状态空间包含4个连续变量,分别是小车位置(x),小车速度(v),杆子角度(θ),杆子角速度(ω)
  • 动作:动作空间只有2个离线动作,分别是0(向左移动)或1(向右移动)
    奖励机制:每成功保持杆子不倒+1分,目前是让奖励最大化,即杆子永远不倒

DQN&DDQN&Dueling DQN

DQN

构建价值网络

DQN需要一个价值网络和目标网络,用来评估执行的动作得到的动作价值。这两个网络使用的是同一个网络结构,通常使用神经网络来实现。
输入的维度是状态空间的4个变量,分别是小车位置(x),小车速度(v),杆子角度(θ),杆子角速度(ω)。输出的维度是动作空间的维度,分别表示向左、向右移动的动作价值。
代码如下:

class QNetWork(nn.Module):def __init__(self, state_dim, action_dim):super().__init__()self.fc = nn.Sequential(nn.Linear(state_dim, 64),nn.ReLU(),nn.Linear(64, 64),nn.ReLU(),nn.Linear(64, action_dim))def forward(self,x):return self.fc(x)

构建DQN智能体

DQN智能体具备如下功能:

  1. 选择动作
  2. 存储经验
  3. 训练上面的神经网络:目的是返回给定状态下尽可能接近真实的动作价值
  4. 模型保存
  5. 评估价值网络
初始化参数

初始化q网络、目标网络、设置优化器、设置经验回放使用的缓存大小、训练的batch_size
dqn的折扣因子、探索率、更新目标网络的频率、智能体步数计数器初始化、记录最佳网络分数值参数初始化、评估轮数设置

    def __init__(self,state_dim, action_dim):# 神经网络网络相关self.q_net = QNetWork(state_dim,action_dim)self.target_net = QNetWork(state_dim,action_dim)self.target_net.load_state_dict(self.q_net.state_dict())self.optimizer = optim.Adam(self.q_net.parameters(), lr=1e-3)self.replay_buffer = deque(maxlen=10000)self.batch_size = 64# DQN相关self.gamma = 0.99self.epsilon = 0.1self.update_target_freq = 100self.step_count=0self.best_avg_reward = 0self.eval_episodes=5
动作选择

在DQN算法中,会采用epsilon参数来增加智能体选择动作的探索性,因此,动作选择的代码逻辑为:

  • 以epsilon的概率随机选择一个动作
  • 以1-epsilon的概率来选择价值网络返回结果中动作价值更大的动作
    def choose_action(self, state):if np.random.rand() < self.epsilon:return np.random.randint(0,2) //随机选择动作else:state_tensor = torch.FloatTensor(state)q_values = self.q_net(state_tensor) //调用价值网络选择动作return q_values.cpu().detach().numpy().argmax()
存储经验

DQN中,使用经验回放,那么我们需要预留缓冲区来存放历史的轨迹数据,方便后续取出用来训练网络
存储当前的状态、选择的动作、获得的奖励、下一个状态、游戏是否结束(即杆子是不是倒下)

    def store_experience(self,state, action, reward, next_state, done):self.replay_buffer.append((state, action, reward, next_state, done))
训练神经网络⭐️

最终要的部分,不同的算法,基本上也就是这一部分存在差异。
代码流程如下:

  1. 缓冲区中采样一个batch的数据
  2. 计算当前q值和目标q值(不同的dqn算法,主要是这两步计算的方式不同)
  3. 计算损失,除了后面ppo需要自己构造损失函数,其他的基本都是用MSELoss,也就是当前q值和目标q值的平方差,
  4. 梯度下降&更新网络,这两部都有现成的库来完成,基本上也是固定代码,
    def train(self):# 判断是否有足够经验用例用来学习if len(self.replay_buffer) < self.batch_size:return# 从缓冲区随机采样batch = random.sample(self.replay_buffer, self.batch_size)states, actions, rewards, next_states, dones = zip(*batch)states = torch.FloatTensor(np.array(states))actions = torch.LongTensor(actions)rewards = torch.FloatTensor(rewards)next_states = torch.FloatTensor(np.array(next_states))dones = torch.FloatTensor(dones)# 计算当前q值,输入当前状态至q网络获得所有q值,使用历史经验中选择的动作的q值作为当前q值current_q = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze()# 计算目标q值with torch.no_grad():# 使用目标网络计算下一个状态的所有q值,直接选择最大的q值作为下一状态的q值# 这里其实隐含着两个步骤:使用目标网络选择动作+计算下一状态q值next_q = self.target_net(next_states).max(1)[0]# 计算目标q值,1-dones表示当前状态执行动作后如果杆子倒了,那么为1-dones=0,否则为1target_q = rewards + self.gamma * next_q * (1 - dones)# 计算损失、梯度下降、网络更新loss = nn.MSELoss()(current_q,target_q)self.optimizer.zero_grad()loss.backward()self.optimizer.step()# 计算步数,每隔一定步数更新目标网络self.step_count += 1if self.step_count % self.update_target_freq == 0:self.target_net.load_state_dict({k: v.clone() for k, v in self.q_net.state_dict().items()})
保存模型
    def save_model(self,path="./output/best_model_bak.pth"):torch.save(self.q_net.state_dict(),path)print(f"Model saved to {path}")
评估模型

这一步用来评估当前网络的好坏,我们评估得分更高的网络进行保存。
评估的逻辑如下:

  1. 从初始状态开始,使用我们训练的q网络选择动作控制杆子
  2. 累加轨迹下每个步骤获得的reward作为当前网络的得分
  3. 一直到杆子倒下(说明网络效果不够好)
  4. 或者分数足够高(说明网络可以控制杆子很长时间保持平衡,网络效果效果很好),再跳出循环
  5. 评估一定轮数后,取分数均值作为评估结果,分数越高越好

代码如下:

    def evaluate(self, env):# 入参是游戏环境,需要一个新的环境进行评估# 由于模型评估不需要随机探索,因此记录当前的epsilon,并设置智能体epsilon=0origin_epsilon = self.epsilonself.epsilon = 0total_rewards = []# self.eval_episodes表示评估的轮数for _ in range(self.eval_episodes):state = env.reset()[0]episode_reward = 0while True:# 使用智能体选择动作action = self.choose_action(state)# 与环境交互next_state, reward, done, _, _ = env.step(action)# 得到reward和下一状态episode_reward += rewardstate = next_state# 判断杆子是否倒下、分数是否足够高if done or episode_reward>2e4:breaktotal_rewards.append(episode_reward)# 网络需要回到评估前的状态继续训练,因此恢复epsilon的值self.epsilon = origin_epsilon# 返回平均分数return np.mean(total_rewards)

主流程

  1. 初始化游戏环境:游戏环境用于选择动作后交互,并获取reward和下一状态,初始化后获得初始状态
  2. 初始化智能体:构建智能体对象,设置epsilon=1,因为初始q网络效果不好,所以设置很大的epsilon让智能体自由探索环境,设置训练的episode数目
  3. 对于每个episode:
    • 对于episode中的每一步
      • 使用智能体选择动作
      • 与环境交互,并获得下一状态next_state,该动作的奖励值reward,杆子是否倒下done,存储本次经验
      • 智能体训练
    • 当一个episode结束后,更新一次epsilon参数,使epsilon慢慢衰减,这是因为随着训练,q网络效果变好,这时我们开始慢慢相信q网络给我们做出的选择
    • 每10个episode我们对q网络做一次评估,存储最佳的q网络

代码如下:

if __name__ == '__main__':env = gym.make('CartPole-v1')state_dim = env.observation_space.shape[0]action_dim = env.action_space.nagent = DQNAgent(state_dim, action_dim)config = {"episode": 600,"epsilon_start": 1.0,"epsilon_end": 0.01,"epsilon_decay": 0.995,}agent.epsilon = config["epsilon_start"]for episode in range(config["episode"]):state = env.reset()[0]total_reward = 0while True:action = agent.choose_action(state)next_state, reward, done, _, _ = env.step(action)agent.store_experience(state, action, reward, next_state, done)agent.train()total_reward += rewardstate=next_stateif done or total_reward > 2e4:breakagent.epsilon = max(config["epsilon_end"],agent.epsilon*config["epsilon_decay"])if episode % 10 == 0:eval_env = gym.make('CartPole-v1')avg_reward = agent.evaluate(eval_env)eval_env.close()if avg_reward > agent.best_avg_reward:agent.best_avg_reward = avg_rewardagent.save_model(path=f"output/best_model.pth")print(f"new best model saved with average reward: {avg_reward}")print(f"Episode: {episode},Train Reward: {total_reward},Best Eval Avg Reward: {agent.best_avg_reward}")

DDQN

DDQN与DQN的区别是:

  • DQN使用目标网络选择动作并计算q值
  • DDQN使用主网络(q_net)选择动作,使用目标网络计算q值

因此DDQN和DQN的代码区别仅有一行代码:

        # DQNcurrent_q = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze()with torch.no_grad():# 根据目标网络计算出来的所有Q值,选择了Q值最大的动作next_q = self.target_net(next_states).max(1)[0]target_q = rewards + self.gamma * next_q * (1 - dones)# DDQNcurrent_q = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze()with torch.no_grad():# 使用主网络选择下一状态要执行的动作next_actions = torch.LongTensor(torch.argmax(self.q_net(next_states),dim=1))# 使用目标网络计算执行这个动作后的q值next_q=self.target_net(next_states).gather(1,next_actions.unsqueeze(1)).squeeze()target_q = rewards + self.gamma * next_q * (1 - dones)

Dueling DQN

Dueling DQN与DQN的区别主要是神经网络结构不同

  1. DQN简单的神经网络,输入状态,输出不同动作的价值
  2. Dueling DQN的网络内有两个分支
    • 价值流:记录当前状态的状态价值,为标量
    • 优势流:表示该动作的优势,与动作价值的作用相同,都是评估动作的好坏,是一个动作维度的响亮

直接看代码,Dueling DQN的网络结构如下:

class QNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(QNetwork, self).__init__()self.feature = nn.Sequential(nn.Linear(state_dim, 128),nn.ReLU())self.advantage = nn.Sequential(nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, action_dim))self.value = nn.Sequential(nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, 1))def forward(self, x):x = self.feature(x)# 优势流advantage = self.advantage(x)# 价值流value = self.value(x)# advantage - advantage.mean()消除优势函数基线影响return value + advantage - advantage.mean()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/49389.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《数据库原理》部分习题解析1

《数据库原理》部分习题解析1 1. 名词解释 &#xff08;1&#xff09;关系&#xff08;2&#xff09;属性&#xff08;3&#xff09;域&#xff08;4&#xff09;元组&#xff08;5&#xff09;码&#xff08;6&#xff09;分量&#xff08;7&#xff09;关系模式 &#xff0…

20250515通过以太网让VLC拉取视熙科技的机芯的rtsp视频流的步骤

20250515通过以太网让VLC拉取视熙科技的机芯的rtsp视频流的步骤 2025/5/15 20:26 缘起&#xff1a;荣品的PRO-RK3566适配视熙科技 的4800W的机芯。 1080p出图预览的时候没图了。 通过105的机芯出图确认 荣品的PRO-RK3566 的硬件正常。 然后要确认 视熙科技 的4800W的机芯是否出…

AI 赋能防艾宣传:从创意到实践,我的 IP 形象设计之旅

在数字技术飞速发展的今天&#xff0c;如何让严肃的健康传播变得更有温度、更具吸引力&#xff1f;作为一名参与防艾宣传实践的学生&#xff0c;我尝试通过 AI 工具构建专属 IP 形象&#xff0c;让防艾知识从 "被动接受" 转化为 "主动探索"。这篇文章将分享…

机器学习笔记2

5 TfidfVectorizer TF-IDF文本特征词的重要程度特征提取 (1) 算法 词频(Term Frequency, TF), 表示一个词在当前篇文章中的重要性 逆文档频率(Inverse Document Frequency, IDF), 反映了词在整个文档集合中的稀有程度 (2) API sklearn.feature_extraction.text.TfidfVector…

MYSQL 子查询

标量子查询 #标量子查询 #1.查询“研发部”的所有员工信息 # a.查询“研发部”部门id select id from dept where name研发部;#b.根据销售部部门id&#xff0c;查询员工信息 select * from emp where dept_id(select id from dept where name研发部);#2.查询在王金彪入职之后…

EasyRTC嵌入式音视频通话SDK驱动智能硬件音视频应用新发展

一、引言 在数字化浪潮下&#xff0c;智能硬件蓬勃发展&#xff0c;从智能家居到工业物联网&#xff0c;深刻改变人们的生活与工作。音视频通讯作为智能硬件交互与协同的核心&#xff0c;重要性不言而喻。但嵌入式设备硬件资源受限&#xff0c;传统音视频方案集成困难。EasyRT…

c/c++中程序内存区域的划分

c/c程序内存分配的几个区域&#xff1a; 1.栈区&#xff1a;在执行函数时&#xff0c;函数内局部变量的存储单元都可以在栈上创建&#xff0c;函数执行结束时这些存储单元自动被释放&#xff0c;栈内存分配运算内置于处理器的指令集中&#xff0c;效率很高但是分配的内存容量有…

微信小程序 自定义图片分享-绘制数据图片以及信息文字

一 、需求 从数据库中读取头像&#xff0c;姓名电话等信息&#xff0c;当分享给女朋友时&#xff0c;每个信息不一样 二、实现方案 1、先将数据库中需要的头像姓名信息读取出来加载到data 数据项中 data:{firstName:, // 姓名img:, // 头像shareImage:,// 存储临时图片 } 2…

IPLOOK超轻量核心网,助力5G专网和MEC边缘快速落地

随着5G深入千行百业&#xff0c;行业客户对核心网的灵活性、可控性和部署效率提出了更高要求。IPLOOK面向数字化转型需求&#xff0c;推出了超轻量级核心网解决方案&#xff0c;具备体积小、资源占用少、部署灵活、易于维护等特性&#xff0c;广泛适用于专网、实验室、MEC边缘云…

搭建基于chrony+OpenSSL(NTS协议)多层级可信时间同步服务

1、时间同步服务的层级概念 在绝大多数IT工程师实际工作过程中&#xff0c;针对于局域网的时间同步&#xff0c;遇到最多的场景是根据实际的需求&#xff0c;搭建一个简单的NTP时间同步服务以时间对局域网中的服务器、网络设备、个人电脑等基础设施实现同步授时功能。虽然这样…

单片机-STM32部分:12、I2C

飞书文档https://x509p6c8to.feishu.cn/wiki/MsB7wLebki07eUkAZ1ec12W3nsh 一、简介 IIC协议&#xff0c;又称I2C协议&#xff0c;是由PHILP公司在80年代开发的两线式串行总线&#xff0c;用于连接微控制器及其外围设备&#xff0c;IIC属于半双工同步通信方式。 IIC是一种同步…

存算一体芯片对传统GPU架构的挑战:在GNN训练中的颠覆性实验

点击 “AladdinEdu&#xff0c;同学们用得起的【H卡】算力平台”&#xff0c;H卡级别算力&#xff0c;按量计费&#xff0c;灵活弹性&#xff0c;顶级配置&#xff0c;学生专属优惠。 一、冯诺依曼架构的"三座大山"与GNN算力困境 当前图神经网络&#xff08;GNN&…

AI大模型学习十八、利用Dify+deepseekR1 +本地部署Stable Diffusion搭建 AI 图片生成应用

一、说明 最近在学习Dify工作流的一些玩法&#xff0c;下面将介绍一下Dify Stable Diffusion实现文生图工作流的应用方法 Dify与Stable Diffusion的协同价值 Dify作为低代码AI开发平台的优势&#xff1a;可视化编排、API快速集成 Stable Diffusion的核心能力&#xff1a;高效…

初识Linux · IP分片

目录 前言&#xff1a; IP分片 分片vs不分片 如何分片 分片举例 三个字段 前言&#xff1a; 前文IP协议上和IP协议下我们已经把IP协议的报头的大多数字段介绍了&#xff0c;唯独有三个字段现在还有介绍&#xff0c;即16位标识&#xff0c;8位协议&#xff0c;13位片偏移…

FPGA: UltraScale+ bitslip实现(ISERDESE3)

收获 一晃五年~ 五年前那个夏夜&#xff0c;我对着泛蓝的屏幕敲下《给十年后的自己》&#xff0c;在2020年的疫情迷雾中编织着对未来的想象。此刻回望&#xff0c;第四届集创赛的参赛编号仍清晰如昨&#xff0c;而那个在家熬夜焊电路板的"不眠者"&#xff0c;现在…

端侧智能重构智能监控新路径 | 2025 高通边缘智能创新应用大赛第三场公开课来袭!

2025 高通边缘智能创新应用大赛初赛激战正酣&#xff0c;系列公开课持续输出硬核干货&#xff01; 5月20日晚8点&#xff0c;第三场重磅课程《端侧智能如何重构下一代智能监控》将准时开启&#xff0c;广翼智联高级产品市场经理伍理化将聚焦智能监控领域的技术变革与产业落地&…

【实战篇】低代码报表开发——平台运营日报表的开发实录

前言 myBuilder的推广有段时间了&#xff0c;想开发个报表看看平台运营的情况。采用myBuilder强大的报表、数据交换模块功能&#xff0c;直接开干。 1. 报表指标思考与概要设计 首先是报表模块的概要设计&#xff0c;先构思一下&#xff0c;我希望报表能查看新用户注册、活跃…

用MCP往ppt文件里插入系统架构图

文章目录 一、技术架构解析1. Markdown解析模块(markdown_to_hierarchy)2. 动态布局引擎(give_hierarchy_positions)3. PPTX生成模块(generate_pptx)二、核心技术亮点1. 自适应布局算法2. MCP服务集成三、工程实践建议1. 性能优化方向2. 样式扩展方案3. 部署实践四、应用…

ubuntu服务器版启动卡在start job is running for wait for...to be Configured

目录 前言 一、原因分析 二、解决方法 总结 前言 当 Ubuntu 服务器启动时&#xff0c;系统会显示类似 “start job is running for wait for Network to be Configured” 或 “start job is running for wait for Plymouth Boot Screen Service” 等提示信息&#xff0c;并且…

Midjourney 最佳创作思路与实战技巧深度解析【附提示词与学习资料包下载】

引言 在人工智能图像生成领域&#xff0c;Midjourney 凭借其强大的艺术表现力和灵活的创作模式&#xff0c;已成为设计师、艺术家和创意工作者的核心工具。作为 CSDN 博主 “小正太浩二”&#xff0c;我将结合多年实战经验&#xff0c;系统分享 Midjourney 的创作方法论&#x…