如何将这个函数向量化到 pandas 或 polars 中

huangapple go评论69阅读模式
英文:

how to vectorize this function to pandas or polars

问题

大家好,我需要使用向量化操作来复制以下行为。 (目前使用cython编写,但仍然太慢,因为我在许多组上调用它)。 但是,我不确定是否可能实现这种行为。 感谢任何帮助。

Python版本/伪代码:

def calc_pulses(np.ndarray time, double pgt, double dt):
    curr_pulse = 1
    res = np.empty(len(time))
    begin = time[0]
    for k in range(len(time)):
        if time[k] - begin <= pgt:
            res[k] = curr_pulse
        elif time[k] - begin > pgt+dt:
            begin = time[k]
            curr_pulse += 1
            res[k] = curr_pulse
        else:
            res[k] = np.nan
    return res

Cython版本:

cpdef np.ndarray[np.int] calc_pulses(np.ndarray time, double pgt, double dt):
    cdef double begin
    cdef int curr_pulse = 1
    cdef Py_ssize_t k, n = len(time)
    cdef np.ndarray[np.int_t, ndim=1] res = np.empty(n, dtype=np.int)
    begin = time[0]
    for k in range(n):
        if time[k] - begin <= pgt:
            res[k] = curr_pulse
        elif time[k] - begin > pgt+dt:
            begin = time[k]
            curr_pulse += 1
            res[k] = curr_pulse
        else:
            res[k] = np.nan
    return res
英文:

Hello everyone I need to replicate the below behavior using vectorized operations. (currently written in cython and it is still too slow as I call it over many groups). However, I am not sure if this behavior is possible. Any help would be appreciated.

python version / pseudocode:

def calc_pulses(np.ndarray time, double pgt, double dt):
    curr_pulse = 1
    res = np.empty(len(time))
    begin = time[0]
    for k in range(len(time)):
        if time[k] - begin &lt;= pgt:
            res[k] = curr_pulse
        elif time[k] - begin &gt; pgt+dt:
            begin = time[k]
            curr_pulse += 1
            res[k] = curr_pulse
        else:
            res[k] = np.nan
    return res

cython version:

cpdef np.ndarray[np.int] calc_pulses(np.ndarray time, double pgt, double dt):
    cdef double begin
    cdef int curr_pulse = 1
    cdef Py_ssize_t k, n = len(time)
    cdef np.ndarray[np.int_t, ndim=1] res = np.empty(n, dtype=np.int)
    begin = time[0]
    for k in range(n):
        if time[k] - begin &lt;= pgt:
            res[k] = curr_pulse
        elif time[k] - begin &gt; pgt+dt:
            begin = time[k]
            curr_pulse += 1
            res[k] = curr_pulse
        else:
            res[k] = np.nan
    return res

答案1

得分: 4

在你付出这个努力之前,我建议查看Cython的注释输出。它通常能够识别出低效的部分。

以下是在这种情况下它显示的内容:

如何将这个函数向量化到 pandas 或 polars 中

一行越黄,它与Python的交互越多。循环外有一堆黄线,这些不重要。循环内还有更多的黄线,这些才重要。

举个例子,让我们点击以下代码行

if time[k] - begin &lt;= pgt:

它显示以下C代码用于实现:

    __pyx_t_6 = __Pyx_GetItemInt(((PyObject *)__pyx_v_time), __pyx_v_k, Py_ssize_t, 1, PyInt_FromSsize_t, 0, 1, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_6);
    __pyx_t_2 = PyFloat_FromDouble(__pyx_v_begin); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_2);
    __pyx_t_4 = PyNumber_Subtract(__pyx_t_6, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_4);
    __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
    __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
    __pyx_t_2 = PyFloat_FromDouble(__pyx_v_pgt); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_2);
    __pyx_t_6 = PyObject_RichCompare(__pyx_t_4, __pyx_t_2, Py_LE); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
    __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
    __pyx_t_11 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely(__pyx_t_11 &lt; 0)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
    if (__pyx_t_11) {
/* … */
      goto __pyx_L5;
    }

