pyspark DataFrame 多列限制

huangapple go评论59阅读模式
英文:

pyspark dataframe limiting on multiple columns

问题

我想知道是否有人能指导我解决以下问题。在一个相当大的pyspark数据框中,有大约50多列,其中两列表示"制造商"和"型号",类似于:

如何得知每个品牌的前两个型号是什么。我可以对这两列进行分组并添加计数,没有问题,但如何限制(或筛选)结果呢?也就是说,如何保留每个品牌的(最多)两个最受欢迎的型号并删除其余的型号?

无论我尝试什么,最终都要手动迭代原始数据框中存在的品牌。是否有其他方法?

英文:

I wonder if anyone point me in the right direction with the following problem. In a rather large pyspark dataframe with about 50 odd columns, two of them represent say 'make' and 'model'. Something like

21234234322(unique id) .. .. .. Nissan  Navara .. .. ..
73647364736            .. .. .. BMW     X5     .. .. ..

What I would like to know is what the top 2 models per brand are. I can groupby both columns and add a count no problem, but how do I then limit (or filter) that result? I.e. how do I keep the (up to) 2 most popular models per brand and remove the rest?

Whatever I try, I end up iterating over the brands that exist in the original dataframe manually. Is there another way?

答案1

得分: 1

你可以使用带有窗口和filter()rank()

from pyspark.sql import functions as func
from pyspark.sql.window import Window

df = spark.createDataFrame(
    [
        ('a', 1, 1),
        ('a', 1, 2),
        ('a', 1, 3),
        ('a', 2, 1),
        ('a', 2, 2),
        ('a', 3, 1)
    ],
    schema=['col1', 'col2', 'col3']
)

df.printSchema()
df.show(10, False)

df.groupBy(
    'col1', 'col2'
).agg(
    func.countDistinct('col3').alias('dcount')
).withColumn(
    'rank', func.rank().over(Window.partitionBy('col1').orderBy(func.desc('dcount')))
).filter(
    func.col('rank')<=2
).show(
    10, False
)
英文:

You can use a rank() with Window and filter():

from pyspark.sql import functions as func
from pyspark.sql.window import Window

df = spark.createDataFrame(
    [
        (&#39;a&#39;, 1, 1),
        (&#39;a&#39;, 1, 2),
        (&#39;a&#39;, 1, 3),
        (&#39;a&#39;, 2, 1),
        (&#39;a&#39;, 2, 2),
        (&#39;a&#39;, 3, 1)
    ],
    schema=[&#39;col1&#39;, &#39;col2&#39;, &#39;col3&#39;]
)

df.printSchema()
df.show(10, False)
+----+----+----+
|col1|col2|col3|
+----+----+----+
|a   |1   |1   |
|a   |1   |2   |
|a   |1   |3   |
|a   |2   |1   |
|a   |2   |2   |
|a   |3   |1   |
+----+----+----+

where col1 and col2 are grouping columns and col3 is your unique id:

df.groupBy(
    &#39;col1&#39;, &#39;col2&#39;
).agg(
    func.countDistinct(&#39;col3&#39;).alias(&#39;dcount&#39;)
).withColumn(
    &#39;rank&#39;, func.rank().over(Window.partitionBy(&#39;col1&#39;).orderBy(func.desc(&#39;dcount&#39;)))
).filter(
    func.col(&#39;rank&#39;)&lt;=2
).show(
    10, False
)
+----+----+------+----+
|col1|col2|dcount|rank|
+----+----+------+----+
|a   |1   |3     |1   |
|a   |2   |2     |2   |
+----+----+------+----+

You can use rank() after the grouping and aggregation to filter out the top 2 value in your each group (col1).

huangapple
  • 本文由 发表于 2023年5月22日 18:00:57
  • 转载请务必保留本文链接:https://go.coder-hub.com/76305020.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定