英文:
Best way to use Pyothon iterator as dataset in PyTorch
问题
The PyTorch DataLoader将数据集转化为可迭代对象。我已经有一个生成器,它生成我想用于训练和测试的数据样本。我之所以使用生成器,是因为样本总数太大,无法一次存储在内存中。我想要分批次加载样本进行训练。
如何以最佳方式实现这一目标?是否可以在不使用自定义DataLoader的情况下完成?PyTorch DataLoader似乎不接受生成器作为输入。下面是我想要做的最小示例,它会产生错误:"对象的类型为'generator'没有len()"。
import torch
from torch import nn
from torch.utils.data import DataLoader
def example_generator():
for i in range(10):
yield i
BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
batch_size = BATCH_SIZE,
shuffle=False)
print(f"train_dataloader的长度:{len(train_dataloader)} 批次,每批次包含{BATCH_SIZE}个样本")
我正在尝试从迭代器中获取数据并利用PyTorch DataLoader的功能。我提供的示例是我想要实现的最小示例,但它会产生错误。
编辑:我希望能够对复杂生成器使用此功能,其中__len__在预先不知道的情况下。
英文:
The PyTorch DataLoader turns datasets into iterables. I already have a generator which yields data samples that I want to use for training and testing. The reason I use a generator is because the total number of samples is too large to store in memory. I would like to load the samples in batches for training.
What is the best way to do this? Can I do it without a custom DataLoader? The PyTorch dataloader doesn't like taking the generator as input. Below is a minimal example of what I want to do, which produces the error "object of type 'generator' has no len()".
import torch
from torch import nn
from torch.utils.data import DataLoader
def example_generator():
for i in range(10):
yield i
BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
batch_size = BATCH_SIZE,
shuffle=False)
print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
I am trying to take the data from an iterator and take advantage of the functionality of the PyTorch DataLoader. The example I gave is a minimal example of what I would like to achieve, but it produces an error.
Edit: I want to be able to use this function for complex generators in which the len is not known in advance.
答案1
得分: 0
PyTorch的DataLoader
实际上官方支持可迭代数据集,但它必须是torch.utils.data.IterableDataset
子类的实例:
from torch.utils.data import IterableDataset
class MyIterableDataset(IterableDataset):
def __init__(self, iterable):
self.iterable = iterable
def __iter__(self):
return iter(self.iterable)
...
BATCH_SIZE = 3
train_dataloader = DataLoader(MyIterableDataset(example_generator()),
batch_size=BATCH_SIZE,
shuffle=False)
英文:
PyTorch's DataLoader
actually has official support for an iterable dataset, but it just has to be an instance of a subclass of torch.utils.data.IterableDataset
:
> An iterable-style dataset is an instance of a subclass of
> IterableDataset that implements the __iter__()
protocol, and
> represents an iterable over data samples
So your code would be written as:
from torch.utils.data import IterableDataset
class MyIterableDataset(IterableDataset):
def __init__(self, iterable):
self.iterable = iterable
def __iter__(self):
return iter(self.iterable)
...
BATCH_SIZE = 3
train_dataloader = DataLoader(MyIterableDataset(example_generator()),
batch_size = BATCH_SIZE,
shuffle=False)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论