【深度学习】基于pytorch的胶囊网络实现

avatar
作者
筋斗云
阅读量:0

参考资料:

  • 阿里云实践:https://developer.aliyun.com/article/581717
  • 动态路由机制讲解:https://www.bilibili.com/video/BV1oW411H7G1/?spm_id_from=333.337.search-card.all.click&vd_source=3b5e1109bdab0d21b23a5c46c4ed667d

  • Hinton论文:Dynamic Routing Between Capsules

代码:

import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np   class CapsuleLayer(nn.Module):     def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, routing_iters=3, batch_size=128):         super(CapsuleLayer, self).__init__()         self.num_route_nodes = num_route_nodes         self.num_capsules = num_capsules         self.routing_iters = routing_iters                  self.W = nn.Parameter(torch.randn(1, num_capsules, num_route_nodes, in_channels, out_channels))          # ([1, 10, 1152, 8, 16])              def forward(self, x):         # ([128, 32, 6, 6, 8])         x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3), x.size(4)) # ([128, 1152, 8])         x = x.unsqueeze(1) # ([128, 1, 1152, 8])         x = x.repeat(1, self.num_capsules, 1, 1) # # ([128, 10, 1152, 8])         x = x.unsqueeze(3) # ([128, 10, 1152, 8, 1])         u_hat = torch.matmul(x, self.W)  # ([128, 10, 1152, 1, 16])         u_hat = u_hat.squeeze(3) # ([128, 10, 1152, 16])                  b = torch.zeros(x.size(0), self.num_capsules, self.num_route_nodes, 1) # ([128, 10, 1152, 1])         if next(self.parameters()).is_cuda:             b = b.cuda()                  for _ in range(self.routing_iters):             c = F.softmax(b, dim=2)             s = (c * u_hat).sum(dim=2, keepdim=True) # ([128, 10, 1, 16])             v = self.squash(s) # ([128, 10, 1, 16])             if _ < self.routing_iters - 1:                 b = b + (u_hat * v).sum(dim=-1, keepdim=True)                  return v.squeeze(dim=-1)          def squash(self, input_tensor):         squared_norm = (input_tensor ** 2).sum(dim=-1, keepdim=True)         scale = squared_norm / (1 + squared_norm)         output_tensor = scale * input_tensor / torch.sqrt(squared_norm)         return output_tensor      class PrimaryCaps(nn.Module):     def __init__(self, in_channels, out_channels, kernel_size, stride, dim_capsule):         super(PrimaryCaps, self).__init__()         self.conv = nn.Conv2d(in_channels, out_channels * dim_capsule, kernel_size=kernel_size, stride=stride)         self.dim_capsule = dim_capsule          def forward(self, x):         x = self.conv(x)         batch_size = x.size(0)         out_channels = int(x.size(1) / self.dim_capsule)         height = x.size(2)         width = x.size(3)         # Reshape to [batch_size, out_channels, height, width, dim_capsule]         # 使用 view 方法进行形状改变         x = x.view(batch_size, out_channels, height, width, self.dim_capsule)         return x      class CapsNet(nn.Module):     def __init__(self, input_shape, n_class, routings):         super(CapsNet, self).__init__()         self.input_shape = input_shape         self.n_class = n_class         self.routings = routings                  self.conv1 = nn.Conv2d(in_channels=input_shape[0], out_channels=256, kernel_size=9, stride=1)         self.primarycaps = PrimaryCaps(dim_capsule=8, in_channels=256, out_channels=32, kernel_size=9, stride=2)         self.digitcaps = CapsuleLayer(num_capsules=n_class, num_route_nodes=32*6*6, in_channels=8, out_channels=16, routing_iters=routings)         self.decoder = nn.Sequential(             nn.Linear(16 * n_class, 512),             nn.ReLU(inplace=True),             nn.Linear(512, 1024),             nn.ReLU(inplace=True),             nn.Linear(1024, np.prod(input_shape)),             nn.Sigmoid()         )          def forward(self, x):         x = F.relu(self.conv1(x)) # ([128, 256, 20, 20])         x = self.primarycaps(x) # ([128, 32, 6, 6, 8])         x = self.digitcaps(x) # ([128, 10, 1, 16])                  # Length of output capsules         lengths = x.norm(dim=-1).squeeze(2) # ([128, 10])          # Reconstruction         x = x.view(x.size(0), -1) # ([128, 160])         reconstructions = self.decoder(x)         reconstructions = reconstructions.view(-1, *self.input_shape)                  return lengths, reconstructions  # Create CapsNet model input_shape = (1, 28, 28)  # Example for MNIST n_class = 10  # Number of classes routings = 3  # Number of routing iterations  model = CapsNet(input_shape, n_class, routings) print(model)  # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++  # Load MNIST dataset batch_size = 128 transform = transforms.Compose([     transforms.ToTensor(),     transforms.Normalize((0.1307,), (0.3081,)) ])  train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)  # Initialize model, optimizer, and loss function device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CapsNet(input_shape=(1, 28, 28), n_class=10, routings=3).to(device) optimizer = optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss()  # Training function def train(model, train_loader, optimizer, criterion, epoch):     model.train()     train_loss = 0     correct = 0     for batch_idx, (data, target) in enumerate(train_loader):         data, target = data.to(device), target.to(device)         optimizer.zero_grad()         lengths, reconstructions = model(data)         classification_loss = criterion(lengths, target)         reconstructions_loss = F.mse_loss(reconstructions, data)         loss = classification_loss + reconstructions_loss         loss.backward()         optimizer.step()                  train_loss += loss.item()         pred = torch.argmax(lengths, dim=1)         correct += pred.eq(target).sum().item()                  if batch_idx % 100 == 0:             print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '                   f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss(c/r): {loss.item():.6f} ({classification_loss.item():.6f} / {reconstructions_loss.item():.6f})')          train_loss /= len(train_loader.dataset)     accuracy = 100. * correct / len(train_loader.dataset)     print(f'Train Epoch: {epoch}\tAverage loss: {train_loss:.4f}\tAccuracy: {accuracy:.2f}%')  # Testing function def test(model, test_loader, criterion):     model.eval()     test_loss = 0     correct = 0     with torch.no_grad():         for data, target in test_loader:             data, target = data.to(device), target.to(device)             lengths, reconstructions = model(data)             loss = criterion(lengths, target) + F.mse_loss(reconstructions, data)             test_loss += loss.item()             pred = lengths.argmax(dim=-1)             correct += pred.eq(target).sum().item()          test_loss /= len(test_loader.dataset)     accuracy = 100. * correct / len(test_loader.dataset)     print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '           f'({accuracy:.2f}%)\n')  # Train the model epochs = 10 for epoch in range(1, epochs + 1):     train(model, train_loader, optimizer, criterion, epoch)     test(model, test_loader, criterion) 

