RNN的理解

对于RNN的理解

import torch
import torch.nn as nn
import torch.nn.functional as F# 手动实现一个简单的RNN
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()# 定义权重矩阵和偏置项self.hidden_size = hidden_sizeself.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))  # 输入到隐藏层的权重#

#注:
input_size = 4
hidden_size = 3
W_xh = torch.randn(input_size, hidden_size)
生成的 W_xh 会是一个形状为 (4, 3) 的张量,可能是这样的(数字是随机生成的):
tensor([[ 0.2973, -1.1254, 0.7172],
[ 0.0983, 0.2856, -0.4586],
[-0.0105, 0.2317, 0.2716],
[ 1.0431, -1.3894, -0.1525]])
这个张量有 4 行 3 列。

        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))  # 隐藏层到隐藏层的权重self.b_h = nn.Parameter(torch.zeros(hidden_size))  # 隐藏层偏置self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))  # 隐藏层到输出层的权重self.b_y = nn.Parameter(torch.zeros(output_size))  # 输出层偏置def forward(self, x):# 初始化隐藏状态为0h_t = torch.zeros(x.size(0), self.hidden_size)  # 初始隐藏状态 [[5]]

注:
x 是输入数据,形状是 (3, 5, 4),其中:
3 是批量大小(batch_size),即我们一次性输入网络的样本数是 3。
5 是序列长度(seq_len),每个样本有 5 个时间步。
4 是每个时间步的输入特征数量。
self.hidden_size 假设是 6,表示隐藏层的维度是 6。
x.size(0) 获取输入张量 x 的第一个维度的大小,也就是批量大小 3。
torch.zeros(3, 6) 会创建一个形状为 (3, 6) 的张量,表示有 3 个样本,每个样本有 6 个隐藏状态神经元(即隐状态的维度是 6)。所有的元素都初始化为 0。

    # 遍历时间步,逐个处理输入序列for t in range(x.size(1)):  # x.size(1) 是序列长度x_t = x[:, t, :]  # 获取当前时间步的输入 (batch_size, input_size)

`
注:x = torch.tensor([[[0.1, 0.2, 0.3, 0.4], # 第 0 时间步的输入 (第一个样本)
[0.5, 0.6, 0.7, 0.8], # 第 1 时间步的输入 (第一个样本)
[0.9, 1.0, 1.1, 1.2]], # 第 2 时间步的输入 (第一个样本)
[[1.3, 1.4, 1.5, 1.6], # 第 0 时间步的输入 (第二个样本)
[1.7, 1.8, 1.9, 2.0], # 第 1 时间步的输入 (第二个样本)
[2.1, 2.2, 2.3, 2.4]]]) # 第 2 时间步的输入 (第二个样本)

第一次循环 t=0:
x_t = x[:, 0, :]
x[:, 0, :] 会提取出所有样本在第 0 时间步的输入:

第一个样本在第 0 时间步的输入是 [0.1, 0.2, 0.3, 0.4]。

第二个样本在第 0 时间步的输入是 [1.3, 1.4, 1.5, 1.6]。

因此,x_t 的值是:

tensor([[0.1, 0.2, 0.3, 0.4],
[1.3, 1.4, 1.5, 1.6]])

        # 更新隐藏状态:h_t = tanh(W_xh * x_t + W_hh * h_t + b_h)h_t = torch.tanh(x_t @ self.W_xh + h_t @ self.W_hh + self.b_h)  # [[4]]

`
注:1. 计算 x_t @ W_xh
x_t @ W_xh 是输入 x_t 和权重矩阵 W_xh 的矩阵乘法。我们有 2 个样本,每个样本有 3 个输入特征,权重矩阵 W_xh 的形状是 (3, 4),所以乘法的结果是一个形状为 (2, 4) 的张量,即每个样本的隐藏状态更新的部分。

对于第一个样本:

[0.5, 0.6, 0.7] @ [[0.1, 0.2, -0.1, 0.4],
[0.3, 0.5, 0.2, -0.2],
[0.7, -0.1, 0.3, 0.5]]
我们可以计算它的结果:

= [0.5 * 0.1 + 0.6 * 0.3 + 0.7 * 0.7,
0.5 * 0.2 + 0.6 * 0.5 + 0.7 * (-0.1),
0.5 * -0.1 + 0.6 * 0.2 + 0.7 * 0.3,
0.5 * 0.4 + 0.6 * (-0.2) + 0.7 * 0.5]

= [0.05 + 0.18 + 0.49,
0.1 + 0.3 - 0.07,
-0.05 + 0.12 + 0.21,
0.2 - 0.12 + 0.35]

