深度学习落地实战:基于UNet实现血管瘤超声图像分割

avatar
作者
筋斗云
阅读量:1

前言

大家好,我是机长

本专栏将持续收集整理市场上深度学习的相关项目,旨在为准备从事深度学习工作或相关科研活动的伙伴,储备、提升更多的实际开发经验,每个项目实例都可作为实际开发项目写入简历,且都附带完整的代码与数据集。可通过百度云盘进行获取,实现开箱即用

正在跟新中~

项目背景

(基于UNet实现血管瘤超声图像分割)

血管瘤是一种常见的血管异常疾病,尤其在婴幼儿中具有较高的发病率。超声检查作为一种无创的诊断手段,能够为临床提供血管瘤的位置、形状及累及范围等重要信息,对指导医生进一步治疗至关重要。然而,目前血管瘤病灶的分割主要依赖于专家的人工勾画,这不仅耗时耗力,还容易受到临床经验水平的影响,导致分割结果存在人为误差。因此,利用深度学习技术实现血管瘤超声图像的自动精准分割,成为了一个医学应用与研究的热点方向方向。

项目环境

  • 平台:windows 10
  • 语言环境:python 3.8
  • 编辑器:PyCharm
  • PyThorch版本:1.8

1.创建并跳转到虚拟环境

python -m venv myenv  myenv\Scripts\activate.bat 

2. 虚拟环境pip命令安装其他工具包

pip install torch torchvision torchaudio

注:此处只示范安装pytorch,其他工具包安装类似,可通过运行代码查看所确实包提示进行安装

3.pycharm 运行环境配置

进入pytcharm =》点击file =》点击settings=》点击Project:...=》点击 Python Interpreter,进入如下界面

点击add =》点击Existing environment  =》 点击 ... =》选择第一步1创建虚拟环境目录myenv\Scripts\下的python.exe文件点击ok完成环境配置

数据集介绍

数据集分分训练数据集与标签数据集

                

       训练数据样式                                                               标注数据样式

训练数据获取:

私信博主获取

UNet网络介绍

U-Net网络的以其独特的U型结构著称,这种结构由编码器(Encoder)和解码器(Decoder)两大部分组成,非常适合于医学图像分割等任务。下面我将进一步解释您提到的关键点,并补充一些细节。

Encoder(编码器)
  • 结构:编码器的左半部分主要负责特征提取。它通过多个卷积层(通常是3x3的卷积核)和ReLU激活函数来逐层提取图像的特征。在每个卷积层组合之后,通常会添加一个2x2的最大池化层(Max Pooling)来降低特征图的分辨率,并增加感受野。这种下采样操作有助于捕获图像中的全局信息。
  • 作用:编码器通过逐层提取和抽象化图像特征,为后续的分割任务提供丰富的信息。
Decoder(解码器)
  • 结构:解码器的右半部分则负责将编码器提取的特征图恢复到原始图像的分辨率,以便进行像素级别的分类。它首先通过上采样(通常是转置卷积或双线性插值)来增加特征图的尺寸,然后通过特征拼接(Concatenation)将上采样后的特征图与编码器对应层级的特征图进行融合。之后,再通过卷积层和ReLU激活函数进一步处理融合后的特征图。
  • 特征拼接:与FCN(全卷积网络)的逐点相加不同,U-Net使用特征拼接来融合不同层级的特征图。这种拼接方式能够保留更多的特征信息,形成更“厚”的特征图,但同时也需要更多的显存来存储这些特征。
  • 作用:解码器通过逐步上采样和特征融合,将编码器提取的高级特征与低级特征相结合,恢复出精细的图像分割结果。
优点与挑战
  • 优点:U-Net网络结构简洁,但性能强大,特别适用于医学图像分割等任务。其独特的U型结构和特征拼接方式使得网络能够同时捕获全局和局部特征,从而实现高精度的分割。
  • 挑战:尽管U-Net在性能上表现出色,但其对显存的需求也相对较高。特别是当处理高分辨率图像或进行大规模训练时,显存消耗可能会成为限制因素。此外,网络的设计和训练也需要大量的专业知识和经验。

