用PyTorch从零搭建U-Net手把手教你实现医学图像分割附完整代码与DRIVE数据集处理视网膜血管分割是医学影像分析中的经典任务它能帮助医生快速识别糖尿病视网膜病变等疾病。2015年提出的U-Net架构因其在小型医学数据集上的出色表现成为这一领域的标杆模型。本文将带您从零开始用PyTorch实现一个完整的U-Net解决方案包含数据处理、模型构建、训练优化等全流程代码。1. 环境准备与数据加载在开始编码前我们需要配置合适的开发环境。推荐使用Python 3.8和PyTorch 1.10版本这些组合经过验证具有最佳兼容性。以下是核心依赖的安装命令pip install torch torchvision numpy pillow matplotlib tqdmDRIVE数据集是视网膜血管分割的标准benchmark包含40张眼底图像20训练20测试每张图像都配有专家标注的血管掩码。数据集目录结构应如下DRIVE/ ├── training/ │ ├── images/ # 原始图像(.tif) │ ├── 1st_manual/ # 专家标注(.gif) │ └── mask/ # ROI区域(.gif) └── test/ ├── images/ ├── 1st_manual/ └── mask/数据加载的关键在于正确处理图像与掩码的对应关系。我们创建DriveDataset类继承PyTorch的Dataset核心逻辑包括def __getitem__(self, idx): img Image.open(self.img_list[idx]).convert(RGB) manual Image.open(self.manual[idx]).convert(L) manual np.array(manual) / 255 # 归一化到[0,1] roi_mask 255 - np.array(Image.open(self.roi_mask[idx]).convert(L)) mask np.clip(manual roi_mask, 0, 255) # 合并标注与ROI return self.transforms(img, Image.fromarray(mask))注意DRIVE数据集的掩码需要进行特殊处理将专家标注与ROI区域结合确保非关注区域不被计入损失计算。2. 数据增强策略医学影像数据有限恰当的数据增强能显著提升模型泛化能力。我们设计了一套针对视网膜图像的增强流水线trans [ T.RandomResize(282, 678), # 随机缩放(50%-120% of 565) T.RandomHorizontalFlip(0.5), # 水平翻转 T.RandomVerticalFlip(0.5), # 垂直翻转 T.RandomCrop(480), # 随机裁剪 T.ToTensor(), T.Normalize(mean[0.709,0.381,0.224], std[0.127,0.079,0.043]) ]关键参数说明增强类型参数设置医学影像适用性说明随机缩放min_size282, max_size678保持血管结构比例不变随机翻转概率0.5视网膜图像具有旋转对称性随机裁剪size480保留中心区域关键特征标准化数据集特定均值/标准差消除光照差异影响验证集只需进行最基本的转换eval_trans [ T.ToTensor(), T.Normalize(mean[0.709,0.381,0.224], std[0.127,0.079,0.043]) ]3. U-Net模型架构实现标准的U-Net由编码器下采样和解码器上采样组成中间通过跳跃连接融合多尺度特征。我们采用改进版设计class DoubleConv(nn.Sequential): def __init__(self, in_channels, out_channels): super().__init__( nn.Conv2d(in_channels, out_channels, 3, padding1, biasFalse), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue), nn.Conv2d(out_channels, out_channels, 3, padding1, biasFalse), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue) ) class Down(nn.Sequential): def __init__(self, in_channels, out_channels): super().__init__( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) class Up(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size2, stride2) self.conv DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 self.up(x1) # 处理尺寸不匹配问题 diffY x2.size()[2] - x1.size()[2] diffX x2.size()[3] - x1.size()[3] x1 F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2]) x torch.cat([x2, x1], dim1) return self.conv(x)模型改进亮点对称填充卷积所有卷积层设置padding1保持特征图尺寸不变避免原始U-Net中的裁剪操作批量归一化每个卷积后加入BN层加速训练并提升模型稳定性自适应上采样转置卷积后自动计算padding处理奇数尺寸输入完整的UNet类组织这些模块class UNet(nn.Module): def __init__(self, in_channels3, num_classes1, base_c64): super().__init__() self.in_conv DoubleConv(in_channels, base_c) self.down1 Down(base_c, base_c*2) self.down2 Down(base_c*2, base_c*4) self.down3 Down(base_c*4, base_c*8) self.down4 Down(base_c*8, base_c*16) self.up1 Up(base_c*16, base_c*8) self.up2 Up(base_c*8, base_c*4) self.up3 Up(base_c*4, base_c*2) self.up4 Up(base_c*2, base_c) self.out_conv nn.Conv2d(base_c, num_classes, 1) def forward(self, x): x1 self.in_conv(x) x2 self.down1(x1) x3 self.down2(x2) x4 self.down3(x3) x5 self.down4(x4) x self.up1(x5, x4) x self.up2(x, x3) x self.up3(x, x2) x self.up4(x, x1) return self.out_conv(x)4. 训练流程与技巧训练视网膜分割网络需要特别注意损失函数选择和评估指标。我们采用以下配置损失函数组合Dice Loss处理类别不平衡问题BCE Loss提供像素级梯度信号def criterion(inputs, target): bce_loss F.binary_cross_entropy_with_logits(inputs, target) dice_loss 1 - dice_coeff(torch.sigmoid(inputs), target) return bce_loss dice_loss优化器设置optimizer torch.optim.SGD( model.parameters(), lr0.01, momentum0.9, weight_decay1e-4 ) lr_scheduler create_lr_scheduler(optimizer, len(train_loader), epochs100)训练监控指标Dice系数衡量分割区域重叠度混淆矩阵计算精确率、召回率def evaluate(model, data_loader, device): model.eval() confmat ConfusionMatrix(num_classes2) dice 0 with torch.no_grad(): for image, target in data_loader: image, target image.to(device), target.to(device) output model(image) confmat.update(target.flatten(), output.argmax(1).flatten()) dice dice_coeff(torch.sigmoid(output), target) return confmat, dice / len(data_loader)实际训练中发现几个关键技巧使用混合精度训练可减少显存占用允许更大batch size渐进式学习率预热能避免初期梯度爆炸在验证集Dice系数不再提升时早停可防止过拟合5. 预测与结果可视化训练完成后我们可以加载最佳模型进行预测model UNet().to(device) model.load_state_dict(torch.load(best_model.pth)) model.eval() with torch.no_grad(): output model(test_image.unsqueeze(0).to(device)) pred torch.sigmoid(output).squeeze().cpu().numpy() binary_mask (pred 0.5).astype(np.uint8)可视化对比结果plt.figure(figsize(12,4)) plt.subplot(131); plt.imshow(test_image.permute(1,2,0)) plt.title(Original Image) plt.subplot(132); plt.imshow(gt_mask, cmapgray) plt.title(Ground Truth) plt.subplot(133); plt.imshow(binary_mask, cmapgray) plt.title(Prediction) plt.show()典型分割结果会显示模型能准确识别主要血管分支但在处理微小血管时可能出现断裂。这可以通过以下方式改进增加模型深度提升base_c到128使用注意力机制增强微小特征提取引入边缘感知损失函数完整项目应包含以下目录结构unet-retina/ ├── data/ # 数据集 ├── src/ │ ├── model.py # U-Net实现 │ ├── dataset.py # 数据加载 │ └── transforms.py # 数据增强 ├── train.py # 训练脚本 ├── predict.py # 预测脚本 └── utils/ # 辅助工具
用PyTorch从零搭建U-Net:手把手教你实现医学图像分割(附完整代码与DRIVE数据集处理)
发布时间:2026/6/3 15:33:25
用PyTorch从零搭建U-Net手把手教你实现医学图像分割附完整代码与DRIVE数据集处理视网膜血管分割是医学影像分析中的经典任务它能帮助医生快速识别糖尿病视网膜病变等疾病。2015年提出的U-Net架构因其在小型医学数据集上的出色表现成为这一领域的标杆模型。本文将带您从零开始用PyTorch实现一个完整的U-Net解决方案包含数据处理、模型构建、训练优化等全流程代码。1. 环境准备与数据加载在开始编码前我们需要配置合适的开发环境。推荐使用Python 3.8和PyTorch 1.10版本这些组合经过验证具有最佳兼容性。以下是核心依赖的安装命令pip install torch torchvision numpy pillow matplotlib tqdmDRIVE数据集是视网膜血管分割的标准benchmark包含40张眼底图像20训练20测试每张图像都配有专家标注的血管掩码。数据集目录结构应如下DRIVE/ ├── training/ │ ├── images/ # 原始图像(.tif) │ ├── 1st_manual/ # 专家标注(.gif) │ └── mask/ # ROI区域(.gif) └── test/ ├── images/ ├── 1st_manual/ └── mask/数据加载的关键在于正确处理图像与掩码的对应关系。我们创建DriveDataset类继承PyTorch的Dataset核心逻辑包括def __getitem__(self, idx): img Image.open(self.img_list[idx]).convert(RGB) manual Image.open(self.manual[idx]).convert(L) manual np.array(manual) / 255 # 归一化到[0,1] roi_mask 255 - np.array(Image.open(self.roi_mask[idx]).convert(L)) mask np.clip(manual roi_mask, 0, 255) # 合并标注与ROI return self.transforms(img, Image.fromarray(mask))注意DRIVE数据集的掩码需要进行特殊处理将专家标注与ROI区域结合确保非关注区域不被计入损失计算。2. 数据增强策略医学影像数据有限恰当的数据增强能显著提升模型泛化能力。我们设计了一套针对视网膜图像的增强流水线trans [ T.RandomResize(282, 678), # 随机缩放(50%-120% of 565) T.RandomHorizontalFlip(0.5), # 水平翻转 T.RandomVerticalFlip(0.5), # 垂直翻转 T.RandomCrop(480), # 随机裁剪 T.ToTensor(), T.Normalize(mean[0.709,0.381,0.224], std[0.127,0.079,0.043]) ]关键参数说明增强类型参数设置医学影像适用性说明随机缩放min_size282, max_size678保持血管结构比例不变随机翻转概率0.5视网膜图像具有旋转对称性随机裁剪size480保留中心区域关键特征标准化数据集特定均值/标准差消除光照差异影响验证集只需进行最基本的转换eval_trans [ T.ToTensor(), T.Normalize(mean[0.709,0.381,0.224], std[0.127,0.079,0.043]) ]3. U-Net模型架构实现标准的U-Net由编码器下采样和解码器上采样组成中间通过跳跃连接融合多尺度特征。我们采用改进版设计class DoubleConv(nn.Sequential): def __init__(self, in_channels, out_channels): super().__init__( nn.Conv2d(in_channels, out_channels, 3, padding1, biasFalse), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue), nn.Conv2d(out_channels, out_channels, 3, padding1, biasFalse), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue) ) class Down(nn.Sequential): def __init__(self, in_channels, out_channels): super().__init__( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) class Up(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size2, stride2) self.conv DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 self.up(x1) # 处理尺寸不匹配问题 diffY x2.size()[2] - x1.size()[2] diffX x2.size()[3] - x1.size()[3] x1 F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2]) x torch.cat([x2, x1], dim1) return self.conv(x)模型改进亮点对称填充卷积所有卷积层设置padding1保持特征图尺寸不变避免原始U-Net中的裁剪操作批量归一化每个卷积后加入BN层加速训练并提升模型稳定性自适应上采样转置卷积后自动计算padding处理奇数尺寸输入完整的UNet类组织这些模块class UNet(nn.Module): def __init__(self, in_channels3, num_classes1, base_c64): super().__init__() self.in_conv DoubleConv(in_channels, base_c) self.down1 Down(base_c, base_c*2) self.down2 Down(base_c*2, base_c*4) self.down3 Down(base_c*4, base_c*8) self.down4 Down(base_c*8, base_c*16) self.up1 Up(base_c*16, base_c*8) self.up2 Up(base_c*8, base_c*4) self.up3 Up(base_c*4, base_c*2) self.up4 Up(base_c*2, base_c) self.out_conv nn.Conv2d(base_c, num_classes, 1) def forward(self, x): x1 self.in_conv(x) x2 self.down1(x1) x3 self.down2(x2) x4 self.down3(x3) x5 self.down4(x4) x self.up1(x5, x4) x self.up2(x, x3) x self.up3(x, x2) x self.up4(x, x1) return self.out_conv(x)4. 训练流程与技巧训练视网膜分割网络需要特别注意损失函数选择和评估指标。我们采用以下配置损失函数组合Dice Loss处理类别不平衡问题BCE Loss提供像素级梯度信号def criterion(inputs, target): bce_loss F.binary_cross_entropy_with_logits(inputs, target) dice_loss 1 - dice_coeff(torch.sigmoid(inputs), target) return bce_loss dice_loss优化器设置optimizer torch.optim.SGD( model.parameters(), lr0.01, momentum0.9, weight_decay1e-4 ) lr_scheduler create_lr_scheduler(optimizer, len(train_loader), epochs100)训练监控指标Dice系数衡量分割区域重叠度混淆矩阵计算精确率、召回率def evaluate(model, data_loader, device): model.eval() confmat ConfusionMatrix(num_classes2) dice 0 with torch.no_grad(): for image, target in data_loader: image, target image.to(device), target.to(device) output model(image) confmat.update(target.flatten(), output.argmax(1).flatten()) dice dice_coeff(torch.sigmoid(output), target) return confmat, dice / len(data_loader)实际训练中发现几个关键技巧使用混合精度训练可减少显存占用允许更大batch size渐进式学习率预热能避免初期梯度爆炸在验证集Dice系数不再提升时早停可防止过拟合5. 预测与结果可视化训练完成后我们可以加载最佳模型进行预测model UNet().to(device) model.load_state_dict(torch.load(best_model.pth)) model.eval() with torch.no_grad(): output model(test_image.unsqueeze(0).to(device)) pred torch.sigmoid(output).squeeze().cpu().numpy() binary_mask (pred 0.5).astype(np.uint8)可视化对比结果plt.figure(figsize(12,4)) plt.subplot(131); plt.imshow(test_image.permute(1,2,0)) plt.title(Original Image) plt.subplot(132); plt.imshow(gt_mask, cmapgray) plt.title(Ground Truth) plt.subplot(133); plt.imshow(binary_mask, cmapgray) plt.title(Prediction) plt.show()典型分割结果会显示模型能准确识别主要血管分支但在处理微小血管时可能出现断裂。这可以通过以下方式改进增加模型深度提升base_c到128使用注意力机制增强微小特征提取引入边缘感知损失函数完整项目应包含以下目录结构unet-retina/ ├── data/ # 数据集 ├── src/ │ ├── model.py # U-Net实现 │ ├── dataset.py # 数据加载 │ └── transforms.py # 数据增强 ├── train.py # 训练脚本 ├── predict.py # 预测脚本 └── utils/ # 辅助工具