从DataLoader中检索Dataset对象的方法是什么?

huangapple go评论82阅读模式
英文:

How to retrieve Dataset object from DataLoader?

问题

I have a PyTorch DataLoader and want to retrieve the Dataset object that the loader wraps around. If this is possible, how? Or does the dataset object only exist for pre-loaded datasets on torch?

The end goal is to easily integrate data in dataloader format into code setup for a dataset format (e.g. CIFAR10).

Where in the original code there is:

from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset

def get_dataset(dataset, split):
    if dataset == "CIFAR10":
        return _cifar10(split)

def _cifar10(split: str) -> Dataset:
    if split == "train":
        return datasets.CIFAR10("./dataset_cache", train=True, download=True)

dataset = get_dataset("CIFAR10", train)
for i in range(len(dataset)):
    ...

I have tried importing the whole dataset at once:

from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset

def get_dataset(dataset, split):
    if dataset == "CIFAR10":
        return _cifar10(split)
    elif dataset == "mydataset":
        return _mydataset(split)

def _mydataset(split: str) -> Dataset:
    files = 
    total_num_images = 0
    for file in files:
        number_images = len([name for name in os.listdir(database_directory +
            '/' + split + '/' + file)])
        total_num_images += number_images
    if split == "train":
        mydataset = torch.utils.data.DataLoader(
            datasets.ImageFolder(dataset_directory + '/train'), batch_size=total_num_images)
        return mydataset

dataset = get_dataset("mydataset", train)
for i in range(len(dataset)):
    ...

But this returns the error 'DataLoader' object is not subscriptable.

英文:

I have a PyTorch DataLoader and want to retrieve the Dataset object that the loader wraps around. If this is possible, how? Or does the dataset object only exist for pre-loaded datasets on torch?

The end goal is to easily integrate data in dataloader format into code setup for a dataset format (e.g. CIFAR10).

Where in the original code there is:

from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset

def get_dataset(dataset, split):
    if dataset == "CIFAR10"
        return _cifar10(split)

def _cifar10(split: str) -> Dataset:
    if split == "train":
        return datasets.CIFAR10("./dataset_cache", train=True, download=True)

dataset = get_dataset("CIFAR10", train)
for i in range(len(dataset)):
    ...

I have tried importing the whole dataset at once:

from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset

def get_dataset(dataset, split):
    if dataset == "CIFAR10"
        return _cifar10(split)
    elif dataset == "mydataset"
        return _mydataset(split)

def _mydataset(split: str) -> Dataset:
    files = 
    total_num_images = 0
    for file in files:
        number_images = len([name for name in os.listdir(database_directory +
            '/' + split + '/' + file)])
        total_num_images += number_images
    if split == "train":
        mydataset = torch.utils.data.DataLoader(
            datasets.ImageFolder(dataset_directory + '/train'),batch_size=total_num_images)
        return mydataset

dataset = get_dataset("mydataset", train)
for i in range(len(dataset)):
    ...

But this returns the error 'DataLoader' object is not subscriptable.

答案1

得分: 2

你可以访问 data.DataLoader 上的 dataset 属性,以获取其底层的 data.Dataset 对象。如在此处的源代码中所见 here

英文:

You can access the dataset attribute on data.DataLoader to get its underlying data.Dataset object. As seen in the source code here.

huangapple
  • 本文由 发表于 2023年3月9日 18:35:03
  • 转载请务必保留本文链接:https://go.coder-hub.com/75683408.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定