如何使用分组列拆分数据为训练集和测试集

huangapple go评论55阅读模式
英文:

How to split data into train and test using groupby column

问题

train:
Question | Hint | Cluster Label|
q1 |q1_h1 |1
q1 |q1_h2 |1
q1 |q1_h3 |1
q2 |q2_h1 |2
q2 |q2_h2 |2

test:
Question | Hint | Cluster Label|
q3 |q3_h1 |1
q4 |q4_h1 |2
q4 |q4_h2 |2

英文:

Let's say I have a dataframe that looks something like this:<br>
The following table is an example, I have like 120000 questions <br> <br>
Question | Hint | Cluster Label|
<br>q1 |q1_h1 |1
<br>q1 |q1_h2 |1
<br>q1 |q1_h3 |1
<br>q2 |q2_h1 |2
<br>q2 |q2_h2 |2
<br>q3 |q3_h1 |1
<br>q4 |q4_h1 |2
<br>q4 |q4_h2 |2

I want to groupby question and split dataframe into train and test such that associated question and hints are captured together and stratified on label.
So output that I require would be:

train:
<br>
Question | Hint | Cluster Label|
<br>q1 |q1_h1 |1
<br>q1 |q1_h2 |1
<br>q1 |q1_h3 |1
<br>q2 |q2_h1 |2
<br>q2 |q2_h2 |2

test:<br>
Question | Hint | Cluster Label|<br>
q3 |q3_h1 |1
<br>q4 |q4_h1 |2
<br>q4 |q4_h2 |2

答案1

得分: 1

你可以根据Hint列的值简单地拆分DataFrame:

df_train = df[(df['Hint'].str.contains('q1')) | (df['Hint'].str.contains('q2'))]

同样适用于df_test。

英文:

You can simply split the DataFrame according to the value in Hint:

df_train= df[(df[&#39;Hint&#39;].str.contains(&#39;q1&#39;)) | (df[&#39;Hint&#39;].str.contains(&#39;q2&#39;))]

and similarly for df_test

答案2

得分: 0

看起来你需要使用GroupKFoldStratifiedGroupKFold

根据用户手册GroupKFold是"k-fold"的一种变体,确保相同的组在测试集和训练集中都不会出现。

要使用它,你可以像正常情况下一样调用构造函数:

gkf = GroupKFold(n_splits=5)

当你调用gkfsplit方法时,你需要指定要分组的变量(在你的情况下是'Question')。

如果你在GridSearchCV或类似的情境中使用它,你需要在调用GridSearchCV时将分组变量指定为'groups'。参见此前的答案

英文:

Looks like you need to use GroupKFold or StratifiedGroupKFold.

From the user manual, GroupKFold "is a variation of k-fold which ensures that the same group is not represented in both testing and training sets."

To use it, you call the constructor as normal:

gkf = GroupKFold(n_splits = 5)

and when you call the split method of gkf you specify the variable to group on (in your case 'Question').

If you're using it in GridSearchCV or similar, you specify the group in as the 'groups' variable in the call to GridSearchCV. See previous answer here.

huangapple
  • 本文由 发表于 2023年2月6日 05:48:58
  • 转载请务必保留本文链接:https://go.coder-hub.com/75355754.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定