最佳方式在PyTorch中使用Python迭代器作为数据集。

huangapple go评论63阅读模式
英文:

Best way to use Pyothon iterator as dataset in PyTorch

问题

The PyTorch DataLoader将数据集转化为可迭代对象。我已经有一个生成器,它生成我想用于训练和测试的数据样本。我之所以使用生成器,是因为样本总数太大,无法一次存储在内存中。我想要分批次加载样本进行训练。

如何以最佳方式实现这一目标?是否可以在不使用自定义DataLoader的情况下完成?PyTorch DataLoader似乎不接受生成器作为输入。下面是我想要做的最小示例,它会产生错误:"对象的类型为'generator'没有len()"。

import torch
from torch import nn
from torch.utils.data import DataLoader

def example_generator():
    for i in range(10):
        yield i
    
BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
                        batch_size = BATCH_SIZE,
                        shuffle=False)

print(f"train_dataloader的长度:{len(train_dataloader)} 批次,每批次包含{BATCH_SIZE}个样本")

我正在尝试从迭代器中获取数据并利用PyTorch DataLoader的功能。我提供的示例是我想要实现的最小示例,但它会产生错误。

编辑:我希望能够对复杂生成器使用此功能,其中__len__在预先不知道的情况下。

英文:

The PyTorch DataLoader turns datasets into iterables. I already have a generator which yields data samples that I want to use for training and testing. The reason I use a generator is because the total number of samples is too large to store in memory. I would like to load the samples in batches for training.

What is the best way to do this? Can I do it without a custom DataLoader? The PyTorch dataloader doesn't like taking the generator as input. Below is a minimal example of what I want to do, which produces the error "object of type 'generator' has no len()".

import torch
from torch import nn
from torch.utils.data import DataLoader

def example_generator():
    for i in range(10):
        yield i
    

BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
                        batch_size = BATCH_SIZE,
                        shuffle=False)

print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")

I am trying to take the data from an iterator and take advantage of the functionality of the PyTorch DataLoader. The example I gave is a minimal example of what I would like to achieve, but it produces an error.

Edit: I want to be able to use this function for complex generators in which the len is not known in advance.

答案1

得分: 0

PyTorch的DataLoader实际上官方支持可迭代数据集,但它必须是torch.utils.data.IterableDataset子类的实例:

from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, iterable):
        self.iterable = iterable

    def __iter__(self):
        return iter(self.iterable)

...

BATCH_SIZE = 3

train_dataloader = DataLoader(MyIterableDataset(example_generator()),
                              batch_size=BATCH_SIZE,
                              shuffle=False)
英文:

PyTorch's DataLoader actually has official support for an iterable dataset, but it just has to be an instance of a subclass of torch.utils.data.IterableDataset:

> An iterable-style dataset is an instance of a subclass of
> IterableDataset that implements the __iter__() protocol, and
> represents an iterable over data samples

So your code would be written as:

from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, iterable):
        self.iterable = iterable

    def __iter__(self):
        return iter(self.iterable)

...

BATCH_SIZE = 3

train_dataloader = DataLoader(MyIterableDataset(example_generator()),
                              batch_size = BATCH_SIZE,
                              shuffle=False)

huangapple
  • 本文由 发表于 2023年4月11日 10:53:04
  • 转载请务必保留本文链接:https://go.coder-hub.com/75982081.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定