综上所述,U-Net网络通过其独特的编码器-解码器结构和特征拼接方式,在医学图像分割等领域取得了显著成效。然而,在实际应用中仍需注意显存消耗等挑战,并不断探索更高效的优化方法。

pytorch实现UNet网络

class double_conv2d_bn(nn.Module):     def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):         super(double_conv2d_bn, self).__init__()         self.conv1 = nn.Conv2d(in_channels, out_channels,                                kernel_size=kernel_size,                                stride=strides, padding=padding, bias=True)         self.conv2 = nn.Conv2d(out_channels, out_channels,                                kernel_size=kernel_size,                                stride=strides, padding=padding, bias=True)         self.bn1 = nn.BatchNorm2d(out_channels)         self.bn2 = nn.BatchNorm2d(out_channels)      def forward(self, x):         out = F.relu(self.bn1(self.conv1(x)))         out = F.relu(self.bn2(self.conv2(out)))         return out   class deconv2d_bn(nn.Module):     def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):         super(deconv2d_bn, self).__init__()         self.conv1 = nn.ConvTranspose2d(in_channels, out_channels,                                         kernel_size=kernel_size,                                         stride=strides, bias=True)         self.bn1 = nn.BatchNorm2d(out_channels)      def forward(self, x):         out = F.relu(self.bn1(self.conv1(x)))         return out   class Unet(nn.Module):     def __init__(self):         super(Unet, self).__init__()         self.layer1_conv = double_conv2d_bn(1, 8)         self.layer2_conv = double_conv2d_bn(8, 16)         self.layer3_conv = double_conv2d_bn(16, 32)         self.layer4_conv = double_conv2d_bn(32, 64)         self.layer5_conv = double_conv2d_bn(64, 128)         self.layer6_conv = double_conv2d_bn(128, 64)         self.layer7_conv = double_conv2d_bn(64, 32)         self.layer8_conv = double_conv2d_bn(32, 16)         self.layer9_conv = double_conv2d_bn(16, 8)         self.layer10_conv = nn.Conv2d(8, 1, kernel_size=3,                                       stride=1, padding=1, bias=True)          self.deconv1 = deconv2d_bn(128, 64)         self.deconv2 = deconv2d_bn(64, 32)         self.deconv3 = deconv2d_bn(32, 16)         self.deconv4 = deconv2d_bn(16, 8)          self.sigmoid = nn.Sigmoid()      def forward(self, x):         conv1 = self.layer1_conv(x)         pool1 = F.max_pool2d(conv1, 2)          conv2 = self.layer2_conv(pool1)         pool2 = F.max_pool2d(conv2, 2)          conv3 = self.layer3_conv(pool2)         pool3 = F.max_pool2d(conv3, 2)          conv4 = self.layer4_conv(pool3)         pool4 = F.max_pool2d(conv4, 2)          conv5 = self.layer5_conv(pool4)          convt1 = self.deconv1(conv5)         concat1 = torch.cat([convt1, conv4], dim=1)         conv6 = self.layer6_conv(concat1)          convt2 = self.deconv2(conv6)         concat2 = torch.cat([convt2, conv3], dim=1)         conv7 = self.layer7_conv(concat2)          convt3 = self.deconv3(conv7)         concat3 = torch.cat([convt3, conv2], dim=1)         conv8 = self.layer8_conv(concat3)          convt4 = self.deconv4(conv8)         concat4 = torch.cat([convt4, conv1], dim=1)         conv9 = self.layer9_conv(concat4)         outp = self.layer10_conv(conv9)         outp = self.sigmoid(outp)         return outp 

自定义加载数据集

