阅读量:0
在PyTorch中,数据加载器可以通过torch.utils.data.DataLoader
来实现。数据加载器可以帮助用户批量加载数据,并可以在训练过程中对数据进行随机排列、并行加载等操作。
下面是一个简单的示例,演示如何使用数据加载器来加载一个简单的数据集:
import torch from torch.utils.data import Dataset, DataLoader # 创建一个自定义的数据集类 class CustomDataset(Dataset): def __init__(self): self.data = torch.randn(100, 3) # 100个3维的随机数据 self.targets = torch.randint(0, 2, (100,)) # 100个随机目标标签 def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.targets[idx] # 创建数据集实例 dataset = CustomDataset() # 创建数据加载器实例 data_loader = DataLoader(dataset, batch_size=32, shuffle=True) # 遍历数据加载器 for i, (data, target) in enumerate(data_loader): print(f'Batch {i}:') print('Data:', data) print('Target:', target)
在上述示例中,首先定义了一个自定义的数据集类CustomDataset
,然后创建了一个数据集实例dataset
。接着利用DataLoader
类来创建一个数据加载器实例data_loader
,并指定了批量大小为32且开启了数据随机排列。最后通过对数据加载器进行遍历,便可以逐批次地获取数据和标签。