【PyTorch】多对象分割项目

avatar
作者
筋斗云
阅读量:0

 【PyTorch】单对象分割项目

对象分割任务的目标是找到图像中目标对象的边界。实际应用例如自动驾驶汽车和医学成像分析。这里将使用PyTorch开发一个深度学习模型来完成多对象分割任务。多对象分割的主要目标是自动勾勒出图像中多个目标对象的边界。

对象的边界通常由与图像大小相同的分割掩码定义,在分割掩码中属于目标对象的所有像素基于预定义的标记被标记为相同。

目录

创建数据集

创建数据加载器

创建模型

部署模型

定义损失函数和优化器

训练和验证模型


创建数据集

from torchvision.datasets import VOCSegmentation from PIL import Image    from torchvision.transforms.functional import to_tensor, to_pil_image  class myVOCSegmentation(VOCSegmentation):     def __getitem__(self, index):         img = Image.open(self.images[index]).convert('RGB')         target = Image.open(self.masks[index])          if self.transforms is not None:             augmented= self.transforms(image=np.array(img), mask=np.array(target))             img = augmented['image']             target = augmented['mask']                               target[target>20]=0          img= to_tensor(img)                     target= torch.from_numpy(target).type(torch.long)         return img, target  from albumentations import (     HorizontalFlip,     Compose,     Resize,     Normalize)  mean = [0.485, 0.456, 0.406]  std = [0.229, 0.224, 0.225] h,w=520,520  transform_train = Compose([ Resize(h,w),                 HorizontalFlip(p=0.5),                  Normalize(mean=mean,std=std)])  transform_val = Compose( [ Resize(h,w),                           Normalize(mean=mean,std=std)])              path2data="./data/"     train_ds=myVOCSegmentation(path2data,                  year='2012',                  image_set='train',                  download=False,                  transforms=transform_train)  print(len(train_ds))    val_ds=myVOCSegmentation(path2data,                  year='2012',                  image_set='val',                  download=False,                  transforms=transform_val) print(len(val_ds))   
import torch import numpy as np from skimage.segmentation import mark_boundaries import matplotlib.pylab as plt %matplotlib inline np.random.seed(0) num_classes=21 COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")  def show_img_target(img, target):     if torch.is_tensor(img):         img=to_pil_image(img)         target=target.numpy()     for ll in range(num_classes):         mask=(target==ll)         img=mark_boundaries(np.array(img) ,                              mask,                             outline_color=COLORS[ll],                             color=COLORS[ll])     plt.imshow(img)   def re_normalize (x, mean = mean, std= std):     x_r= x.clone()     for c, (mean_c, std_c) in enumerate(zip(mean, std)):         x_r [c] *= std_c         x_r [c] += mean_c     return x_r

 展示训练数据集示例图像

img, mask = train_ds[10] print(img.shape, img.type(),torch.max(img)) print(mask.shape, mask.type(),torch.max(mask))  plt.figure(figsize=(20,20))  img_r= re_normalize(img) plt.subplot(1, 3, 1)  plt.imshow(to_pil_image(img_r))  plt.subplot(1, 3, 2)  plt.imshow(mask)  plt.subplot(1, 3, 3)  show_img_target(img_r, mask)     

展示验证数据集示例图像

img, mask = val_ds[10] print(img.shape, img.type(),torch.max(img)) print(mask.shape, mask.type(),torch.max(mask))  plt.figure(figsize=(20,20))  img_r= re_normalize(img) plt.subplot(1, 3, 1)  plt.imshow(to_pil_image(img_r))  plt.subplot(1, 3, 2)  plt.imshow(mask)  plt.subplot(1, 3, 3)  show_img_target(img_r, mask)

创建数据加载器

 通过torch.utils.data针对训练和验证集分别创建Dataloader,打印示例观察效果

from torch.utils.data import DataLoader train_dl = DataLoader(train_ds, batch_size=4, shuffle=True) val_dl = DataLoader(val_ds, batch_size=8, shuffle=False)   for img_b, mask_b in train_dl:     print(img_b.shape,img_b.dtype)     print(mask_b.shape, mask_b.dtype)     break  for img_b, mask_b in val_dl:     print(img_b.shape,img_b.dtype)     print(mask_b.shape, mask_b.dtype)     break

