pytorch如何打印网络结构

avatar
作者
猴君
阅读量:0

要打印PyTorch网络结构,可以使用print函数或者torchsummary库来实现。

使用print函数来打印网络结构示例如下:

import torch import torch.nn as nn  class Net(nn.Module):     def __init__(self):         super(Net, self).__init__()         self.conv1 = nn.Conv2d(3, 6, 3)         self.pool = nn.MaxPool2d(2, 2)         self.conv2 = nn.Conv2d(6, 16, 3)         self.fc1 = nn.Linear(16 * 6 * 6, 120)         self.fc2 = nn.Linear(120, 84)         self.fc3 = nn.Linear(84, 10)      def forward(self, x):         x = self.pool(F.relu(self.conv1(x)))         x = self.pool(F.relu(self.conv2(x)))         x = x.view(-1, 16 * 6 * 6)         x = F.relu(self.fc1(x))         x = F.relu(self.fc2(x))         x = self.fc3(x)         return x  net = Net() print(net) 

使用torchsummary库来打印网络结构示例如下:

from torchsummary import summary  net = Net() summary(net, input_size=(3, 32, 32)) 

以上两种方法都可以用来打印PyTorch网络结构,可以根据需要选择其中一种方法。

    广告一刻

    为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!