这个大块代码在做什么?

  • 它从time数组中获取一个项目。time没有为其元素指定类型。因此,项目可能是一个double,但也可能是一个Fraction对象或Decimal对象。Cython必须通过调用__Pyx_GetItemInt()来涵盖所有这些可能性。
  • 因为time[k]可能不是一个double,所以它必须将begin转换为double对象,以便进行减法运算。
  • 由于time[k]可以是任何对象,该对象可能已覆盖减法的工作方式。因此,它调用PyNumber_Subtract(),该函数可以处理任何类型的减法。
  • 由于该减法的结果可能是任何对象,它也可能已覆盖了比较的工作方式。因此,pgt也必须转换为double对象,并且必须调用PyObject_RichCompare(),该函数可以处理任何类型的比较。
  • 比较函数通常返回True或False,但从技术上讲,这不是必需的。从技术上讲,它们也可以返回任何真值。Cython必须通过__Pyx_PyObject_IsTrue()将对象转换为布尔值来涵盖这种可能性。

真是费力,不是吗?这就是为什么这么慢。

幸运的是,有一些方法可以让它运行得更快:为time参数添加类型和维度。

cpdef np.ndarray[np.int] calc_pulses(np.ndarray[np.double_t, ndim=1] time, double pgt, double dt):

在我的基准测试中,这会提高10倍的性能,因为当Cython知道类型时,所有这些步骤都更容易。

没有类型
491 微秒 ± 6.38 微秒每次循环(7次运行的平均值 ± 标准差,每次循环1,000次)

有类型
39.8 微秒 ± 274 纳秒每次循环(7次运行的平均值 ± 标准差,每次循环10,000次)

还有其他优化方法,但这是最重要的。

英文:

Before you go through the effort of doing that, I would suggest looking at the annotated output of Cython. It can often identify inefficient things.

Here's what it shows in this case:

如何将这个函数向量化到 pandas 或 polars 中

The yellower a line is, the more it interacts with Python. There are a bunch of yellow lines outside the loop. Those don't matter. There are also more yellow lines inside the loop. Those do matter.

To give an example of why this function might be slow, let's click on the line

if time[k] - begin &lt;= pgt:

It shows the following C code is used to implement that:

    __pyx_t_6 = __Pyx_GetItemInt(((PyObject *)__pyx_v_time), __pyx_v_k, Py_ssize_t, 1, PyInt_FromSsize_t, 0, 1, 1); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_6);
    __pyx_t_2 = PyFloat_FromDouble(__pyx_v_begin); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_2);
    __pyx_t_4 = PyNumber_Subtract(__pyx_t_6, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_4);
    __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
    __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
    __pyx_t_2 = PyFloat_FromDouble(__pyx_v_pgt); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_2);
    __pyx_t_6 = PyObject_RichCompare(__pyx_t_4, __pyx_t_2, Py_LE); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
    __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
    __pyx_t_11 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely(__pyx_t_11 &lt; 0)) __PYX_ERR(0, 11, __pyx_L1_error)
    __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
    if (__pyx_t_11) {
/* … */
      goto __pyx_L5;
    }

What is it doing in this huge block of code?

  • It gets an item from the time array. time doesn't have a type for its elements. For that reason, item might be a double, but it could be a Fraction object, or a Decimal object. Cython must cover all these possibilities by calling __Pyx_GetItemInt().
  • Because time[k] might not be a double, it must convert begin into a double object with PyFloat_FromDouble(), so that the subtraction can be done.
  • Since time[k] could be any object, that object might have overridden how subtraction works. Therefore, it calls PyNumber_Subtract(), which can handle any kind of subtraction.
  • Since the result of that subtraction could be any object, it could have overridden how comparisons work. Thefore pgt must also be converted into a double object, and PyObject_RichCompare() must be called, which can handle any kind of comparison.
  • Comparison functions normally return True or False, but that is technically not required. Technically, they could return any truthy value as well. Cython must cover this possibility by converting the object to a boolean with __Pyx_PyObject_IsTrue().

Exhausting, right? That's why this is so slow.

Fortunately, there is something you can do to make this much faster: add a type and number of dimensions to the time argument.

cpdef np.ndarray[np.int] calc_pulses(np.ndarray[np.double_t, ndim=1] time, double pgt, double dt):

This gives a 10x performance improvement in my benchmark, because every one of those steps are easier when Cython knows the type.

No types
491 &#181;s &#177; 6.38 &#181;s per loop (mean &#177; std. dev. of 7 runs, 1,000 loops each)

With types
39.8 &#181;s &#177; 274 ns per loop (mean &#177; std. dev. of 7 runs, 10,000 loops each)

There are other optimizations you can do, but this is the most important.

huangapple
  • 本文由 发表于 2023年6月29日 04:05:33
  • 转载请务必保留本文链接:https://go.coder-hub.com/76576389.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定