英文:
How to retrieve Dataset object from DataLoader?
问题
I have a PyTorch DataLoader and want to retrieve the Dataset object that the loader wraps around. If this is possible, how? Or does the dataset object only exist for pre-loaded datasets on torch?
The end goal is to easily integrate data in dataloader format into code setup for a dataset format (e.g. CIFAR10).
Where in the original code there is:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10":
return _cifar10(split)
def _cifar10(split: str) -> Dataset:
if split == "train":
return datasets.CIFAR10("./dataset_cache", train=True, download=True)
dataset = get_dataset("CIFAR10", train)
for i in range(len(dataset)):
...
I have tried importing the whole dataset at once:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10":
return _cifar10(split)
elif dataset == "mydataset":
return _mydataset(split)
def _mydataset(split: str) -> Dataset:
files =
total_num_images = 0
for file in files:
number_images = len([name for name in os.listdir(database_directory +
'/' + split + '/' + file)])
total_num_images += number_images
if split == "train":
mydataset = torch.utils.data.DataLoader(
datasets.ImageFolder(dataset_directory + '/train'), batch_size=total_num_images)
return mydataset
dataset = get_dataset("mydataset", train)
for i in range(len(dataset)):
...
But this returns the error 'DataLoader' object is not subscriptable.
英文:
I have a PyTorch DataLoader and want to retrieve the Dataset object that the loader wraps around. If this is possible, how? Or does the dataset object only exist for pre-loaded datasets on torch?
The end goal is to easily integrate data in dataloader format into code setup for a dataset format (e.g. CIFAR10).
Where in the original code there is:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10"
return _cifar10(split)
def _cifar10(split: str) -> Dataset:
if split == "train":
return datasets.CIFAR10("./dataset_cache", train=True, download=True)
dataset = get_dataset("CIFAR10", train)
for i in range(len(dataset)):
...
I have tried importing the whole dataset at once:
from torchvision import transforms, datasets
from typing import *
import torch
import os
from torch.utils.data import Dataset
def get_dataset(dataset, split):
if dataset == "CIFAR10"
return _cifar10(split)
elif dataset == "mydataset"
return _mydataset(split)
def _mydataset(split: str) -> Dataset:
files =
total_num_images = 0
for file in files:
number_images = len([name for name in os.listdir(database_directory +
'/' + split + '/' + file)])
total_num_images += number_images
if split == "train":
mydataset = torch.utils.data.DataLoader(
datasets.ImageFolder(dataset_directory + '/train'),batch_size=total_num_images)
return mydataset
dataset = get_dataset("mydataset", train)
for i in range(len(dataset)):
...
But this returns the error 'DataLoader' object is not subscriptable.
答案1
得分: 2
你可以访问 data.DataLoader
上的 dataset
属性,以获取其底层的 data.Dataset
对象。如在此处的源代码中所见 here。
英文:
You can access the dataset
attribute on data.DataLoader
to get its underlying data.Dataset
object. As seen in the source code here.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论