阅读量:0
参考资料:
- 阿里云实践:https://developer.aliyun.com/article/581717
动态路由机制讲解:https://www.bilibili.com/video/BV1oW411H7G1/?spm_id_from=333.337.search-card.all.click&vd_source=3b5e1109bdab0d21b23a5c46c4ed667d
Hinton论文:Dynamic Routing Between Capsules
代码:
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np class CapsuleLayer(nn.Module): def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, routing_iters=3, batch_size=128): super(CapsuleLayer, self).__init__() self.num_route_nodes = num_route_nodes self.num_capsules = num_capsules self.routing_iters = routing_iters self.W = nn.Parameter(torch.randn(1, num_capsules, num_route_nodes, in_channels, out_channels)) # ([1, 10, 1152, 8, 16]) def forward(self, x): # ([128, 32, 6, 6, 8]) x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3), x.size(4)) # ([128, 1152, 8]) x = x.unsqueeze(1) # ([128, 1, 1152, 8]) x = x.repeat(1, self.num_capsules, 1, 1) # # ([128, 10, 1152, 8]) x = x.unsqueeze(3) # ([128, 10, 1152, 8, 1]) u_hat = torch.matmul(x, self.W) # ([128, 10, 1152, 1, 16]) u_hat = u_hat.squeeze(3) # ([128, 10, 1152, 16]) b = torch.zeros(x.size(0), self.num_capsules, self.num_route_nodes, 1) # ([128, 10, 1152, 1]) if next(self.parameters()).is_cuda: b = b.cuda() for _ in range(self.routing_iters): c = F.softmax(b, dim=2) s = (c * u_hat).sum(dim=2, keepdim=True) # ([128, 10, 1, 16]) v = self.squash(s) # ([128, 10, 1, 16]) if _ < self.routing_iters - 1: b = b + (u_hat * v).sum(dim=-1, keepdim=True) return v.squeeze(dim=-1) def squash(self, input_tensor): squared_norm = (input_tensor ** 2).sum(dim=-1, keepdim=True) scale = squared_norm / (1 + squared_norm) output_tensor = scale * input_tensor / torch.sqrt(squared_norm) return output_tensor class PrimaryCaps(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, dim_capsule): super(PrimaryCaps, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels * dim_capsule, kernel_size=kernel_size, stride=stride) self.dim_capsule = dim_capsule def forward(self, x): x = self.conv(x) batch_size = x.size(0) out_channels = int(x.size(1) / self.dim_capsule) height = x.size(2) width = x.size(3) # Reshape to [batch_size, out_channels, height, width, dim_capsule] # 使用 view 方法进行形状改变 x = x.view(batch_size, out_channels, height, width, self.dim_capsule) return x class CapsNet(nn.Module): def __init__(self, input_shape, n_class, routings): super(CapsNet, self).__init__() self.input_shape = input_shape self.n_class = n_class self.routings = routings self.conv1 = nn.Conv2d(in_channels=input_shape[0], out_channels=256, kernel_size=9, stride=1) self.primarycaps = PrimaryCaps(dim_capsule=8, in_channels=256, out_channels=32, kernel_size=9, stride=2) self.digitcaps = CapsuleLayer(num_capsules=n_class, num_route_nodes=32*6*6, in_channels=8, out_channels=16, routing_iters=routings) self.decoder = nn.Sequential( nn.Linear(16 * n_class, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, np.prod(input_shape)), nn.Sigmoid() ) def forward(self, x): x = F.relu(self.conv1(x)) # ([128, 256, 20, 20]) x = self.primarycaps(x) # ([128, 32, 6, 6, 8]) x = self.digitcaps(x) # ([128, 10, 1, 16]) # Length of output capsules lengths = x.norm(dim=-1).squeeze(2) # ([128, 10]) # Reconstruction x = x.view(x.size(0), -1) # ([128, 160]) reconstructions = self.decoder(x) reconstructions = reconstructions.view(-1, *self.input_shape) return lengths, reconstructions # Create CapsNet model input_shape = (1, 28, 28) # Example for MNIST n_class = 10 # Number of classes routings = 3 # Number of routing iterations model = CapsNet(input_shape, n_class, routings) print(model) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # Load MNIST dataset batch_size = 128 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # Initialize model, optimizer, and loss function device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CapsNet(input_shape=(1, 28, 28), n_class=10, routings=3).to(device) optimizer = optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss() # Training function def train(model, train_loader, optimizer, criterion, epoch): model.train() train_loss = 0 correct = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() lengths, reconstructions = model(data) classification_loss = criterion(lengths, target) reconstructions_loss = F.mse_loss(reconstructions, data) loss = classification_loss + reconstructions_loss loss.backward() optimizer.step() train_loss += loss.item() pred = torch.argmax(lengths, dim=1) correct += pred.eq(target).sum().item() if batch_idx % 100 == 0: print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ' f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss(c/r): {loss.item():.6f} ({classification_loss.item():.6f} / {reconstructions_loss.item():.6f})') train_loss /= len(train_loader.dataset) accuracy = 100. * correct / len(train_loader.dataset) print(f'Train Epoch: {epoch}\tAverage loss: {train_loss:.4f}\tAccuracy: {accuracy:.2f}%') # Testing function def test(model, test_loader, criterion): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) lengths, reconstructions = model(data) loss = criterion(lengths, target) + F.mse_loss(reconstructions, data) test_loss += loss.item() pred = lengths.argmax(dim=-1) correct += pred.eq(target).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ' f'({accuracy:.2f}%)\n') # Train the model epochs = 10 for epoch in range(1, epochs + 1): train(model, train_loader, optimizer, criterion, epoch) test(model, test_loader, criterion)
这个是去掉了reconstruction的训练效果:
CapsNet( (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1)) (primarycaps): PrimaryCaps( (conv): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2)) ) (digitcaps): CapsuleLayer() (decoder): Sequential( (0): Linear(in_features=160, out_features=512, bias=True) (1): ReLU(inplace=True) (2): Linear(in_features=512, out_features=1024, bias=True) (3): ReLU(inplace=True) (4): Linear(in_features=1024, out_features=784, bias=True) (5): Sigmoid() ) ) Train Epoch: 1 [0/60000 (0%)] Loss: 2.301852 Train Epoch: 1 [12800/60000 (21%)] Loss: 2.302585 Train Epoch: 1 [25600/60000 (43%)] Loss: 2.302585 Train Epoch: 1 [38400/60000 (64%)] Loss: 2.302585 Train Epoch: 1 [51200/60000 (85%)] Loss: 2.302585 Train Epoch: 1 Average loss: 0.0180 Accuracy: 15.21% Test set: Average loss: 0.0182, Accuracy: 2033/10000 (20.33%) Train Epoch: 2 [0/60000 (0%)] Loss: 2.302584 Train Epoch: 2 [12800/60000 (21%)] Loss: 2.302584 Train Epoch: 2 [25600/60000 (43%)] Loss: 2.302585 Train Epoch: 2 [38400/60000 (64%)] Loss: 2.302584 Train Epoch: 2 [51200/60000 (85%)] Loss: 2.302584 Train Epoch: 2 Average loss: 0.0180 Accuracy: 16.46% Test set: Average loss: 0.0182, Accuracy: 1753/10000 (17.53%) Train Epoch: 3 [0/60000 (0%)] Loss: 2.302584 Train Epoch: 3 [12800/60000 (21%)] Loss: 2.302584 Train Epoch: 3 [25600/60000 (43%)] Loss: 2.302583 Train Epoch: 3 [38400/60000 (64%)] Loss: 2.302582 Train Epoch: 3 [51200/60000 (85%)] Loss: 2.302579 Train Epoch: 3 Average loss: 0.0180 Accuracy: 18.76% Test set: Average loss: 0.0182, Accuracy: 2707/10000 (27.07%) Train Epoch: 4 [0/60000 (0%)] Loss: 2.302577 Train Epoch: 4 [12800/60000 (21%)] Loss: 2.302552 Train Epoch: 4 [25600/60000 (43%)] Loss: 2.302583 Train Epoch: 4 [38400/60000 (64%)] Loss: 2.302581 Train Epoch: 4 [51200/60000 (85%)] Loss: 2.302576 Train Epoch: 4 Average loss: 0.0180 Accuracy: 21.94% Test set: Average loss: 0.0182, Accuracy: 1935/10000 (19.35%) Train Epoch: 5 [0/60000 (0%)] Loss: 2.302576 Train Epoch: 5 [12800/60000 (21%)] Loss: 2.302568 Train Epoch: 5 [25600/60000 (43%)] Loss: 2.302579 Train Epoch: 5 [38400/60000 (64%)] Loss: 2.302577 Train Epoch: 5 [51200/60000 (85%)] Loss: 2.302578 Train Epoch: 5 Average loss: 0.0180 Accuracy: 23.75% Test set: Average loss: 0.0182, Accuracy: 2909/10000 (29.09%) Train Epoch: 6 [0/60000 (0%)] Loss: 2.302570 Train Epoch: 6 [12800/60000 (21%)] Loss: 2.302553 Train Epoch: 6 [25600/60000 (43%)] Loss: 2.302286 Train Epoch: 6 [38400/60000 (64%)] Loss: 1.515582 Train Epoch: 6 [51200/60000 (85%)] Loss: 1.519818 Train Epoch: 6 Average loss: 0.0147 Accuracy: 66.09% Test set: Average loss: 0.0119, Accuracy: 9556/10000 (95.56%) Train Epoch: 7 [0/60000 (0%)] Loss: 1.531795 Train Epoch: 7 [12800/60000 (21%)] Loss: 1.491780 Train Epoch: 7 [25600/60000 (43%)] Loss: 1.485394 Train Epoch: 7 [38400/60000 (64%)] Loss: 1.485206 Train Epoch: 7 [51200/60000 (85%)] Loss: 1.483377 Train Epoch: 7 Average loss: 0.0117 Accuracy: 97.54% Test set: Average loss: 0.0118, Accuracy: 9801/10000 (98.01%) Train Epoch: 8 [0/60000 (0%)] Loss: 1.492681 Train Epoch: 8 [12800/60000 (21%)] Loss: 1.511164 Train Epoch: 8 [25600/60000 (43%)] Loss: 1.471053 Train Epoch: 8 [38400/60000 (64%)] Loss: 1.472341 Train Epoch: 8 [51200/60000 (85%)] Loss: 1.496702 Train Epoch: 8 Average loss: 0.0116 Accuracy: 98.42% Test set: Average loss: 0.0118, Accuracy: 9787/10000 (97.87%) Train Epoch: 9 [0/60000 (0%)] Loss: 1.476871 Train Epoch: 9 [12800/60000 (21%)] Loss: 1.483141 Train Epoch: 9 [25600/60000 (43%)] Loss: 1.489299 Train Epoch: 9 [38400/60000 (64%)] Loss: 1.478493 Train Epoch: 9 [51200/60000 (85%)] Loss: 1.472141 Train Epoch: 9 Average loss: 0.0116 Accuracy: 98.89% Test set: Average loss: 0.0117, Accuracy: 9856/10000 (98.56%) Train Epoch: 10 [0/60000 (0%)] Loss: 1.468742 Train Epoch: 10 [12800/60000 (21%)] Loss: 1.479285 Train Epoch: 10 [25600/60000 (43%)] Loss: 1.471160 Train Epoch: 10 [38400/60000 (64%)] Loss: 1.468553 Train Epoch: 10 [51200/60000 (85%)] Loss: 1.478799 Train Epoch: 10 Average loss: 0.0115 Accuracy: 99.06% Test set: Average loss: 0.0117, Accuracy: 9819/10000 (98.19%)
可以看到从第六个epoch开始,acurracy呈现断崖式增长,但是前期增长却很慢,而且训练速度非常慢。