如何平衡 PyTorch 数据集?

huangapple go评论60阅读模式
英文:

How to balance a PyTorch dataset?

问题

我有一个不平衡的PyTorch数据集。
A和V样本的数量远远低于其他样本
我想平衡我的数据集,即使我必须删除属于占主导地位类别的样本。我该如何做呢?

现在,我只是删除某些类别的样本,如果它们的数量超过了某个固定值。这在技术上很复杂,也不方便。也许有一些sklearn或PyTorch的方法可以使这个算法更容易实现?

英文:

I have an imbalanced PyTorch dataset.
The number of A and V samples is much lower than the others
I would like to ballanase my dataset, even if I have to delete samples that belong to the prevailing class. How can do it?

Now I just remove samples of certain classes if their number exceeds some fixed value. This is technically complicated and not convenient. Maybe there is some sklearn or PyTorch method that makes this algorithm much easier to implement?

答案1

得分: 0

不建议从现有类别中移除样本的策略:

  1. 丢失重要信息,
  2. 可能导致模型对少数类别产生偏见。

相反,有几种策略可用于平衡数据集,包括:

  1. 过采样:为少数类别生成新样本,以增加它们在数据集中的表示。这可以通过技术来完成,例如:

    a. 合成少数过采样技术(SMOTE)
    b. 自适应合成采样(ADASYN)。

  2. 欠采样(您正在进行的操作):减少多数类别的样本数量,以与少数类别的样本数量相匹配。这可以通过技术来完成,例如:

    a. 随机欠采样
    b. Tomek链接。

  3. 过采样和欠采样的结合:这涉及使用过采样和欠采样技术的组合来平衡数据集。

在PyTorch中有几种方法可用于帮助平衡数据集:

  1. WeightedRandomSampler:此采样器允许您为每个类别指定权重,可用于过采样少数类别或欠采样多数类别。
  2. DataLoader:此类提供了几个选项用于洗牌和批处理数据,可以确保每个批次包含各类别的平衡表示。
英文:

Removing samples from the prevailing classes is not a recommended strategy:

  1. Loss of important information,
  2. Might be incur bias in the model towards the minority classes.

Instead, there are several strategies you can use to balance the dataset, including:

  1. Oversampling: Generating new samples for the minority classes to increase their representation in the dataset. This can be done through techniques such as:

    a. Synthetic Minority Over-sampling Technique (SMOTE)
    b. Adaptive Synthetic Sampling (ADASYN).

  2. Under-sampling (which you are doing): Reducing the number of samples for the majority classes to match the number of samples for the minority classes. This can be done through techniques such as:

    a. Random Undersampling
    b. Tomek Links.

  3. Combination of oversampling and under-sampling: This involves using a combination of oversampling and under-sampling techniques to balance the dataset.

There are several methods available in both PyTorch to help balance the dataset:

  1. WeightedRandomSampler: This sampler allows you to specify weights for each class, which can be used to oversample the minority classes or undersample the majority classes.
  2. DataLoader: This class provides several options for shuffling and batching the data, which can help ensure that each batch contains a balanced representation of the classes.

huangapple
  • 本文由 发表于 2023年3月10日 00:37:34
  • 转载请务必保留本文链接:https://go.coder-hub.com/75687534.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定