I have a balanced dataset, after I split it to train & test set, the test set is imbalance, what is the reason?

huangapple go评论78阅读模式
英文:

I have a balanced dataset, after I split it to train & test set, the test set is imbalance, what is the reason?

问题

我的整个数据集有8个类别,每个类别有100个主题。在进行分类时,我将其分割,测试集不平衡。

我进行了分割:

_train, X_test, y_train, y_test = train_test_split(data.iloc[:, :-1], data.iloc[:, -1], test_size=0.2, random_state=42)

RFC的混淆矩阵是
在此输入图片描述
例如:第二类只有10个样本,为什么会这样,我应该进行平衡吗?

谢谢大家。

英文:

My whole dataset is 8 classes, 100 subjects for each class.
When I do the classification, I split it, the test set is imbalance.

I do split:

_train, X_test, y_train, y_test = train_test_split(data.iloc[:, :-1], data.iloc[:, -1], test_size=0.2, random_state=42)

The confusion matrix of RFC is
enter image description here
eg: There's only 10 from the second class, why and should I balance it?

Thank you all.

答案1

得分: 2

train_test_split 函数来自 scikit-learn,以随机方式分割类别数。

要保持测试集和训练集中类别数量相等,您需要添加 "stratify" 参数。

查看文档

_train, X_test, y_train, y_test = train_test_split(data.iloc[:, :-1], 
    data.iloc[:, -1], 
    test_size=0.2, 
    random_state=42,
    stratify=data.iloc[:, -1])
英文:

The train_test_split function from scikit-learn splits classes number in random fashion.

To keep the number of classes equal in the test and train set you need to add the "stratify" argument.

See documentation

_train, X_test, y_train, y_test = train_test_split(data.iloc[:, :-1], 
    data.iloc[:, -1], 
    test_size=0.2, 
    random_state=42,
    stratify=data.iloc[:, -1])

huangapple
  • 本文由 发表于 2023年3月9日 17:08:42
  • 转载请务必保留本文链接:https://go.coder-hub.com/75682459.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定