这个是去掉了reconstruction的训练效果:

CapsNet(   (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))   (primarycaps): PrimaryCaps(     (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))   )   (digitcaps): CapsuleLayer()   (decoder): Sequential(     (0): Linear(in_features=160, out_features=512, bias=True)     (1): ReLU(inplace=True)     (2): Linear(in_features=512, out_features=1024, bias=True)     (3): ReLU(inplace=True)     (4): Linear(in_features=1024, out_features=784, bias=True)     (5): Sigmoid()   ) ) Train Epoch: 1 [0/60000 (0%)]	Loss: 2.301852 Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.302585 Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.302585 Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.302585 Train Epoch: 1 [51200/60000 (85%)]	Loss: 2.302585 Train Epoch: 1	Average loss: 0.0180	Accuracy: 15.21%  Test set: Average loss: 0.0182, Accuracy: 2033/10000 (20.33%)  Train Epoch: 2 [0/60000 (0%)]	Loss: 2.302584 Train Epoch: 2 [12800/60000 (21%)]	Loss: 2.302584 Train Epoch: 2 [25600/60000 (43%)]	Loss: 2.302585 Train Epoch: 2 [38400/60000 (64%)]	Loss: 2.302584 Train Epoch: 2 [51200/60000 (85%)]	Loss: 2.302584 Train Epoch: 2	Average loss: 0.0180	Accuracy: 16.46%  Test set: Average loss: 0.0182, Accuracy: 1753/10000 (17.53%)  Train Epoch: 3 [0/60000 (0%)]	Loss: 2.302584 Train Epoch: 3 [12800/60000 (21%)]	Loss: 2.302584 Train Epoch: 3 [25600/60000 (43%)]	Loss: 2.302583 Train Epoch: 3 [38400/60000 (64%)]	Loss: 2.302582 Train Epoch: 3 [51200/60000 (85%)]	Loss: 2.302579 Train Epoch: 3	Average loss: 0.0180	Accuracy: 18.76%  Test set: Average loss: 0.0182, Accuracy: 2707/10000 (27.07%)  Train Epoch: 4 [0/60000 (0%)]	Loss: 2.302577 Train Epoch: 4 [12800/60000 (21%)]	Loss: 2.302552 Train Epoch: 4 [25600/60000 (43%)]	Loss: 2.302583 Train Epoch: 4 [38400/60000 (64%)]	Loss: 2.302581 Train Epoch: 4 [51200/60000 (85%)]	Loss: 2.302576 Train Epoch: 4	Average loss: 0.0180	Accuracy: 21.94%  Test set: Average loss: 0.0182, Accuracy: 1935/10000 (19.35%)  Train Epoch: 5 [0/60000 (0%)]	Loss: 2.302576 Train Epoch: 5 [12800/60000 (21%)]	Loss: 2.302568 Train Epoch: 5 [25600/60000 (43%)]	Loss: 2.302579 Train Epoch: 5 [38400/60000 (64%)]	Loss: 2.302577 Train Epoch: 5 [51200/60000 (85%)]	Loss: 2.302578 Train Epoch: 5	Average loss: 0.0180	Accuracy: 23.75%  Test set: Average loss: 0.0182, Accuracy: 2909/10000 (29.09%)  Train Epoch: 6 [0/60000 (0%)]	Loss: 2.302570 Train Epoch: 6 [12800/60000 (21%)]	Loss: 2.302553 Train Epoch: 6 [25600/60000 (43%)]	Loss: 2.302286 Train Epoch: 6 [38400/60000 (64%)]	Loss: 1.515582 Train Epoch: 6 [51200/60000 (85%)]	Loss: 1.519818 Train Epoch: 6	Average loss: 0.0147	Accuracy: 66.09%  Test set: Average loss: 0.0119, Accuracy: 9556/10000 (95.56%)  Train Epoch: 7 [0/60000 (0%)]	Loss: 1.531795 Train Epoch: 7 [12800/60000 (21%)]	Loss: 1.491780 Train Epoch: 7 [25600/60000 (43%)]	Loss: 1.485394 Train Epoch: 7 [38400/60000 (64%)]	Loss: 1.485206 Train Epoch: 7 [51200/60000 (85%)]	Loss: 1.483377 Train Epoch: 7	Average loss: 0.0117	Accuracy: 97.54%  Test set: Average loss: 0.0118, Accuracy: 9801/10000 (98.01%)  Train Epoch: 8 [0/60000 (0%)]	Loss: 1.492681 Train Epoch: 8 [12800/60000 (21%)]	Loss: 1.511164 Train Epoch: 8 [25600/60000 (43%)]	Loss: 1.471053 Train Epoch: 8 [38400/60000 (64%)]	Loss: 1.472341 Train Epoch: 8 [51200/60000 (85%)]	Loss: 1.496702 Train Epoch: 8	Average loss: 0.0116	Accuracy: 98.42%  Test set: Average loss: 0.0118, Accuracy: 9787/10000 (97.87%)  Train Epoch: 9 [0/60000 (0%)]	Loss: 1.476871 Train Epoch: 9 [12800/60000 (21%)]	Loss: 1.483141 Train Epoch: 9 [25600/60000 (43%)]	Loss: 1.489299 Train Epoch: 9 [38400/60000 (64%)]	Loss: 1.478493 Train Epoch: 9 [51200/60000 (85%)]	Loss: 1.472141 Train Epoch: 9	Average loss: 0.0116	Accuracy: 98.89%  Test set: Average loss: 0.0117, Accuracy: 9856/10000 (98.56%)  Train Epoch: 10 [0/60000 (0%)]	Loss: 1.468742 Train Epoch: 10 [12800/60000 (21%)]	Loss: 1.479285 Train Epoch: 10 [25600/60000 (43%)]	Loss: 1.471160 Train Epoch: 10 [38400/60000 (64%)]	Loss: 1.468553 Train Epoch: 10 [51200/60000 (85%)]	Loss: 1.478799 Train Epoch: 10	Average loss: 0.0115	Accuracy: 99.06%  Test set: Average loss: 0.0117, Accuracy: 9819/10000 (98.19%)

可以看到从第六个epoch开始,acurracy呈现断崖式增长,但是前期增长却很慢,而且训练速度非常慢。

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!