Torch中怎么加载和处理数据集

avatar
作者
猴君
阅读量:16

在Torch中加载和处理数据集通常通过使用torch.utils.data.Datasettorch.utils.data.DataLoader类来实现。以下是一个简单的示例代码:

import torch from torch.utils.data import Dataset, DataLoader  # 定义自定义数据集类 class CustomDataset(Dataset):     def __init__(self):         # 初始化数据集         self.data = torch.randn(100, 10)         self.labels = torch.randint(0, 2, (100,))              def __len__(self):         # 返回数据集大小         return len(self.data)          def __getitem__(self, idx):         # 获取数据集中的一个样本         return self.data[idx], self.labels[idx]  # 创建数据集实例 dataset = CustomDataset()  # 创建数据加载器 dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # 遍历数据集 for data, labels in dataloader:     # 处理每个批次的数据     print(data.shape, labels.shape) 

在上面的示例中,定义了一个自定义的数据集类CustomDataset,其中实现了__init____len____getitem__方法。然后创建了dataset实例和dataloader对象,并使用for循环遍历数据加载器,获取每个批次的数据。

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!