如何创建一个具有多级索引的torch数据集?

huangapple go评论46阅读模式
英文:

How can I have a multi-indexed torch Dataset?

问题

我有几个时间序列我想用循环神经网络进行分类我想让 get_item 方法接受一个字典或元组指定时间索引和序列索引我在这里实现了这个功能

我的问题是如何在这上面使用 DataLoader我假设我需要一个自定义的采样器但我尝试搜索了一下不知道需要实现什么
英文:

I have several timeseries that I want to classify with RNNs. I would like to have the get_item method take in either a dict or tuple, specifying both time idx and sequence idx. I have implemented that here.

class MLDataWrangler(zrfr.ZRFReader, torch.utils.data.Dataset):
   ...
   ...
  
  def __len__(self) -> int:
    # return len(self._all_flexions) # number of timeseries's
    return len(self._all_flexions), self._sequence_length # can I return multiple lengths?

  def __getitem__(self, idx) -> Tuple[np.ndarray, Dict[str, bool]]:
    if isinstance(idx, dict):
      frames = idx["frames"]
      sequence = idx["flexion"]
    elif hasattr(idx, "__iter__"):
      frames = idx[1]
      sequence = idx[0]
    

    zrf_path = self._all_flexions[sequence]
    video_pixels = self.read(zrf_path).data


    # signal level (input)
    if self._signame == SIGNAL_NAMES[1]:
      img_pixels = video_pixels.zsig[frames].astype(np.float32) # new since stacked
    elif self._signame == SIGNAL_NAMES[0]:
      img_pixels = video_pixels.rfsig[frames].astype(np.float32) # new since stacked


    # get finger ground truth
    labels = self._data_labels[sequence][1] # since stacked
    labels = np.array([labels[fing] for fing in FINGER_NAMES]).astype(np.float32) # convert dict to np array

    # transform as necessary
    if self.transform:
      img_pixels = self.transform(img_pixels)
    if self.target_transform:
      labels = self.target_transform(labels)
    return img_pixels.astype(np.float32), labels

My question is how do I use DataLoader on this? I assume I need a custom Sampler but I've tried searching and can't figure out what I need to implement.

答案1

得分: 1

TLDR:因为Python的可迭代对象本质上是一维的,所以没有一个很好的方法可以直接实现这个,但你可能可以通过其他方法实现相同的结果。

从文档中可以看到,“每个Sampler子类都必须提供一个__iter__()方法”。现在,你提出的数据集有两个维度/索引,这意味着你需要将其展平(即枚举所有可能的时间索引和序列索引组合,然后将它们映射到一个一维可迭代对象(类似于列表)。这可能是我推荐的做法。你也可以在__getitem__中随机选择一个时间索引以消除二维索引。你要做的具体操作可能取决于你打算如何进行训练。

例如,这里是实现了2D列表的展平枚举:

class MLDataWrangler(zrfr.ZRFReader, torch.utils.data.Dataset):

    def __init__(self,...):
        self.idx_map = []
        for i in range(len(self.all_flexions)):
            for j in range(len(self.all_flexions[i])):
                 self.idx_map.append([i,j])
        ...

    def __getitem(self,single_idx):
        idx = self.idx_map[single_idx]
        # 现在你可以像以前一样调用idx[0],idx[1]
        ...
英文:

TLDR: there's not a good way to do exactly this beceause python iterables are inherently 1D, but you can probably accomplish the same outcome with other approaches.

From docs, "Every Sampler subclass has to provide an __iter__() method". Now, your proposed dataset has two dimensions/indices, meaning you'd either need to flatten it (i.e. enumerate all possible time idx and sequence idx combinations and then map these to a 1-D iterable (list-like). That's probably how I'd recommend doing it. You could also select a random time_idx in __getitem__ to eliminate the 2-D indexing. Exactly what you want to do will probably depend on how exactly you intend to train.

For example, here's the flattened enumeration of the 2D list implemented:

class MLDataWrangler(zrfr.ZRFReader, torch.utils.data.Dataset):

    def __init__(self,...):
        self.idx_map = []
        for i in range(len(self.all_flexions)):
            for j in range(len(self.all_flexions[i])):
                 self.idx_map.append([i,j])
        ...

    def __getitem(self,single_idx):
        idx = self.idx_map[single_idx]
        # now you can call idx[0], idx[1] as before
        ...

huangapple
  • 本文由 发表于 2023年2月8日 09:41:42
  • 转载请务必保留本文链接:https://go.coder-hub.com/75380629.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定