PyTorch中的数据加载器怎么使用

avatar
作者
猴君
阅读量:0

在PyTorch中,数据加载器可以通过torch.utils.data.DataLoader来实现。数据加载器可以帮助用户批量加载数据,并可以在训练过程中对数据进行随机排列、并行加载等操作。

下面是一个简单的示例,演示如何使用数据加载器来加载一个简单的数据集:

import torch from torch.utils.data import Dataset, DataLoader  # 创建一个自定义的数据集类 class CustomDataset(Dataset):     def __init__(self):         self.data = torch.randn(100, 3)  # 100个3维的随机数据         self.targets = torch.randint(0, 2, (100,))  # 100个随机目标标签      def __len__(self):         return len(self.data)      def __getitem__(self, idx):         return self.data[idx], self.targets[idx]  # 创建数据集实例 dataset = CustomDataset()  # 创建数据加载器实例 data_loader = DataLoader(dataset, batch_size=32, shuffle=True)  # 遍历数据加载器 for i, (data, target) in enumerate(data_loader):     print(f'Batch {i}:')     print('Data:', data)     print('Target:', target) 

在上述示例中,首先定义了一个自定义的数据集类CustomDataset,然后创建了一个数据集实例dataset。接着利用DataLoader类来创建一个数据加载器实例data_loader,并指定了批量大小为32且开启了数据随机排列。最后通过对数据加载器进行遍历,便可以逐批次地获取数据和标签。

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!