英文:
pyspark dataframe limiting on multiple columns
问题
我想知道是否有人能指导我解决以下问题。在一个相当大的pyspark数据框中,有大约50多列,其中两列表示"制造商"和"型号",类似于:
如何得知每个品牌的前两个型号是什么。我可以对这两列进行分组并添加计数,没有问题,但如何限制(或筛选)结果呢?也就是说,如何保留每个品牌的(最多)两个最受欢迎的型号并删除其余的型号?
无论我尝试什么,最终都要手动迭代原始数据框中存在的品牌。是否有其他方法?
英文:
I wonder if anyone point me in the right direction with the following problem. In a rather large pyspark dataframe with about 50 odd columns, two of them represent say 'make' and 'model'. Something like
21234234322(unique id) .. .. .. Nissan Navara .. .. ..
73647364736 .. .. .. BMW X5 .. .. ..
What I would like to know is what the top 2 models per brand are. I can groupby both columns and add a count no problem, but how do I then limit (or filter) that result? I.e. how do I keep the (up to) 2 most popular models per brand and remove the rest?
Whatever I try, I end up iterating over the brands that exist in the original dataframe manually. Is there another way?
答案1
得分: 1
你可以使用带有窗口和filter()
的rank()
:
from pyspark.sql import functions as func
from pyspark.sql.window import Window
df = spark.createDataFrame(
[
('a', 1, 1),
('a', 1, 2),
('a', 1, 3),
('a', 2, 1),
('a', 2, 2),
('a', 3, 1)
],
schema=['col1', 'col2', 'col3']
)
df.printSchema()
df.show(10, False)
df.groupBy(
'col1', 'col2'
).agg(
func.countDistinct('col3').alias('dcount')
).withColumn(
'rank', func.rank().over(Window.partitionBy('col1').orderBy(func.desc('dcount')))
).filter(
func.col('rank')<=2
).show(
10, False
)
英文:
You can use a rank()
with Window and filter()
:
from pyspark.sql import functions as func
from pyspark.sql.window import Window
df = spark.createDataFrame(
[
('a', 1, 1),
('a', 1, 2),
('a', 1, 3),
('a', 2, 1),
('a', 2, 2),
('a', 3, 1)
],
schema=['col1', 'col2', 'col3']
)
df.printSchema()
df.show(10, False)
+----+----+----+
|col1|col2|col3|
+----+----+----+
|a |1 |1 |
|a |1 |2 |
|a |1 |3 |
|a |2 |1 |
|a |2 |2 |
|a |3 |1 |
+----+----+----+
where col1
and col2
are grouping columns and col3
is your unique id:
df.groupBy(
'col1', 'col2'
).agg(
func.countDistinct('col3').alias('dcount')
).withColumn(
'rank', func.rank().over(Window.partitionBy('col1').orderBy(func.desc('dcount')))
).filter(
func.col('rank')<=2
).show(
10, False
)
+----+----+------+----+
|col1|col2|dcount|rank|
+----+----+------+----+
|a |1 |3 |1 |
|a |2 |2 |2 |
+----+----+------+----+
You can use rank()
after the grouping and aggregation to filter out the top 2 value in your each group (col1
).
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论