训练模型Train和测试模型Test最终都是为了应用Inference。本篇将教你如何加载已经保存的.pth模型文件并用一张外部图片来检验它的分类能力。1. 验证流程的三大核心步骤准备测试环境包括加载模型、处理单张图片。图像预处理输入的图片必须经过与训练时完全相同的缩放Resize和归一化ToTensor。模型推理将图片送入模型并解析输出结果。2. 代码实战验证模型的分类结果文件演示了如何加载一个在 CIFAR-10 上训练好的模型并识别一张“狗”或“飞机”的图片。Pythonimport torch import torchvision from PIL import Image from torch import nn # 1. 读取外部图片 image_path dog.png # 或者是你本地的图片路径 img Image.open(image_path) # 如果是 RGBA 格式带透明度需转为 RGB img img.convert(RGB) # 2. 图像预处理必须与训练时保持一致 transform torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor() ]) img transform(img) # 增加 Batch 维度从 [3, 32, 32] 变为 [1, 3, 32, 32] img torch.reshape(img, (1, 3, 32, 32)) # 3. 加载训练好的模型 # 注意如果模型是用方式一保存的加载时需要能访问到网络定义类 model torch.load(tudui_29.pth, map_locationtorch.device(cpu)) # 4. 进入验证模式 model.eval() with torch.no_grad(): output model(img) # 5. 解析输出 # output 是一个长度为 10 的向量值最大的位置即为预测类别 print(output) predict_idx output.argmax(1).item() # CIFAR-10 的类别映射固定顺序 classes [airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck] print(f预测结果是{classes[predict_idx]})3. 关键细节分析为什么需要img.convert(RGB)PNG 图片通常包含 4 个通道RGBA其中 A 是透明度。而我们的模型是基于 RGB 3 通道训练的。使用convert(RGB)可以保证无论输入什么格式的图片都能适配模型。map_location的妙用如果你的模型是在 GPU 上训练保存的但现在你想在只有 CPU 的电脑上运行验证加载时必须加上map_locationtorch.device(cpu)否则会报错。argmax(1)的逻辑模型输出的是图片属于这 10 个类别的“概率得分”。通过argmax(1)我们能直接提取出得分最高的那个位置的索引例如 5 代表狗。4. 总结从模型到应用分析完这个文件我们就完成了从数据采集到实战部署的全流程训练集/测试集-训练出高准确率模型。保存模型- 持久化存储。单张验证-将模型应用到真实场景中。 学习小结当你能成功地输入一张从网上下载的图片并让模型正确报出“cat”或“dog”时你就真正完成了深度学习的闭环。
117_PyTorch 实战:利用训练好的模型进行单张图片验证
发布时间:2026/5/27 7:23:07
训练模型Train和测试模型Test最终都是为了应用Inference。本篇将教你如何加载已经保存的.pth模型文件并用一张外部图片来检验它的分类能力。1. 验证流程的三大核心步骤准备测试环境包括加载模型、处理单张图片。图像预处理输入的图片必须经过与训练时完全相同的缩放Resize和归一化ToTensor。模型推理将图片送入模型并解析输出结果。2. 代码实战验证模型的分类结果文件演示了如何加载一个在 CIFAR-10 上训练好的模型并识别一张“狗”或“飞机”的图片。Pythonimport torch import torchvision from PIL import Image from torch import nn # 1. 读取外部图片 image_path dog.png # 或者是你本地的图片路径 img Image.open(image_path) # 如果是 RGBA 格式带透明度需转为 RGB img img.convert(RGB) # 2. 图像预处理必须与训练时保持一致 transform torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor() ]) img transform(img) # 增加 Batch 维度从 [3, 32, 32] 变为 [1, 3, 32, 32] img torch.reshape(img, (1, 3, 32, 32)) # 3. 加载训练好的模型 # 注意如果模型是用方式一保存的加载时需要能访问到网络定义类 model torch.load(tudui_29.pth, map_locationtorch.device(cpu)) # 4. 进入验证模式 model.eval() with torch.no_grad(): output model(img) # 5. 解析输出 # output 是一个长度为 10 的向量值最大的位置即为预测类别 print(output) predict_idx output.argmax(1).item() # CIFAR-10 的类别映射固定顺序 classes [airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck] print(f预测结果是{classes[predict_idx]})3. 关键细节分析为什么需要img.convert(RGB)PNG 图片通常包含 4 个通道RGBA其中 A 是透明度。而我们的模型是基于 RGB 3 通道训练的。使用convert(RGB)可以保证无论输入什么格式的图片都能适配模型。map_location的妙用如果你的模型是在 GPU 上训练保存的但现在你想在只有 CPU 的电脑上运行验证加载时必须加上map_locationtorch.device(cpu)否则会报错。argmax(1)的逻辑模型输出的是图片属于这 10 个类别的“概率得分”。通过argmax(1)我们能直接提取出得分最高的那个位置的索引例如 5 代表狗。4. 总结从模型到应用分析完这个文件我们就完成了从数据采集到实战部署的全流程训练集/测试集-训练出高准确率模型。保存模型- 持久化存储。单张验证-将模型应用到真实场景中。 学习小结当你能成功地输入一张从网上下载的图片并让模型正确报出“cat”或“dog”时你就真正完成了深度学习的闭环。