Torch中怎么加载和使用预训练模型

avatar
作者
筋斗云
阅读量:0

在Torch中加载和使用预训练模型通常通过使用torchvision.models模块来实现。以下是一个简单的示例,演示如何加载预训练的ResNet模型并使用它对图像进行预测:

import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image  # 加载预训练的ResNet模型 model = models.resnet18(pretrained=True) model.eval()  # 定义图片预处理步骤 transform = transforms.Compose([     transforms.Resize(256),     transforms.CenterCrop(224),     transforms.ToTensor(),     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])  # 加载并预处理图像 img = Image.open('image.jpg') img = transform(img).unsqueeze(0)  # 使用模型进行预测 output = model(img)  # 获取预测结果 _, predicted = torch.max(output, 1)  print('Predicted class:', predicted.item()) 

在上面的示例中,我们首先加载了预训练的ResNet模型,并将其设置为评估模式。然后定义了图像预处理步骤,并加载并预处理了一个示例图像。最后,我们使用模型对图像进行预测,并输出预测结果。

请注意,这只是一个简单的示例,实际应用中可能会有更复杂的预处理步骤和模型的使用方式,具体取决于你的应用场景和需求。

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!