class LiverDataset(data.Dataset):     def __init__(self, root, transform=None, target_transform=None, mode='train'):         n = len(os.listdir(root + '/images'))          imgs = []          if mode == 'train':             for i in range(n):                 img = os.path.join(root, 'images',                                    "%d.png" % (i + 1))                 mask = os.path.join(root, 'mask', "%d.png" % (i + 1))                 imgs.append([img, mask])         else:             for i in range(n):                 img = os.path.join(root, 'images',                                    "%d.png" % (i + 1))                 mask = os.path.join(root, 'mask', "%d.png" % (i + 1))                 imgs.append([img, mask])          self.imgs = imgs         self.transform = transform         self.target_transform = target_transform      def __getitem__(self, index):         x_path, y_path = self.imgs[index]         img_x = Image.open(x_path)         img_y = Image.open(y_path)         if self.transform is not None:             img_x = self.transform(img_x)         if self.target_transform is not None:             img_y = self.target_transform(img_y)         return img_x, img_y      def __len__(self):         return len(self.imgs) 

完整代码

import os  import PIL.Image as Image import matplotlib.pyplot as plt import matplotlib matplotlib.use('TkAgg') #module 'backend_interagg' has no attribute 'FigureCanvas',报错执行pip install matplotlib==3.5.0 import numpy as np #关于numpy 报错时执行 pip install numpy==1.23.5 import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data from torchvision import transforms from tqdm import tqdm  epochs = 1 batch_size = 32 device = 'cpu' best_model = None best_loss = 999 save_path = './checkpoints/best_model.pkl' directory_path = './checkpoints'  # 检查目录是否存在 if not os.path.exists(directory_path):     # 如果目录不存在,则创建它     os.makedirs(directory_path)   class double_conv2d_bn(nn.Module):     def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):         super(double_conv2d_bn, self).__init__()         self.conv1 = nn.Conv2d(in_channels, out_channels,                                kernel_size=kernel_size,                                stride=strides, padding=padding, bias=True)         self.conv2 = nn.Conv2d(out_channels, out_channels,                                kernel_size=kernel_size,                                stride=strides, padding=padding, bias=True)         self.bn1 = nn.BatchNorm2d(out_channels)         self.bn2 = nn.BatchNorm2d(out_channels)      def forward(self, x):         out = F.relu(self.bn1(self.conv1(x)))         out = F.relu(self.bn2(self.conv2(out)))         return out   class deconv2d_bn(nn.Module):     def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):         super(deconv2d_bn, self).__init__()         self.conv1 = nn.ConvTranspose2d(in_channels, out_channels,                                         kernel_size=kernel_size,                                         stride=strides, bias=True)         self.bn1 = nn.BatchNorm2d(out_channels)      def forward(self, x):         out = F.relu(self.bn1(self.conv1(x)))         return out   class Unet(nn.Module):     def __init__(self):         super(Unet, self).__init__()         self.layer1_conv = double_conv2d_bn(1, 8)         self.layer2_conv = double_conv2d_bn(8, 16)         self.layer3_conv = double_conv2d_bn(16, 32)         self.layer4_conv = double_conv2d_bn(32, 64)         self.layer5_conv = double_conv2d_bn(64, 128)         self.layer6_conv = double_conv2d_bn(128, 64)         self.layer7_conv = double_conv2d_bn(64, 32)         self.layer8_conv = double_conv2d_bn(32, 16)         self.layer9_conv = double_conv2d_bn(16, 8)         self.layer10_conv = nn.Conv2d(8, 1, kernel_size=3,                                       stride=1, padding=1, bias=True)          self.deconv1 = deconv2d_bn(128, 64)         self.deconv2 = deconv2d_bn(64, 32)         self.deconv3 = deconv2d_bn(32, 16)         self.deconv4 = deconv2d_bn(16, 8)          self.sigmoid = nn.Sigmoid()      def forward(self, x):         conv1 = self.layer1_conv(x)         pool1 = F.max_pool2d(conv1, 2)          conv2 = self.layer2_conv(pool1)         pool2 = F.max_pool2d(conv2, 2)          conv3 = self.layer3_conv(pool2)         pool3 = F.max_pool2d(conv3, 2)          conv4 = self.layer4_conv(pool3)         pool4 = F.max_pool2d(conv4, 2)          conv5 = self.layer5_conv(pool4)          convt1 = self.deconv1(conv5)         concat1 = torch.cat([convt1, conv4], dim=1)         conv6 = self.layer6_conv(concat1)          convt2 = self.deconv2(conv6)         concat2 = torch.cat([convt2, conv3], dim=1)         conv7 = self.layer7_conv(concat2)          convt3 = self.deconv3(conv7)         concat3 = torch.cat([convt3, conv2], dim=1)         conv8 = self.layer8_conv(concat3)          convt4 = self.deconv4(conv8)         concat4 = torch.cat([convt4, conv1], dim=1)         conv9 = self.layer9_conv(concat4)         outp = self.layer10_conv(conv9)         outp = self.sigmoid(outp)         return outp   class LiverDataset(data.Dataset):     def __init__(self, root, transform=None, target_transform=None, mode='train'):         n = len(os.listdir(root + '/images'))          imgs = []          if mode == 'train':             for i in range(n):                 img = os.path.join(root, 'images',                                    "%d.png" % (i + 1))                 mask = os.path.join(root, 'mask', "%d.png" % (i + 1))                 imgs.append([img, mask])         else:             for i in range(n):                 img = os.path.join(root, 'images',                                    "%d.png" % (i + 1))                 mask = os.path.join(root, 'mask', "%d.png" % (i + 1))                 imgs.append([img, mask])          self.imgs = imgs         self.transform = transform         self.target_transform = target_transform      def __getitem__(self, index):         x_path, y_path = self.imgs[index]         img_x = Image.open(x_path)         img_y = Image.open(y_path)         if self.transform is not None:             img_x = self.transform(img_x)         if self.target_transform is not None:             img_y = self.target_transform(img_y)         return img_x, img_y      def __len__(self):         return len(self.imgs)   x_transform = transforms.Compose([     transforms.ToTensor(),     transforms.RandomResizedCrop(224),     transforms.Grayscale(num_output_channels=1) ])  y_transform = transforms.Compose([     transforms.ToTensor(),     transforms.RandomResizedCrop(224) ])  liver_dataset = LiverDataset("./data/train", transform=x_transform, target_transform=y_transform) dataloader = torch.utils.data.DataLoader(liver_dataset, batch_size=batch_size, shuffle=True)  model = Unet() criterion = torch.nn.BCELoss() #criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters())  for epoch in range(epochs):     model.train()     epoch_loss = 0     train_bar = tqdm(dataloader)     for x, y in train_bar:         optimizer.zero_grad()         inputs = x.to(device)         labels = y.to(device)         outputs = model(inputs)         labels=labels.bool().float()         loss = criterion(outputs, labels)         loss.backward()         optimizer.step()         epoch_loss += loss.item()      print("【EPOCH: 】%s" % str(epoch + 1))     print("训练损失为%s" % str(epoch_loss))      if epoch_loss < best_loss:         best_loss = epoch_loss         best_model = model.state_dict()      # 在训练结束保存最优的模型参数     if epoch == epochs - 1:         # 保存模型         torch.save(best_model, save_path)  print('Finished Training')  plt.figure('测试一张图片') pil_img = Image.open('./data/train/images/1.png') np_img = np.array(pil_img) plt.imshow(np_img) plt.show()  plt.figure('测试一张蒙版图片') pil_img = Image.open('./data/train/mask/1.png') np_img = np.array(pil_img) plt.imshow(np_img) plt.show() 

开启训练

训练结果展示

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!