= [0.72, 0.33, 0.28, 0.43]
对于第二个样本:

[1.0, 1.2, 1.3] @ [[0.1, 0.2, -0.1, 0.4],
[0.3, 0.5, 0.2, -0.2],
[0.7, -0.1, 0.3, 0.5]]
计算结果:

= [1.0 * 0.1 + 1.2 * 0.3 + 1.3 * 0.7,
1.0 * 0.2 + 1.2 * 0.5 + 1.3 * (-0.1),
1.0 * -0.1 + 1.2 * 0.2 + 1.3 * 0.3,
1.0 * 0.4 + 1.2 * (-0.2) + 1.3 * 0.5]

= [0.1 + 0.36 + 0.91,
0.2 + 0.6 - 0.13,
-0.1 + 0.24 + 0.39,
0.4 - 0.24 + 0.65]

= [1.37, 0.67, 0.53, 0.81]
因此,x_t @ W_xh 的结果是:

tensor([[0.72, 0.33, 0.28, 0.43],
[1.37, 0.67, 0.53, 0.81]])

x_t @ self.W_xh:
x_t 是当前时间步的输入,形状是 (batch_size, input_size)。
self.W_xh 是输入到隐藏层的权重矩阵,形状是 (input_size, hidden_size)。
h_t @ self.W_hh:
h_t 是前一时间步的隐藏状态,形状是 (batch_size, hidden_size)。
self.W_hh 是隐藏层到隐藏层的权重矩阵,形状是 (hidden_size, hidden_size)。

    # 最后一个时间步的隐藏状态通过全连接层得到输出y_t = h_t @ self.W_hy + self.b_y  # 输出层return y_t

超参数设置

input_size = 10 # 输入特征维度
hidden_size = 20 # 隐藏层维度
output_size = 5 # 输出类别数
seq_length = 5 # 序列长度
batch_size = 3 # 批量大小

实例化模型

model = RNN(input_size, hidden_size, output_size)

打印模型结构

print(model)

创建随机输入数据 (batch_size, seq_length, input_size)

x = torch.randn(batch_size, seq_length, input_size)

前向传播

output = model(x)
print(“Output shape:”, output.shape) # 输出形状应为 (batch_size, output_size)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/48087.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 网络基础(二) (传输协议层:UDP、TCP)

目录 一、传输层的意义 二、端口号 1、五元组标识一个通信 2、端口号范围划分 3、知名端口号(Well-Know Port Number) (1)查看端口号 4、绑定端口号数目问题 5、pidof & netstat 命令 (1)ne…

守护进程编程

一、守护进程 守护进程的含义: 守护进程是计算机中的一种特殊后台服务进程,通常在系统启动时自动运行,独立于用户终端,长期驻留在系统中执行特定任务。守护进程是操作系统服务可靠性的关键,确保核心功能持续可用而不受…

JUC复习及面试题学习

资源来自沉默王二、小林coding、竹子爱熊猫、代码随想录 一、JUC 1、进程与线程 进程是对运行程序的封装,是系统进行资源调度和分配的最小单位。 线程是进程的子任务,是CPU调度分配的基本单位 不同的进程之间很难数据共享,同进程下的不同线…

python-图片分割

图片分割是图像处理中的一个重要任务,它的目标是将图像划分为多个区域或者对象,例如分割出物体、前景背景或特定的部分。在 Python 中,常用的图片分割方法包括传统的图像处理技术(例如阈值分割、区域生长等)和深度学习…

STM32嵌入式

一、创建工程项目 1、进入软件首页 2、新建项目,【file】->【new project】 3、选择需要的芯片 4、系统内核部分设置 ① 选择晶振(使用外部的高速晶振) ② 选择debug形式(SW类型) 5、时钟设置 6、选择自己需要的引脚设置&a…

7.QT-常用控件-QWidget|font|toolTip|focusPolicy|styleSheet(C++)

