Optimise this function — numpy broadcasting issue.

huangapple go评论82阅读模式
英文:

Optimise this function -- numpy broadcasting issue

问题

以下是已经翻译好的部分:

import numpy as np

def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u > min) & (u < max), axis=1)
    else:
        return np.all((u >= min) & (u <= max), axis=1)

# Usage examples : 
d = 4
min = np.random.uniform(size=d)*1/2
max = np.random.uniform(size=d)*1/2+1/2
u1 = np.random.uniform(size=d)
u2 = np.random.uniform(size=(100,d))
u3 = u2[np.repeat(False,100)]

contains(u1,min,max,d) # should return a boolean array of shape (1,)
contains(u2,min,max,d) # shape (100,)
contains(u3,min,max,d) # shape (0,)
英文:

I have a function contains that check for a given 2D array u if the box [min,max] contains each row of u. I need it to reshape u if needed, but the number of values of u will always be a multiple of d (can be zero);

I'm using the following snippet of code. This function run thousands of time. Can faster code be produced ? If you think so, any tips on how to ?

import numpy as np

def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u &gt; min) &amp; (u &lt; max), axis=1)
    else:
        return np.all((u &gt;= min) &amp; (u &lt;= max), axis=1)

# Usage examples : 
d = 4
min = np.random.uniform(size=d)*1/2
max = np.random.uniform(size=d)*1/2+1/2
u1 = np.random.uniform(size=d)
u2 = np.random.uniform(size=(100,d))
u3 = u2[np.repeat(False,100)]
   
contains(u1,min,max,d) # should return a boolean array of shape (1,)
contains(u2,min,max,d) # shape (100,)
contains(u3,min,max,d) # shape (0,)

答案1

得分: 3

以下是您提供的代码的中文翻译:

瓶颈最终出现在np.all()调用内部。
可以使用Numba来加速,如下所示:

import numpy as np
import numba as nb

@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    n = arr.shape[0]
    result = np.ones(n, dtype=np.bool8)
    for i in range(n):       
        for j in range(m):
            if not a_arr[j] < arr[i, j] < b_arr[j]:
                result[i] = False
                break
    return result

这与NumPy解决方案相比:

import numpy as np

def contains_np(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    return np.all((arr >= a_arr) & (arr <= b_arr), axis=1)

我对您的方法进行了一些简化(我省略了dimstrict参数,因为dim是多余的,可以从a_arrb_arr的大小中推断出,而strict参数对解决方案没有太多影响,但可以轻松重新引入)。
我还假设输入已经始终是NumPy数组。

另外,NumPy解决方案可以修改为使用numexpr,这将导致第三种方法。这将带来一些调用开销,但可以加速计算,例如:

import numpy as np
import numexpr as ne

def contains_ne(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    result = ne.evaluate('(arr >= a_arr) & (arr <= b_arr)')
    return np.all(result, axis=1)

可以获得以下性能基准数据:

Optimise this function — numpy broadcasting issue.

这表明Numba解决方案始终是最快的。
相反,使用numexpr似乎对探索的参数范围没有好处。

(完整的基准数据可在此处获得)

英文:

(EDITED: to fix the timing measurement issue raised by @max9111 in the comments, and to include a numexpr-modified solution).

The bottleneck would eventually be within the np.all() call.
This could be sped up with Numba like the following:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    n = arr.shape[0]
    result = np.ones(n, dtype=np.bool8)
    for i in range(n):       
        for j in range(m):
            if not a_arr[j] &lt; arr[i, j] &lt; b_arr[j]:
                result[i] = False
                break
    return result

This is compared to the NumPy solution:

import numpy as np


def contains_np(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    return np.all((arr &gt;= a_arr) &amp; (arr &lt;= b_arr), axis=1)

which I simplified a bit over your approach (I have omitted dim and strict parameters, since dim is redundant, as it can be inferred from a_arr or b_arr sizes, while the strict parameter does not add much to the solution, but it can be easily reintroduced).
I also assume that the input is already always a NumPy array.

Also, the NumPy solution could be modified to use numexpr, which leads to a third approach. This will have some calling overhead, but could speed up the computations, e.g.:

import numpy as np
import numexpr as ne


def contains_ne(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    result = ne.evaluate(&#39;(arr &gt;= a_arr) &amp; (arr &lt;= b_arr)&#39;)
    return np.all(result, axis=1)

The following benchmarks can be obtained:

Optimise this function — numpy broadcasting issue.

This show that the Numba solution is consistently the fastest.
On the contrary, the use of numexpr seems to be non-beneficial for the range of parameters explored.

(full benchmark available here)

答案2

得分: 2

Here is the translated content:

尝试这个以提高速度,在这里阅读更多

from numba import jit

@jit(nopython=True)
def contains(u, min, max, dim, strict=True):
    u = np.array(u).reshape(-1, dim)
    if strict:
        return np.all((u > min) & (u < max), axis=1)
    else:
        return np.all((u >= min) & (u <= max), axis=1)
英文:

Try this to speed, read more here

from numba import jit

@jit(nopython=True)
def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u &gt; min) &amp; (u &lt; max), axis=1)
    else:
        return np.all((u &gt;= min) &amp; (u &lt;= max), axis=1)

huangapple
  • 本文由 发表于 2020年1月7日 01:24:14
  • 转载请务必保留本文链接:https://go.coder-hub.com/59616390.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定