创建模型

创建并打印deeplab_resnet模型结构,使用预训练权重

from torchvision.models.segmentation import deeplabv3_resnet101 import torch  model=deeplabv3_resnet101(pretrained=True, num_classes=21) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model=model.to(device) print(model)

部署模型

在验证数据集的数据批次上部署模型观察效果 

from torch import nn  model.eval() with torch.no_grad():     for xb, yb in val_dl:         yb_pred = model(xb.to(device))         yb_pred = yb_pred["out"].cpu()         print(yb_pred.shape)             yb_pred = torch.argmax(yb_pred,axis=1)         break print(yb_pred.shape)  plt.figure(figsize=(20,20))  n=2 img, mask= xb[n], yb_pred[n] img_r= re_normalize(img) plt.subplot(1, 3, 1)  plt.imshow(to_pil_image(img_r))  plt.subplot(1, 3, 2)  plt.imshow(mask)  plt.subplot(1, 3, 3)  show_img_target(img_r, mask)

可见勾勒对象方面效果很好 

定义损失函数和优化器

from torch import nn criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim opt = optim.Adam(model.parameters(), lr=1e-6)  def loss_batch(loss_func, output, target, opt=None):        loss = loss_func(output, target)          if opt is not None:         opt.zero_grad()         loss.backward()         opt.step()      return loss.item(), None  from torch.optim.lr_scheduler import ReduceLROnPlateau lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)  def get_lr(opt):     for param_group in opt.param_groups:         return param_group['lr']  current_lr=get_lr(opt) print('current lr={}'.format(current_lr))

训练和验证模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):     running_loss=0.0     len_data=len(dataset_dl.dataset)      for xb, yb in dataset_dl:         xb=xb.to(device)         yb=yb.to(device)                  output=model(xb)["out"]         loss_b, _ = loss_batch(loss_func, output, yb, opt)         running_loss += loss_b                  if sanity_check is True:             break          loss=running_loss/float(len_data)     return loss, None  import copy def train_val(model, params):     num_epochs=params["num_epochs"]     loss_func=params["loss_func"]     opt=params["optimizer"]     train_dl=params["train_dl"]     val_dl=params["val_dl"]     sanity_check=params["sanity_check"]     lr_scheduler=params["lr_scheduler"]     path2weights=params["path2weights"]          loss_history={         "train": [],         "val": []}          metric_history={         "train": [],         "val": []}                   best_model_wts = copy.deepcopy(model.state_dict())     best_loss=float('inf')              for epoch in range(num_epochs):         current_lr=get_lr(opt)         print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))             model.train()         train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)          loss_history["train"].append(train_loss)         metric_history["train"].append(train_metric)                  model.eval()         with torch.no_grad():             val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)                 loss_history["val"].append(val_loss)         metric_history["val"].append(val_metric)                     if val_loss < best_loss:             best_loss = val_loss             best_model_wts = copy.deepcopy(model.state_dict())                          torch.save(model.state_dict(), path2weights)             print("Copied best model weights!")                      lr_scheduler.step(val_loss)         if current_lr != get_lr(opt):             print("Loading best model weights!")             model.load_state_dict(best_model_wts)                       print("train loss: %.6f" %(train_loss))         print("val loss: %.6f" %(val_loss))         print("-"*10)      model.load_state_dict(best_model_wts)     return model, loss_history, metric_history         
import os opt = optim.Adam(model.parameters(), lr=1e-6) lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)  path2models= "./models/" if not os.path.exists(path2models):         os.mkdir(path2models)  params_train={     "num_epochs": 10,     "optimizer": opt,     "loss_func": criterion,     "train_dl": train_dl,     "val_dl": val_dl,     "sanity_check": True,     "lr_scheduler": lr_scheduler,     "path2weights": path2models+"sanity_weights.pt", }  model, loss_hist, _ = train_val(model, params_train)

绘制了训练和验证损失曲线 

num_epochs=params_train["num_epochs"]  plt.title("Train-Val Loss") plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train") plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val") plt.ylabel("Loss") plt.xlabel("Training Epochs") plt.legend() plt.show()

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!