用PyTorch实战CVPR 2016经典Context Encoder图像修复全流程解析当你在巴黎街头拍摄的照片中央出现一块碍眼的污渍或是老照片的某个角落因年代久远而破损时是否想过用AI技术让这些缺失的部分重获新生2016年CVPR会议上提出的Context Encoder正是解决这类图像修复问题的里程碑式工作。不同于传统的补丁复制方法它通过深度学习实现了真正的内容生成。本文将带你从零开始用PyTorch完整复现这一经典算法并分享实际训练中的关键技巧。1. 环境准备与核心架构解析在开始编写代码前我们需要明确Context Encoder的三大核心组件基于AlexNet的编码器、创新的通道全连接层Channel-wise FC以及上卷积构成的解码器。这个结构看似简单却蕴含着几个精妙的设计选择。首先创建Python环境并安装必要依赖conda create -n context_encoder python3.8 conda activate context_encoder pip install torch1.12.0 torchvision0.13.0 pillow9.2.0编码器部分采用AlexNet的前五个卷积层但需要注意三个关键调整点移除原始AlexNet的全连接层和分类头所有卷积层采用随机初始化而非预训练权重在conv5后添加额外的1x1卷积进行特征压缩import torch.nn as nn class Encoder(nn.Module): def __init__(self): super().__init__() self.features nn.Sequential( nn.Conv2d(3, 96, kernel_size11, stride4, padding2), # conv1 nn.ReLU(inplaceTrue), nn.MaxPool2d(kernel_size3, stride2), nn.Conv2d(96, 256, kernel_size5, padding2), # conv2 nn.ReLU(inplaceTrue), nn.MaxPool2d(kernel_size3, stride2), nn.Conv2d(256, 384, kernel_size3, padding1), # conv3 nn.ReLU(inplaceTrue), nn.Conv2d(384, 384, kernel_size3, padding1), # conv4 nn.ReLU(inplaceTrue), nn.Conv2d(384, 256, kernel_size3, padding1), # conv5 nn.ReLU(inplaceTrue), nn.Conv2d(256, 128, kernel_size1) # 特征压缩 )2. 通道全连接层的创新实现原文最具创新性的部分是通道全连接层Channel-wise FC它解决了传统全连接层的参数爆炸问题。具体实现时需要理解其数学本质对特征图的每个通道独立进行全连接操作相当于一组平行的1x1卷积。关键实现细节输入特征图尺寸128x6x6假设输入图像为128x128每个通道的6x6特征展平为36维向量对每个通道独立应用全连接层class ChannelWiseFC(nn.Module): def __init__(self, in_channels128, feat_size6): super().__init__() self.fc nn.Linear(feat_size*feat_size, feat_size*feat_size) self.in_channels in_channels self.feat_size feat_size def forward(self, x): bs, c, h, w x.shape x x.view(bs*c, h*w) # 展平每个通道 x self.fc(x) return x.view(bs, c, h, w) # 恢复原始维度与传统全连接层的参数对比连接类型输入维度输出维度参数量传统FC128x6x6460846084608x4608≈21M通道FC128个36维向量128个36维向量128x36x36≈165K参数减少约128倍这是模型能够实际训练的关键。3. 解码器设计与上卷积技巧解码器负责将压缩的特征表示恢复为完整图像其核心是五个上卷积转置卷积层。这里最容易出现的问题是棋盘效应需要通过精心设计核大小和步长来避免。class Decoder(nn.Module): def __init__(self): super().__init__() self.layers nn.Sequential( nn.ConvTranspose2d(128, 256, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(256, 384, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(384, 384, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(384, 256, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(256, 3, kernel_size4, stride2, padding1), nn.Tanh() ) def forward(self, x): return self.layers(x)上卷积层配置要点使用4x4核配合stride2实现2倍上采样每层后接ReLU激活最后一层用Tanh输出通道数镜像编码器的收缩过程使用padding1保持空间尺寸精确计算提示转置卷积容易产生不均匀重叠建议在关键层后添加PixelShuffle或插值卷积的替代方案来减轻伪影。4. 损失函数组合与训练策略Context Encoder使用重构损失L2和对抗损失的组合这是获得高质量修复效果的关键。我们需要分别实现这两个损失并设计合理的加权策略。4.1 重构损失实现重构损失确保修复区域与周围内容的结构一致性def reconstruction_loss(pred, target, mask): # pred: 预测图像 [B,3,H,W] # target: 真实图像 [B,3,H,W] # mask: 二值掩码 [B,1,H,W], 1表示缺失区域 diff (pred - target) * mask return torch.mean(diff**2)4.2 对抗损失集成对抗损失来自辅助的判别器网络帮助生成更真实的细节class Discriminator(nn.Module): def __init__(self): super().__init__() self.main nn.Sequential( nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2), nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 1, 4, 1, 0), nn.Sigmoid() ) def adversarial_loss(discriminator, pred, real): real_loss torch.log(discriminator(real)) fake_loss torch.log(1 - discriminator(pred.detach())) return -(torch.mean(real_loss) torch.mean(fake_loss))4.3 联合训练流程训练时需要交替优化编码器和判别器# 初始化模型 encoder Encoder() channel_fc ChannelWiseFC() decoder Decoder() discriminator Discriminator() # 优化器设置 gen_optimizer torch.optim.Adam( list(encoder.parameters()) list(channel_fc.parameters()) list(decoder.parameters()), lr0.0002) disc_optimizer torch.optim.Adam(discriminator.parameters(), lr0.0002) for epoch in range(100): for images, masks in dataloader: # 生成修复图像 features encoder(images) features channel_fc(features) outputs decoder(features) # 判别器训练 disc_loss adversarial_loss(discriminator, outputs, images) disc_optimizer.zero_grad() disc_loss.backward() disc_optimizer.step() # 生成器训练 recon_loss reconstruction_loss(outputs, images, masks) adv_loss -torch.log(discriminator(outputs)) total_loss 0.999*recon_loss 0.001*adv_loss gen_optimizer.zero_grad() total_loss.backward() gen_optimizer.step()注意对抗损失的权重系数(0.001)需要小心调整初期可先设为0纯用L2损失预热。5. 数据准备与掩码生成技巧Paris StreetView和ImageNet是原文使用的两个主要数据集。我们需要实现两种掩码生成策略中心矩形掩码和随机形状掩码。5.1 中心矩形掩码生成def generate_center_mask(batch_size, height, width, margin0.25): 生成中心矩形掩码 h_margin int(height * margin) w_margin int(width * margin) mask torch.ones(batch_size, 1, height, width) mask[:, :, h_margin:-h_margin, w_margin:-w_margin] 0 return mask5.2 随机形状掩码生成def generate_random_mask(batch_size, height, width, max_holes5, max_size0.3): 生成随机形状掩码 mask torch.zeros(batch_size, 1, height, width) for i in range(batch_size): num_holes random.randint(1, max_holes) for _ in range(num_holes): hole_size random.uniform(0.1, max_size) hole_h int(height * hole_size) hole_w int(width * hole_size) x random.randint(0, width - hole_w) y random.randint(0, height - hole_h) mask[i, :, y:yhole_h, x:xhole_w] 1 return mask5.3 数据增强策略为提高模型鲁棒性建议在训练时应用以下增强随机水平翻转颜色抖动亮度、对比度、饱和度微调小角度旋转±10度以内随机裁剪保持原始分辨率from torchvision import transforms train_transform transforms.Compose([ transforms.Resize(128), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.1, 0.1, 0.1), transforms.RandomRotation(10), transforms.RandomCrop(128), transforms.ToTensor(), transforms.Normalize(mean[0.5,0.5,0.5], std[0.5,0.5,0.5]) ])6. 训练技巧与问题调试在实际训练Context Encoder时有几个常见陷阱需要特别注意6.1 边缘模糊问题缓解原文提到的边缘模糊问题主要源于对抗损失仅作用于缺失区域L2损失的均值倾向解决方案在判别器输入中拼接完整图像而不仅是修复区域采用感知损失替代纯L2损失添加边缘一致性损失项def edge_aware_loss(pred, target, mask, sigma10): 边缘感知损失 # 计算图像梯度 pred_gray 0.299*pred[:,0] 0.587*pred[:,1] 0.114*pred[:,2] target_gray 0.299*target[:,0] 0.587*target[:,1] 0.114*target[:,2] pred_grad torch.abs(pred_gray[:,1:,:] - pred_gray[:,:-1,:]) \ torch.abs(pred_gray[:,:,1:] - pred_gray[:,:,:-1]) target_grad torch.abs(target_gray[:,1:,:] - target_gray[:,:-1,:]) \ torch.abs(target_gray[:,:,1:] - target_gray[:,:,:-1]) # 计算权重 weights torch.exp(-sigma * target_grad) loss torch.mean(weights * (pred_grad - target_grad)**2) return loss6.2 训练不稳定对策对抗训练常见问题及解决方法问题现象可能原因解决方案生成器输出全黑/全白判别器过强降低判别器学习率减少更新频率修复区域出现噪声对抗损失权重过大动态调整权重从0逐渐增加模式崩溃判别器过弱添加梯度惩罚(WGAN-GP)颜色偏差L2损失主导添加感知损失或VGG特征损失6.3 学习率调度策略推荐使用循环学习率(Cyclic LR)配合余弦退火from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR scheduler_gen CyclicLR( gen_optimizer, base_lr1e-5, max_lr2e-4, step_size_up2000, cycle_momentumFalse ) scheduler_disc CosineAnnealingLR( disc_optimizer, T_max10, eta_min1e-6 )7. 模型评估与效果展示完整的评估流程应包括定量指标和视觉质量评估两方面。7.1 定量评估指标指标名称计算公式意义PSNR$10 \cdot \log_{10}(\frac{MAX_I^2}{MSE})$峰值信噪比值越大越好SSIM$\frac{(2\mu_x\mu_y c_1)(2\sigma_{xy} c_2)}{(\mu_x^2 \mu_y^2 c_1)(\sigma_x^2 \sigma_y^2 c_2)}$结构相似性范围[0,1]FID$|\mu_1 - \mu_2|^2 Tr(\Sigma_1 \Sigma_2 - 2(\Sigma_1\Sigma_2)^{1/2})$特征分布距离越小越好实现示例from piq import psnr, ssim, fid def evaluate(model, test_loader): psnr_values [] ssim_values [] real_features [] pred_features [] with torch.no_grad(): for img, mask in test_loader: output model(img) # 计算PSNR/SSIM仅针对修复区域 psnr_val psnr(output*mask, img*mask, data_range1.0) ssim_val ssim(output*mask, img*mask, data_range1.0) psnr_values.append(psnr_val) ssim_values.append(ssim_val) # 收集FID特征 real_features.append(fid._compute_feats(img)) pred_features.append(fid._compute_feats(output)) fid_score fid._compute_fid( torch.cat(real_features), torch.cat(pred_features) ) return { PSNR: torch.mean(torch.stack(psnr_values)), SSIM: torch.mean(torch.stack(ssim_values)), FID: fid_score }7.2 效果可视化建议将以下内容并排显示以便对比原始图像掩码图像缺失区域显示为黑色模型修复结果真实完整图像如有import matplotlib.pyplot as plt def visualize_results(images, masks, outputs, num_samples4): plt.figure(figsize(15, 10)) for i in range(num_samples): # 原始图像 plt.subplot(num_samples, 4, i*41) plt.imshow(images[i].permute(1,2,0).cpu().numpy()*0.50.5) # 掩码图像 plt.subplot(num_samples, 4, i*42) masked images[i] * (1 - masks[i]) plt.imshow(masked.permute(1,2,0).cpu().numpy()*0.50.5) # 修复结果 plt.subplot(num_samples, 4, i*43) comp images[i] * (1 - masks[i]) outputs[i] * masks[i] plt.imshow(comp.permute(1,2,0).cpu().numpy()*0.50.5) # 真实图像如有 plt.subplot(num_samples, 4, i*44) plt.imshow(images[i].permute(1,2,0).cpu().numpy()*0.50.5) plt.tight_layout() plt.show()在实际项目中修复128x128图像中心64x64区域时预期PSNR应达到25dB以上SSIM超过0.85这表明修复区域与周围内容在结构和纹理上都具有良好的一致性。对于更复杂的随机掩码场景这些指标会有所下降但通过调整损失权重和训练策略仍可获得视觉上令人满意的结果。
CVPR 2016经典论文复现:手把手教你用PyTorch实现Context Encoder图像修复
发布时间:2026/5/28 16:38:03
用PyTorch实战CVPR 2016经典Context Encoder图像修复全流程解析当你在巴黎街头拍摄的照片中央出现一块碍眼的污渍或是老照片的某个角落因年代久远而破损时是否想过用AI技术让这些缺失的部分重获新生2016年CVPR会议上提出的Context Encoder正是解决这类图像修复问题的里程碑式工作。不同于传统的补丁复制方法它通过深度学习实现了真正的内容生成。本文将带你从零开始用PyTorch完整复现这一经典算法并分享实际训练中的关键技巧。1. 环境准备与核心架构解析在开始编写代码前我们需要明确Context Encoder的三大核心组件基于AlexNet的编码器、创新的通道全连接层Channel-wise FC以及上卷积构成的解码器。这个结构看似简单却蕴含着几个精妙的设计选择。首先创建Python环境并安装必要依赖conda create -n context_encoder python3.8 conda activate context_encoder pip install torch1.12.0 torchvision0.13.0 pillow9.2.0编码器部分采用AlexNet的前五个卷积层但需要注意三个关键调整点移除原始AlexNet的全连接层和分类头所有卷积层采用随机初始化而非预训练权重在conv5后添加额外的1x1卷积进行特征压缩import torch.nn as nn class Encoder(nn.Module): def __init__(self): super().__init__() self.features nn.Sequential( nn.Conv2d(3, 96, kernel_size11, stride4, padding2), # conv1 nn.ReLU(inplaceTrue), nn.MaxPool2d(kernel_size3, stride2), nn.Conv2d(96, 256, kernel_size5, padding2), # conv2 nn.ReLU(inplaceTrue), nn.MaxPool2d(kernel_size3, stride2), nn.Conv2d(256, 384, kernel_size3, padding1), # conv3 nn.ReLU(inplaceTrue), nn.Conv2d(384, 384, kernel_size3, padding1), # conv4 nn.ReLU(inplaceTrue), nn.Conv2d(384, 256, kernel_size3, padding1), # conv5 nn.ReLU(inplaceTrue), nn.Conv2d(256, 128, kernel_size1) # 特征压缩 )2. 通道全连接层的创新实现原文最具创新性的部分是通道全连接层Channel-wise FC它解决了传统全连接层的参数爆炸问题。具体实现时需要理解其数学本质对特征图的每个通道独立进行全连接操作相当于一组平行的1x1卷积。关键实现细节输入特征图尺寸128x6x6假设输入图像为128x128每个通道的6x6特征展平为36维向量对每个通道独立应用全连接层class ChannelWiseFC(nn.Module): def __init__(self, in_channels128, feat_size6): super().__init__() self.fc nn.Linear(feat_size*feat_size, feat_size*feat_size) self.in_channels in_channels self.feat_size feat_size def forward(self, x): bs, c, h, w x.shape x x.view(bs*c, h*w) # 展平每个通道 x self.fc(x) return x.view(bs, c, h, w) # 恢复原始维度与传统全连接层的参数对比连接类型输入维度输出维度参数量传统FC128x6x6460846084608x4608≈21M通道FC128个36维向量128个36维向量128x36x36≈165K参数减少约128倍这是模型能够实际训练的关键。3. 解码器设计与上卷积技巧解码器负责将压缩的特征表示恢复为完整图像其核心是五个上卷积转置卷积层。这里最容易出现的问题是棋盘效应需要通过精心设计核大小和步长来避免。class Decoder(nn.Module): def __init__(self): super().__init__() self.layers nn.Sequential( nn.ConvTranspose2d(128, 256, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(256, 384, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(384, 384, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(384, 256, kernel_size4, stride2, padding1), nn.ReLU(), nn.ConvTranspose2d(256, 3, kernel_size4, stride2, padding1), nn.Tanh() ) def forward(self, x): return self.layers(x)上卷积层配置要点使用4x4核配合stride2实现2倍上采样每层后接ReLU激活最后一层用Tanh输出通道数镜像编码器的收缩过程使用padding1保持空间尺寸精确计算提示转置卷积容易产生不均匀重叠建议在关键层后添加PixelShuffle或插值卷积的替代方案来减轻伪影。4. 损失函数组合与训练策略Context Encoder使用重构损失L2和对抗损失的组合这是获得高质量修复效果的关键。我们需要分别实现这两个损失并设计合理的加权策略。4.1 重构损失实现重构损失确保修复区域与周围内容的结构一致性def reconstruction_loss(pred, target, mask): # pred: 预测图像 [B,3,H,W] # target: 真实图像 [B,3,H,W] # mask: 二值掩码 [B,1,H,W], 1表示缺失区域 diff (pred - target) * mask return torch.mean(diff**2)4.2 对抗损失集成对抗损失来自辅助的判别器网络帮助生成更真实的细节class Discriminator(nn.Module): def __init__(self): super().__init__() self.main nn.Sequential( nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2), nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 1, 4, 1, 0), nn.Sigmoid() ) def adversarial_loss(discriminator, pred, real): real_loss torch.log(discriminator(real)) fake_loss torch.log(1 - discriminator(pred.detach())) return -(torch.mean(real_loss) torch.mean(fake_loss))4.3 联合训练流程训练时需要交替优化编码器和判别器# 初始化模型 encoder Encoder() channel_fc ChannelWiseFC() decoder Decoder() discriminator Discriminator() # 优化器设置 gen_optimizer torch.optim.Adam( list(encoder.parameters()) list(channel_fc.parameters()) list(decoder.parameters()), lr0.0002) disc_optimizer torch.optim.Adam(discriminator.parameters(), lr0.0002) for epoch in range(100): for images, masks in dataloader: # 生成修复图像 features encoder(images) features channel_fc(features) outputs decoder(features) # 判别器训练 disc_loss adversarial_loss(discriminator, outputs, images) disc_optimizer.zero_grad() disc_loss.backward() disc_optimizer.step() # 生成器训练 recon_loss reconstruction_loss(outputs, images, masks) adv_loss -torch.log(discriminator(outputs)) total_loss 0.999*recon_loss 0.001*adv_loss gen_optimizer.zero_grad() total_loss.backward() gen_optimizer.step()注意对抗损失的权重系数(0.001)需要小心调整初期可先设为0纯用L2损失预热。5. 数据准备与掩码生成技巧Paris StreetView和ImageNet是原文使用的两个主要数据集。我们需要实现两种掩码生成策略中心矩形掩码和随机形状掩码。5.1 中心矩形掩码生成def generate_center_mask(batch_size, height, width, margin0.25): 生成中心矩形掩码 h_margin int(height * margin) w_margin int(width * margin) mask torch.ones(batch_size, 1, height, width) mask[:, :, h_margin:-h_margin, w_margin:-w_margin] 0 return mask5.2 随机形状掩码生成def generate_random_mask(batch_size, height, width, max_holes5, max_size0.3): 生成随机形状掩码 mask torch.zeros(batch_size, 1, height, width) for i in range(batch_size): num_holes random.randint(1, max_holes) for _ in range(num_holes): hole_size random.uniform(0.1, max_size) hole_h int(height * hole_size) hole_w int(width * hole_size) x random.randint(0, width - hole_w) y random.randint(0, height - hole_h) mask[i, :, y:yhole_h, x:xhole_w] 1 return mask5.3 数据增强策略为提高模型鲁棒性建议在训练时应用以下增强随机水平翻转颜色抖动亮度、对比度、饱和度微调小角度旋转±10度以内随机裁剪保持原始分辨率from torchvision import transforms train_transform transforms.Compose([ transforms.Resize(128), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.1, 0.1, 0.1), transforms.RandomRotation(10), transforms.RandomCrop(128), transforms.ToTensor(), transforms.Normalize(mean[0.5,0.5,0.5], std[0.5,0.5,0.5]) ])6. 训练技巧与问题调试在实际训练Context Encoder时有几个常见陷阱需要特别注意6.1 边缘模糊问题缓解原文提到的边缘模糊问题主要源于对抗损失仅作用于缺失区域L2损失的均值倾向解决方案在判别器输入中拼接完整图像而不仅是修复区域采用感知损失替代纯L2损失添加边缘一致性损失项def edge_aware_loss(pred, target, mask, sigma10): 边缘感知损失 # 计算图像梯度 pred_gray 0.299*pred[:,0] 0.587*pred[:,1] 0.114*pred[:,2] target_gray 0.299*target[:,0] 0.587*target[:,1] 0.114*target[:,2] pred_grad torch.abs(pred_gray[:,1:,:] - pred_gray[:,:-1,:]) \ torch.abs(pred_gray[:,:,1:] - pred_gray[:,:,:-1]) target_grad torch.abs(target_gray[:,1:,:] - target_gray[:,:-1,:]) \ torch.abs(target_gray[:,:,1:] - target_gray[:,:,:-1]) # 计算权重 weights torch.exp(-sigma * target_grad) loss torch.mean(weights * (pred_grad - target_grad)**2) return loss6.2 训练不稳定对策对抗训练常见问题及解决方法问题现象可能原因解决方案生成器输出全黑/全白判别器过强降低判别器学习率减少更新频率修复区域出现噪声对抗损失权重过大动态调整权重从0逐渐增加模式崩溃判别器过弱添加梯度惩罚(WGAN-GP)颜色偏差L2损失主导添加感知损失或VGG特征损失6.3 学习率调度策略推荐使用循环学习率(Cyclic LR)配合余弦退火from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR scheduler_gen CyclicLR( gen_optimizer, base_lr1e-5, max_lr2e-4, step_size_up2000, cycle_momentumFalse ) scheduler_disc CosineAnnealingLR( disc_optimizer, T_max10, eta_min1e-6 )7. 模型评估与效果展示完整的评估流程应包括定量指标和视觉质量评估两方面。7.1 定量评估指标指标名称计算公式意义PSNR$10 \cdot \log_{10}(\frac{MAX_I^2}{MSE})$峰值信噪比值越大越好SSIM$\frac{(2\mu_x\mu_y c_1)(2\sigma_{xy} c_2)}{(\mu_x^2 \mu_y^2 c_1)(\sigma_x^2 \sigma_y^2 c_2)}$结构相似性范围[0,1]FID$|\mu_1 - \mu_2|^2 Tr(\Sigma_1 \Sigma_2 - 2(\Sigma_1\Sigma_2)^{1/2})$特征分布距离越小越好实现示例from piq import psnr, ssim, fid def evaluate(model, test_loader): psnr_values [] ssim_values [] real_features [] pred_features [] with torch.no_grad(): for img, mask in test_loader: output model(img) # 计算PSNR/SSIM仅针对修复区域 psnr_val psnr(output*mask, img*mask, data_range1.0) ssim_val ssim(output*mask, img*mask, data_range1.0) psnr_values.append(psnr_val) ssim_values.append(ssim_val) # 收集FID特征 real_features.append(fid._compute_feats(img)) pred_features.append(fid._compute_feats(output)) fid_score fid._compute_fid( torch.cat(real_features), torch.cat(pred_features) ) return { PSNR: torch.mean(torch.stack(psnr_values)), SSIM: torch.mean(torch.stack(ssim_values)), FID: fid_score }7.2 效果可视化建议将以下内容并排显示以便对比原始图像掩码图像缺失区域显示为黑色模型修复结果真实完整图像如有import matplotlib.pyplot as plt def visualize_results(images, masks, outputs, num_samples4): plt.figure(figsize(15, 10)) for i in range(num_samples): # 原始图像 plt.subplot(num_samples, 4, i*41) plt.imshow(images[i].permute(1,2,0).cpu().numpy()*0.50.5) # 掩码图像 plt.subplot(num_samples, 4, i*42) masked images[i] * (1 - masks[i]) plt.imshow(masked.permute(1,2,0).cpu().numpy()*0.50.5) # 修复结果 plt.subplot(num_samples, 4, i*43) comp images[i] * (1 - masks[i]) outputs[i] * masks[i] plt.imshow(comp.permute(1,2,0).cpu().numpy()*0.50.5) # 真实图像如有 plt.subplot(num_samples, 4, i*44) plt.imshow(images[i].permute(1,2,0).cpu().numpy()*0.50.5) plt.tight_layout() plt.show()在实际项目中修复128x128图像中心64x64区域时预期PSNR应达到25dB以上SSIM超过0.85这表明修复区域与周围内容在结构和纹理上都具有良好的一致性。对于更复杂的随机掩码场景这些指标会有所下降但通过调整损失权重和训练策略仍可获得视觉上令人满意的结果。