返回自定义数据集中的多个图像

huangapple go评论70阅读模式
英文:

Return multiple images from a custom dataset

问题

我有一组从YouTube视频中获取的帧的袋子,并希望在迭代数据集时返回整个帧袋。我的自定义数据集类如下:

dataset_path = Path('/content/VideoClassificationDataset')

class VideoDataset(Dataset):
  def __init__(self, dictionary, transform = None):
    self.l_dict = list(dictionary.items())
    self.transform = transform
  
  def __len__(self):
    return len(self.l_dict)

  def __get_item__(self, index):
    item = self.l_dict[index]

    images_path = item[0]
    images = [Image.open(f'{dataset_path}/{images_path}/{image}') for image in os.listdir(f'{dataset_path}/{images_path}')]

    y_labels = torch.tensor(item[1])

    if self.transform:
      for image in images: self.transform(image)

    return images, y_labels

此外,我还完成了以下操作:

def spit_train(train_data, perc_val_size):
  train_size = len(train_data)
  val_size = int((train_size * perc_val_size) // 100)
  train_size -= val_size

  return random_split(train_data, [int(train_size), int(val_size)])

train_data, val_data = spit_train(VideoDataset(train_dict, transform=train_transform()), 20)
test_data = VideoDataset(dictionary=test_dict, transform=test_transform())

BATCH_SIZE = 16
NUM_WORKERS = os.cpu_count()

def generate_dataloaders(train_data, test_data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):

  train_dl = DataLoader(dataset = train_data, 
                                batch_size = BATCH_SIZE,
                                num_workers = NUM_WORKERS,
                                shuffle = True)

  val_dl = DataLoader dataset = val_data, 
                                batch_size = BATCH_SIZE,
                                num_workers = NUM_WORKERS,
                                shuffle = True)

  test_dl = DataLoader(dataset = test_data, 
                              batch_size = BATCH_SIZE, 
                              num_workers = NUM_WORKERS, 
                              shuffle = False)

  return train_dl, val_dl, test_dl

train_dl, val_dl, test_dl = generate_dataloaders(train_data, test_data)

train_dicttest_dict 是包含每个帧袋的路径作为键和标签列表作为值的字典,如下所示:

{'train/iqGq-8vHEJs/bag_of_shots0': [2],
 'train/iqGq-8vHEJs/bag_of_shots1': [2],
 'train/gnw83R8R6jU/bag_of_shots0': [119],
 'train/gnw83R8R6jU/bag_of_shots1': [119],
...
}

问题出在当我尝试查看数据加载器包含的内容时:

train_features_batch, train_labels_batch = next(iter(train_dl))
print(train_features_batch.shape, train_labels_batch.shape)

val_features_batch, val_labels_batch = next(iter(val_dl))
print(val_features_batch.shape, val_labels_batch.shape)

我得到了以下错误:

NotImplementedError: DataLoader worker process 0 中发生了 NotImplementedError。
原始跟踪信息(最近的调用最先)如下所示:
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py", line 295, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py", line 53, in __getitem__
    raise NotImplementedError
NotImplementedError

在这一点上,我不确定是否可以在我的 __get_item__() 函数中返回一组图像。

英文:

I have a set of bags of frames obtained from YouTube videos and I would like to return an entire bag when I iterate my dataset. My custom dataset class is the following one:

dataset_path = Path(&#39;/content/VideoClassificationDataset&#39;)

class VideoDataset(Dataset):
  def __init__(self, dictionary, transform = None):
    self.l_dict = list(dictionary.items())
    self.transform = transform
  
  def __len__(self):
    return len(self.l_dict)

  def __get_item__(self, index):
    item = self.l_dict[index]

    images_path = item[0]
    images = [Image.open(f&#39;{dataset_path}/{images_path}/{image}&#39;) for image in os.listdir(f&#39;{dataset_path}/{images_path}&#39;)]
    
    y_labels = torch.tensor(item[1])

    if self.transform:
      for image in images: self.transform(image)
    
    return images, y_labels

Moreover I’ve done also

def spit_train(train_data, perc_val_size):
  train_size = len(train_data)
  val_size = int((train_size * perc_val_size) // 100)
  train_size -= val_size

  return random_split(train_data, [int(train_size), int(val_size)])

train_data, val_data = spit_train(VideoDataset(train_dict, transform=train_transform()), 20)
test_data = VideoDataset(dictionary=test_dict, transform=test_transform())


BATCH_SIZE = 16
NUM_WORKERS = os.cpu_count()

def generate_dataloaders(train_data, test_data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):

  train_dl = DataLoader(dataset = train_data, 
                                batch_size = BATCH_SIZE,
                                num_workers = NUM_WORKERS,
                                shuffle = True)

  val_dl = DataLoader(dataset = val_data, 
                                batch_size = BATCH_SIZE,
                                num_workers = NUM_WORKERS,
                                shuffle = True)

  test_dl = DataLoader(dataset = test_data, 
                              batch_size = BATCH_SIZE, 
                              num_workers = NUM_WORKERS, 
                              shuffle = False) # don&#39;t need to shuffle testing data when we are considering time series dataset

  return train_dl, val_dl, test_dl

train_dl, val_dl, test_dl = generate_dataloaders(train_data, test_data)

The train_dict and test_dict are dictionaries that contains the path of each bag of shots as key and the list of labels as value, like so:

{&#39;train/iqGq-8vHEJs/bag_of_shots0&#39;: [2],
 &#39;train/iqGq-8vHEJs/bag_of_shots1&#39;: [2],
 &#39;train/gnw83R8R6jU/bag_of_shots0&#39;: [119],
 &#39;train/gnw83R8R6jU/bag_of_shots1&#39;: [119],
...
}

The point is that when I try to see what the dataloader contains:

train_features_batch, train_labels_batch = next(iter(train_dl))
print(train_features_batch.shape, train_labels_batch.shape)

val_features_batch, val_labels_batch = next(iter(val_dl))
print(val_features_batch.shape, val_labels_batch.shape)

I get:

NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File &quot;/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py&quot;, line 302, in _worker_loop
    data = fetcher.fetch(index)
  File &quot;/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py&quot;, line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File &quot;/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py&quot;, line 58, in &lt;listcomp&gt;
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File &quot;/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py&quot;, line 295, in __getitem__
    return self.dataset[self.indices[idx]]
  File &quot;/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataset.py&quot;, line 53, in __getitem__
    raise NotImplementedError
NotImplementedError

I’m not particularly sure If I can return a set of images in my get_item() function at this point.

答案1

得分: 2

函数名中有拼写错误,应该是__getitem__,而不是__get_item__

由于在您的自定义数据集类中未定义此函数,因此会使用基类(torch.utils.data.Dataset)中的函数,但该基类中并没有实现此功能,因为每个从该类继承的数据集都需要自行实现该功能。因此会出现NotImplementedError

关于此问题的更多文档可以在这里找到。

英文:

There is a typo in the function name, instead of __get_item__, the name should be __getitem__.

Since this is not defined in your custom dataset class, the function from the base class (torch.utils.data.Dataset) is used, which doesn't implement this since it needs to be implemented by each dataset that inherits from this class. So you get an NotImplementedError.

More documentation on this can be found here.

huangapple
  • 本文由 发表于 2023年2月18日 22:37:53
  • 转载请务必保留本文链接:https://go.coder-hub.com/75494056.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定