英文:
Python-polars: rolling_sum where the window_size from another column
问题
以下是代码部分的翻译:
Consider the following dataframe:
df = pl.DataFrame(
{
"date": pl.date_range(
low=datetime(2023, 2, 1),
high=datetime(2023, 2, 5),
interval="1d"),
"periods": [2, 2, 2, 1, 1],
"quantity": [10, 12, 14, 16, 18],
"calculate": [22, 26, 30, 16, 18]
}
)
"calculate"列是我想要的。这是通过rolling_sum完成的,其中window_size
参数是从"periods"列取得的,而不是一个固定的值。
我可以这样做(window_size=2):
df.select(pl.col("quantity").rolling_sum(window_size=2))
然而,当我尝试这样做时,会出现错误:
df.select(pl.col("quantity").rolling_sum(window_size=pl.col("periods")))
这是错误信息 -
TypeError: argument 'window_size': 'Expr' object cannot be converted to 'PyString'
我如何基于另一列的值传递window_size
呢?我也尝试使用groupie_rolling
,但也无法弄清楚。
英文:
Consider the following dataframe:
df = pl.DataFrame(
{
"date": pl.date_range(
low=datetime(2023, 2, 1),
high=datetime(2023, 2, 5),
interval="1d"),
"periods": [2, 2, 2, 1, 1],
"quantity": [10, 12, 14, 16, 18],
"calculate": [22, 26, 30, 16, 18]
}
)
The column calculate is what I want. This is done by a rolling_sum where the window_size
parameter is taken from the periods
column, rather than a fixed value.
I can do the following (window_size=2):
df.select(pl.col("quantity").rolling_sum(window_size=2))
However, I get an error when I try and do this:
df.select(pl.col("quantity").rolling_sum(window_size=pl.col("periods")))
This is the error -
TypeError: argument 'window_size': 'Expr' object cannot be converted to 'PyString'
How do I pass the value of window_size
based on another column? I also looked at using groupie_rolling
but could not figure it out as well.
答案1
得分: 1
以下是您要翻译的内容:
- It seems like this should be easier to do which suggests I may be missing something obvious.
- 似乎这应该更容易做到,这表明我可能漏掉了一些明显的东西。
- As a workaround - you could use the row count to generate row indexes for the windows.
- 作为一种解决方法 - 您可以使用行数来生成窗口的行索引。
- You could
.explode()
the window and use.take()
+.search_sorted()
to find the corresponding values.- 您可以使用
.explode()
函数来展开窗口,并使用.take()
和.search_sorted()
来查找相应的值。
- 您可以使用
.groupby()
can be used to combine the window values again.- 可以使用
.groupby()
来再次组合窗口的值。
- 可以使用
英文:
It seems like this should be easier to do which suggests I may be missing something obvious.
As a workaround - you could use the row count to generate row indexes for the windows.
(
df
.with_row_count()
.with_columns(
window =
pl.arange(
pl.col("row_nr"),
pl.col("row_nr") + pl.col("periods")))
)
shape: (5, 6)
┌────────┬─────────────────────┬─────────┬──────────┬───────────┬───────────┐
│ row_nr | date | periods | quantity | calculate | window │
│ --- | --- | --- | --- | --- | --- │
│ u32 | datetime[μs] | i64 | i64 | i64 | list[i64] │
╞════════╪═════════════════════╪═════════╪══════════╪═══════════╪═══════════╡
│ 0 | 2023-02-01 00:00:00 | 2 | 10 | 22 | [0, 1] │
│ 1 | 2023-02-02 00:00:00 | 2 | 12 | 26 | [1, 2] │
│ 2 | 2023-02-03 00:00:00 | 2 | 14 | 30 | [2, 3] │
│ 3 | 2023-02-04 00:00:00 | 1 | 16 | 16 | [3] │
│ 4 | 2023-02-05 00:00:00 | 1 | 18 | 18 | [4] │
└────────┴─────────────────────┴─────────┴──────────┴───────────┴───────────┘
You could .explode()
the window and use .take()
+ .search_sorted()
to find the corresponding values.
.groupby()
can be used to combine the window values again.
(
df
.with_row_count()
.with_columns(
window =
pl.arange(
pl.col("row_nr"),
pl.col("row_nr") + pl.col("periods")))
.explode("window")
.with_columns(
rolling =
pl.col("quantity")
.take(pl.col("row_nr").search_sorted("window")))
.groupby("row_nr", maintain_order=True)
.agg([
pl.exclude("rolling").first(),
pl.col("rolling").sum()
])
)
shape: (5, 7)
┌────────┬─────────────────────┬─────────┬──────────┬───────────┬────────┬─────────┐
│ row_nr | date | periods | quantity | calculate | window | rolling │
│ --- | --- | --- | --- | --- | --- | --- │
│ u32 | datetime[μs] | i64 | i64 | i64 | i64 | i64 │
╞════════╪═════════════════════╪═════════╪══════════╪═══════════╪════════╪═════════╡
│ 0 | 2023-02-01 00:00:00 | 2 | 10 | 22 | 0 | 22 │
│ 1 | 2023-02-02 00:00:00 | 2 | 12 | 26 | 1 | 26 │
│ 2 | 2023-02-03 00:00:00 | 2 | 14 | 30 | 2 | 30 │
│ 3 | 2023-02-04 00:00:00 | 1 | 16 | 16 | 3 | 16 │
│ 4 | 2023-02-05 00:00:00 | 1 | 18 | 18 | 4 | 18 │
└────────┴─────────────────────┴─────────┴──────────┴───────────┴────────┴─────────┘
答案2
得分: 1
很类似于@jqurious的,但(我认为)稍微简化了。
df.lazy() \
.with_row_count('i') \
.with_columns(
window =
pl.arange(
pl.col("i"),
pl.col("i") + pl.col("periods")),
qty=pl.col('quantity').list()
) \
.with_columns(
rollsum=pl.col('qty').arr.take(pl.col('window')).arr.sum()
) \
.select(pl.exclude(['window','qty','i'])) \
.collect()
它的工作原理与相同的概念,但它实际上是重新创建了整个“quantity”列作为列表,然后使用“window”列来过滤该列表的相应值并将它们相加。
另一种方法是只使用一个循环,这将更节省内存。
首先,你要获取所有唯一的“periods”值,然后在df中初始化一个用于rolling_sum的列,倒序排序,然后用每个周期的计算替换列。最后,将行放回原始顺序。
periods=df.get_column('periods').unique()
df=df.with_columns(pl.lit(None).cast(pl.Float64()).alias("rollsum")).sort('date',reverse=True)
for period in periods:
df=df.with_columns((pl.when(pl.col('periods')==period).then(pl.col('quantity').rolling_sum(window_size=period)).otherwise(pl.col('rollsum'))).alias('rollsum'))
df=df.sort('date')
df
英文:
Very similar to @jqurious's but (I think) a bit simplified
df.lazy() \
.with_row_count('i') \
.with_columns(
window =
pl.arange(
pl.col("i"),
pl.col("i") + pl.col("periods")),
qty=pl.col('quantity').list()
) \
.with_columns(
rollsum=pl.col('qty').arr.take(pl.col('window')).arr.sum()
) \
.select(pl.exclude(['window','qty','i'])) \
.collect()
It works on the same concept but it just essentially recreates the whole quantity
column as a list then using the window
column to filter that list to the corresponding values and sum them up.
Another method is to just use a loop which will be more memory efficient.
First, you want to get all the unique values of periods, then initialize a column in the df for the rolling_sum, reverse the order, and then replace the column with a calculation for every period. At the end, put the rows back in the original order.
periods=df.get_column('periods').unique()
df=df.with_columns(pl.lit(None).cast(pl.Float64()).alias("rollsum")).sort('date',reverse=True)
for period in periods:
df=df.with_columns((pl.when(pl.col('periods')==period).then(pl.col('quantity').rolling_sum(window_size=period)).otherwise(pl.col('rollsum'))).alias('rollsum'))
df=df.sort('date')
df
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论