英文:
how to vectorize this function to pandas or polars
问题
大家好,我需要使用向量化操作来复制以下行为。 (目前使用cython编写,但仍然太慢,因为我在许多组上调用它)。 但是,我不确定是否可能实现这种行为。 感谢任何帮助。
Python版本/伪代码:
def calc_pulses(np.ndarray time, double pgt, double dt):
curr_pulse = 1
res = np.empty(len(time))
begin = time[0]
for k in range(len(time)):
if time[k] - begin <= pgt:
res[k] = curr_pulse
elif time[k] - begin > pgt+dt:
begin = time[k]
curr_pulse += 1
res[k] = curr_pulse
else:
res[k] = np.nan
return res
Cython版本:
cpdef np.ndarray[np.int] calc_pulses(np.ndarray time, double pgt, double dt):
cdef double begin
cdef int curr_pulse = 1
cdef Py_ssize_t k, n = len(time)
cdef np.ndarray[np.int_t, ndim=1] res = np.empty(n, dtype=np.int)
begin = time[0]
for k in range(n):
if time[k] - begin <= pgt:
res[k] = curr_pulse
elif time[k] - begin > pgt+dt:
begin = time[k]
curr_pulse += 1
res[k] = curr_pulse
else:
res[k] = np.nan
return res
英文:
Hello everyone I need to replicate the below behavior using vectorized operations. (currently written in cython and it is still too slow as I call it over many groups). However, I am not sure if this behavior is possible. Any help would be appreciated.
python version / pseudocode:
def calc_pulses(np.ndarray time, double pgt, double dt):
curr_pulse = 1
res = np.empty(len(time))
begin = time[0]
for k in range(len(time)):
if time[k] - begin <= pgt:
res[k] = curr_pulse
elif time[k] - begin > pgt+dt:
begin = time[k]
curr_pulse += 1
res[k] = curr_pulse
else:
res[k] = np.nan
return res
cython version:
cpdef np.ndarray[np.int] calc_pulses(np.ndarray time, double pgt, double dt):
cdef double begin
cdef int curr_pulse = 1
cdef Py_ssize_t k, n = len(time)
cdef np.ndarray[np.int_t, ndim=1] res = np.empty(n, dtype=np.int)
begin = time[0]
for k in range(n):
if time[k] - begin <= pgt:
res[k] = curr_pulse
elif time[k] - begin > pgt+dt:
begin = time[k]
curr_pulse += 1
res[k] = curr_pulse
else:
res[k] = np.nan
return res
答案1
得分: 4
在你付出这个努力之前,我建议查看Cython的注释输出。它通常能够识别出低效的部分。
以下是在这种情况下它显示的内容:
一行越黄,它与Python的交互越多。循环外有一堆黄线,这些不重要。循环内还有更多的黄线,这些才重要。
举个例子,让我们点击以下代码行
if time[k] - begin <= pgt:
它显示以下C代码用于实现:
__pyx_t_6 = __Pyx_GetItemInt(((PyObject *)__pyx_v_time), __pyx_v_k, Py_ssize_t, 1, PyInt_FromSsize_t, 0, 1, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_6);
__pyx_t_2 = PyFloat_FromDouble(__pyx_v_begin); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_4 = PyNumber_Subtract(__pyx_t_6, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
__Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__pyx_t_2 = PyFloat_FromDouble(__pyx_v_pgt); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_6 = PyObject_RichCompare(__pyx_t_4, __pyx_t_2, Py_LE); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__pyx_t_11 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely(__pyx_t_11 < 0)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
if (__pyx_t_11) {
/* … */
goto __pyx_L5;
}
这个大块代码在做什么?
- 它从
time
数组中获取一个项目。time
没有为其元素指定类型。因此,项目可能是一个double,但也可能是一个Fraction对象或Decimal对象。Cython必须通过调用__Pyx_GetItemInt()
来涵盖所有这些可能性。 - 因为
time[k]
可能不是一个double,所以它必须将begin
转换为double对象,以便进行减法运算。 - 由于
time[k]
可以是任何对象,该对象可能已覆盖减法的工作方式。因此,它调用PyNumber_Subtract()
,该函数可以处理任何类型的减法。 - 由于该减法的结果可能是任何对象,它也可能已覆盖了比较的工作方式。因此,
pgt
也必须转换为double对象,并且必须调用PyObject_RichCompare()
,该函数可以处理任何类型的比较。 - 比较函数通常返回True或False,但从技术上讲,这不是必需的。从技术上讲,它们也可以返回任何真值。Cython必须通过
__Pyx_PyObject_IsTrue()
将对象转换为布尔值来涵盖这种可能性。
真是费力,不是吗?这就是为什么这么慢。
幸运的是,有一些方法可以让它运行得更快:为time
参数添加类型和维度。
cpdef np.ndarray[np.int] calc_pulses(np.ndarray[np.double_t, ndim=1] time, double pgt, double dt):
在我的基准测试中,这会提高10倍的性能,因为当Cython知道类型时,所有这些步骤都更容易。
没有类型
491 微秒 ± 6.38 微秒每次循环(7次运行的平均值 ± 标准差,每次循环1,000次)
有类型
39.8 微秒 ± 274 纳秒每次循环(7次运行的平均值 ± 标准差,每次循环10,000次)
还有其他优化方法,但这是最重要的。
英文:
Before you go through the effort of doing that, I would suggest looking at the annotated output of Cython. It can often identify inefficient things.
Here's what it shows in this case:
The yellower a line is, the more it interacts with Python. There are a bunch of yellow lines outside the loop. Those don't matter. There are also more yellow lines inside the loop. Those do matter.
To give an example of why this function might be slow, let's click on the line
if time[k] - begin <= pgt:
It shows the following C code is used to implement that:
__pyx_t_6 = __Pyx_GetItemInt(((PyObject *)__pyx_v_time), __pyx_v_k, Py_ssize_t, 1, PyInt_FromSsize_t, 0, 1, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_6);
__pyx_t_2 = PyFloat_FromDouble(__pyx_v_begin); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_4 = PyNumber_Subtract(__pyx_t_6, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
__Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__pyx_t_2 = PyFloat_FromDouble(__pyx_v_pgt); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_6 = PyObject_RichCompare(__pyx_t_4, __pyx_t_2, Py_LE); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__pyx_t_11 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely(__pyx_t_11 < 0)) __PYX_ERR(0, 11, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
if (__pyx_t_11) {
/* … */
goto __pyx_L5;
}
What is it doing in this huge block of code?
- It gets an item from the
time
array.time
doesn't have a type for its elements. For that reason, item might be a double, but it could be a Fraction object, or a Decimal object. Cython must cover all these possibilities by calling__Pyx_GetItemInt()
. - Because
time[k]
might not be a double, it must convertbegin
into a double object withPyFloat_FromDouble()
, so that the subtraction can be done. - Since
time[k]
could be any object, that object might have overridden how subtraction works. Therefore, it callsPyNumber_Subtract()
, which can handle any kind of subtraction. - Since the result of that subtraction could be any object, it could have overridden how comparisons work. Thefore
pgt
must also be converted into a double object, andPyObject_RichCompare()
must be called, which can handle any kind of comparison. - Comparison functions normally return True or False, but that is technically not required. Technically, they could return any truthy value as well. Cython must cover this possibility by converting the object to a boolean with
__Pyx_PyObject_IsTrue()
.
Exhausting, right? That's why this is so slow.
Fortunately, there is something you can do to make this much faster: add a type and number of dimensions to the time
argument.
cpdef np.ndarray[np.int] calc_pulses(np.ndarray[np.double_t, ndim=1] time, double pgt, double dt):
This gives a 10x performance improvement in my benchmark, because every one of those steps are easier when Cython knows the type.
No types
491 µs ± 6.38 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
With types
39.8 µs ± 274 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
There are other optimizations you can do, but this is the most important.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论