阅读量:16
在Torch中加载和处理数据集通常通过使用torch.utils.data.Dataset
和torch.utils.data.DataLoader
类来实现。以下是一个简单的示例代码:
import torch from torch.utils.data import Dataset, DataLoader # 定义自定义数据集类 class CustomDataset(Dataset): def __init__(self): # 初始化数据集 self.data = torch.randn(100, 10) self.labels = torch.randint(0, 2, (100,)) def __len__(self): # 返回数据集大小 return len(self.data) def __getitem__(self, idx): # 获取数据集中的一个样本 return self.data[idx], self.labels[idx] # 创建数据集实例 dataset = CustomDataset() # 创建数据加载器 dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # 遍历数据集 for data, labels in dataloader: # 处理每个批次的数据 print(data.shape, labels.shape)
在上面的示例中,定义了一个自定义的数据集类CustomDataset
,其中实现了__init__
、__len__
和__getitem__
方法。然后创建了dataset
实例和dataloader
对象,并使用for
循环遍历数据加载器,获取每个批次的数据。