如何在Pyspark中迭代地评估当前行的前一行数值。

huangapple go评论55阅读模式
英文:

how to assess previous row values for current row iteratively in Pyspark

问题

data = [(1, None), (2, 3), (3, None), (4, None), (5, None), (6, None), (7, 6), (8, None), (9, None)]
df = spark.createDataFrame(data, ["col_a", "col_b"])

window_spec = Window.orderBy("col_a")

df = df.withColumn('col_b', 
              when(
                  (F.col('col_b').isNull()) & (F.lag(F.col('col_b')).over(window_spec) != 0),
                  (F.coalesce(F.lag(F.col('col_b')).over(window_spec), 0) - 1)
                  ).otherwise(F.col('col_b'))
             )

df.show()
英文:

I am trying to reduce by 1 and assign the last row of column B to the current row until I reach 0 or a non-null row.

col_a col_b
1 null
2 3
3 null
4 null
5 null
6 null
7 6
8 null
9 null

Here's what I'm hoping to get.

col_a col_b
Cell 1 null
Cell 2 3
Cell 3 2
Cell 4 1
Cell 5 0
Cell 6 null
Cell 7 6
Cell 8 5
Cell 9 4

My code so far

data = [(1, None), (2, 3), (3, None), (4, None), (5, None), (6, None), (7, 6), (8, None), (9, None)]
df = spark.createDataFrame(data, ["col_a", "col_b"])

window_spec = Window.orderBy("col_a")

df = df.withColumn('col_b', 
              when(
                  (F.col('col_b').isNull()) & (F.lag(F.col('col_b')).over(window_spec) != 0),
                  (F.lag(F.col('col_b')).over(window_spec) - 1)
                  ).otherwise(F.col('col_b'))
             )

df.show()

My code only gets the last value that is currently present, and not those recently assigned. How do I get around this.

I know I can collect this column, process it, and add it back to the df but this is currently too computationally expensive for me since the dataset is very large.

答案1

得分: 1

以下是翻译好的部分:

  • reptval字段将列值重复直到出现非空值。
  • changeflag标记发生值变化(相对于前一行)的行。
  • cflag_csum字段是变化标志的累积和。这是为了创建分区。
  • rn是每个cflag_csum分区的行号(从0开始)。请在partitionBy中添加实际的分区列。
  • 然后,需要将重复的值与生成的行号相减。
英文:

see following example that can help. note that i've created separate column for each calculation, but you can merge a few to make it concise.

from pyspark.sql.window import Window as wd
import pyspark.sql.functions as func
import sys

wSpec = wd.partitionBy('id').orderBy('c1')

data_sdf. \
    withColumn('id', func.lit('dummy_id')). \
    withColumn('reptval',
               func.last('c2', ignorenulls=True).over(wSpec.rowsBetween(-sys.maxsize, 0))
               ). \
    withColumn('changeflag',
               func.coalesce(func.col('reptval') != func.lag('reptval').over(wSpec), func.lit(True)).cast('int')
               ). \
    withColumn('cflag_csum',
               func.sum('changeflag').over(wSpec.rowsBetween(-sys.maxsize, 0))
               ). \
    withColumn('rn',
               func.row_number().over(wd.partitionBy('id', 'cflag_csum').orderBy('c1')) - 1
               ). \
    withColumn('c_interim', func.col('reptval') - func.col('rn')). \
    withColumn('c_fnl',
               func.when(func.col('c_interim') < 0, func.lit(None)).
               otherwise(func.col('c_interim'))
               ). \
    show()

# +---+----+--------+-------+----------+----------+---+---------+-----+
# | c1|  c2|      id|reptval|changeflag|cflag_csum| rn|c_interim|c_fnl|
# +---+----+--------+-------+----------+----------+---+---------+-----+
# |  1|null|dummy_id|   null|         1|         1|  0|     null| null|
# |  2|   3|dummy_id|      3|         1|         2|  0|        3|    3|
# |  3|null|dummy_id|      3|         0|         2|  1|        2|    2|
# |  4|null|dummy_id|      3|         0|         2|  2|        1|    1|
# |  5|null|dummy_id|      3|         0|         2|  3|        0|    0|
# |  6|null|dummy_id|      3|         0|         2|  4|       -1| null|
# |  7|   6|dummy_id|      6|         1|         3|  0|        6|    6|
# |  8|null|dummy_id|      6|         0|         3|  1|        5|    5|
# |  9|null|dummy_id|      6|         0|         3|  2|        4|    4|
# +---+----+--------+-------+----------+----------+---+---------+-----+
  • reptval field repeats the column value till the non-null value.
  • changeflag flags the row where a change in value (w.r.t prev row) occurs
  • cflag_csum field is the cumulative sum of the change flag. this is done to create partitions
  • rn is the row number (starting 0) for each cflag_csum partition. please add the actual partition columns as well within the partitionBy.
  • then, all that's needed is to subtract the repeated values with the generated row number

答案2

得分: 0

你可以使用函数row_number(),并减去(-1,因为它从1开始),而不是从前一行减去1。但是,为了使此方法有效,您需要对窗口进行分区。要获得分组,添加一列具有累积和的值,当col_b中有非空值时增加(您可以找到其他stackoverflow问题,显示如何执行累积求和)。希望这对您有所帮助。

英文:

You can use the function row_number(), and subtract that (-1 because it starts at 1) instead of subtracting 1 from the previous row. However, for this to work you need to partition your window. To get the grouping, add a column with a cumulative sum, that increases when there's a non-null value in col_b (you can find other stackoverflow questions that show how to do a cumulative sum). Hope this helps.

huangapple
  • 本文由 发表于 2023年7月18日 15:39:47
  • 转载请务必保留本文链接:https://go.coder-hub.com/76710496.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定