font API说明font()获取当前widget的字体信息.返回QFont对象.setFont(const QFont& font)设置当前widget的字体信息. 属性说明family字体家族.⽐如"楷体",“宋体”,"微软雅⿊"等.pointSize字体⼤⼩weight字体粗细.以数值⽅式表⽰粗细程度取值范围为[…

蓝桥杯之前缀和

一维前缀 解题思路 看到“区间之和”问题,直接想到“前缀和” 前缀和的核心公式: sum[i]sum[i−1]a[i] 利用前缀和求区间和 [l,r] 的公式: 区间和sum[r]−sum[l−1] 解题步骤模板 输入数组: 读取数组长度 n 和查询次数 m。 读…

⭐ Unity 使用Odin Inspector增强编辑器的功能:UIManager脚本实例

先看一下测试效果: 在Unity开发中,Odin Inspector已经成为了一个非常受欢迎的工具,它通过增强编辑器的功能,使得开发者在工作中更加高效,尤其是在处理复杂数据和自定义编辑器方面。今天,我们将通过一个简…

JBoss + WildFly 本地开发环境完全指南

JBoss WildFly 本地开发环境完全指南 本篇笔记主要实现在本地通过 docker 创建 JBoss 和 WildFly 服务器这一功能,基于红帽的禁制 EAP 版本的重新分发,所以我这里没办法放 JBoss EAP 的 zip 文件。WildFly 是免费开源的版本,可以在红帽官网找…

IDEA使用jclasslib Bytecode Viewer查看jvm字节码

学习jvm的时候,想查看字节码和局部变量表,可以使用idea安装jclasslib Bytecode View插件查看。 (1)安装工具: 安装完成后需要重启idea. (2)准备一段代码,编译运行 package com.te…

STM32控制DRV8825驱动42BYGH34步进电机

最近想玩一下人工智能,然后买了个步进电机想玩一下,刚到了一脸懵逼,发现驱动器20多块,有点超预算,然后整了个驱动板,方便自己画线路板,经过各种搜索,终于转起来了,记录一…

第十四节:实战场景-何实现全局状态管理?

React.createElement调用示例 Babel插件对JSX的转换逻辑 React 全局状态管理实战与 JSX 转换原理深度解析 一、React 全局状态管理实现方案 1. Context API useReducer 方案(轻量级首选) // 创建全局 Context 对象 const GlobalContext createConte…

QT网络拓扑图绘制实验

前言 在网络通讯中,我qt常用的是TCP或者UDP协议,就比方说TCP吧,一台服务器有时可能会和多台客户端相连接,我之前都是处理单链接情况,最近研究图结构的时候,突然就想到了这个问题。那么如何解决这个问题呢&…

【深度学习—李宏毅教程笔记】各式各样的 Attention

目录 一、普通 Self-Attention 的痛点 二、对 Self-Attention 的优化方式 1、Local Attention / Truncated Attention 2、Stride Attention 3、Global Attention 4、知名的 Self-Attention 的变形的应用 (1)Longformer (2&#xff09…

OceanBases数据库单机社区版保姆级安装

目录 背景 简介 安装 OceanBase 下载地址 上传解压安装包 ​编辑 执行安装命令 ​编辑 应用环境配置 执行以下命令,快速部署 OceanBase 数据库(仅用于简单使用,不应用于生产)。 查看一下数据库状态 ​编辑连接数据库 用户创建 使用工具Navi…

Linux守护进程

一、相关概念 QQ邮箱关于三种协议的解释:SMTP/IMAP服务 1.SMTP协议 SMTP(​​Simple Mail Transfer Protocol​​,简单邮件传输协议)是一种用于发送电子邮件的互联网标准。它在TCP/IP协议族中,通常使用25端口进行通…

Java【网络原理】(4)HTTP协议

目录 1.前言 2.正文 2.1自定义协议 2.2HTTP协议 2.2.1抓包工具 2.2.2请求响应格式 2.2.2.1URL 2.2.2.2urlencode 2.2.3认识方法 2.2.3.1GET与POST 2.2.3.2PUT与DELETE 2.2.4请求头关键属性 3.小结 1.前言 哈喽大家好啊,今天来继续给大家带来Java中网络…

【版本控制】idea中使用git

大家好,我是jstart千语。接下来继续对git的内容进行讲解。也是在开发中最常使用,最重要的部分,在idea中操作git。目录在右侧哦。 如果需要git命令的详解: 【版本控制】git命令使用大全-CSDN博客 一、配置git 要先关闭项目&#xf…

【中间件】redis使用

一、redis介绍 redis是一种NoSQL类型的数据库,其数据存储在内存中,因此其数据查询效率很高,很快。常被用作数据缓存,分布式锁 等。SpringBoot集成了Redis,可查看开发文档Redis开发文档。Redis有自己的可视化工具Redis …

一文粗通 Celery 分布式任务队列

目录 简介什么是 CeleryCelery 的基本组成Celery 的应用场景快速开始 设置热重载开发脚本基本任务管理绑定任务本身设置任务的执行超时时间允许任务重试自定义任务名称实现任务优先级 高级任务管理任务延迟执行指定时间执行任务超时自动取消任务优先级重试任务 任务链与工作流简…