英文:
Aggregate ArrayType column to get another ArrayType column without UDF
问题
我想做的是找到lis列的中位数值,但是要按某些列分组。就像我们在数字列上使用*percentile_approx()*一样,但这里我希望对每个组的联合列表计算中位数。
期望的输出是:
df1 = df.groupBy('department').agg(desired_function(col('lis')).alias('lis_median'))
给出以下结果:
+----------+----------+
|department|lis_median|
+----------+----------+
| Sales| 3|
| Finance| 6|
+----------+----------+
我尝试过这样做:
df1 = df.groupBy('department').agg(collect_list(col('lis')).alias('lis_median'))
至少可以得到分组后的列表,但它只生成了一个列表的列表。
我想找到一种实现这一目标的方法。我不想使用UDF,因为担心执行速度问题。
英文:
I have a dataframe like this, that has one ArrayType column
simpleData = [("202305","Sales","NY",[1,2,3]),
("202306","Sales","NY",[4,2,3]),
("202305","Sales","CA",[4,5,3]),
("202306","Finance","CA",[4,5,6]),
("202305","Finance","NY",[5,6,7]),
("202306","Finance","NY",[6,7,8]),
]
schema = ["month","department","state","lis"]
df = spark.createDataFrame(data=simpleData, schema = schema)
What I would like to do is find the median value of the lis column, but grouped by some columns. Pretty much like what we would do with percentile_approx() over a numeric column, but here I want that to be done such that the median is computed on the unioned list for each group.
Desired output:
df1 = df.groupBy('department').agg(desired_function(col('lis')).alias('lis_median'))
gives
+----------+----------+
|department|lis_median|
+----------+----------+
| Sales| 3|
| Finance| 6|
+----------+----------+
I tried doing
df1 = df.groupBy('department').agg(collect_list(col('lis')).alias('lis_median'))
to atleast get the aggregated list over the group, but it only generates a list of list.
I'd like some way to implement this. I dont want to use UDF due to execution speed concerns
答案1
得分: 1
尝试使用数组的高阶函数,即 flatten
和 array_sort
。
- 如果数组的大小是偶数,则从中间值和中间值 - 1/2 获取中位数的值。
- 如果数组大小是奇数,则将中值作为中位数。
示例:
simpleData = [("202305", "Sales", "NY", [1, 2, 3]),
("202306", "Sales", "NY", [4, 2, 3]),
("202305", "Sales", "CA", [4, 5, 3]),
("202306", "Finance", "CA", [4, 5, 6]),
("202305", "Finance", "NY", [5, 6, 7]),
("202306", "Finance", "NY", [6, 7, 8]),
]
schema = ["month", "department", "state", "lis"]
df = spark.createDataFrame(data=simpleData, schema=schema)
df.groupBy("department").agg(flatten(collect_set(col("lis"))).alias("cs")).\
withColumn('col', array_sort(expr('filter(cs, x -> x is not null)'))).\
withColumn('lis_median', when(size(col('cs')) % 2 == 0,
(expr('col[int(size(cs)/2)]') + expr('col[int(size(cs)/2)-1]'))/2
).otherwise(expr('col[int(size(cs)/2)]'))).\
drop(*['col', 'cs']).\
show(10, False)
结果如下:
+----------+----------+
|department|lis_median|
+----------+----------+
|Sales |3.0 |
|Finance |6.0 |
+----------+----------+
注意:这是代码的翻译,不包括解释或其他内容。
英文:
Try with higher order functions for array i.e. flatten,array_sort
for this case.
- if the size of array is even then get the median value from midvalue + midvalue -1/2
- if size of array is odd then get the mid value as the median
Example:
simpleData = [("202305","Sales","NY",[1,2,3]),
("202306","Sales","NY",[4,2,3]),
("202305","Sales","CA",[4,5,3]),
("202306","Finance","CA",[4,5,6]),
("202305","Finance","NY",[5,6,7]),
("202306","Finance","NY",[6,7,8]),
]
schema = ["month","department","state","lis"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.groupBy("department").agg(flatten(collect_set(col("lis"))).alias("cs")).\
withColumn('col', array_sort(expr('filter(cs, x -> x is not null)'))).\
withColumn('lis_median', when(size(col('cs')) % 2 == 0,
(expr('col[int(size(cs)/2)]') + expr('col[int(size(cs)/2)-1]'))/2
).otherwise(expr('col[int(size(cs)/2)]'))).\
drop(*['col','cs']).\
show(10,False)
#+----------+----------+
#|department|lis_median|
#+----------+----------+
#|Sales |3.0 |
#|Finance |6.0 |
#+----------+----------+
答案2
得分: 1
我同意 @notNull,将列表展开并根据大小获取中位数。
以下是你也可以遵循的一种方法。
你可以使用 posexplode
函数,该函数提供列表的位置和值,
然后按如下方式进行聚合并直接找到中位数。
display(df.selectExpr("*", "posexplode(lis) as (pos, lis_value)").groupBy('department').agg(median("lis_value").alias('lis_median')))
英文:
I agree with @notNull, flatten the list and get median according to the size.
Below is one of the approaches you can also follow.
You can use posexplode
function which gives position and value of that list,
and do aggregate and find median directly as below.
display(df.selectExpr("*", "posexplode(lis) as (pos, lis_value)").groupBy('department').agg(median("lis_value").alias('lis_median')))
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论