英文:
Optimise this function -- numpy broadcasting issue
问题
以下是已经翻译好的部分:
import numpy as np
def contains(u, min, max, dim, strict = True):
u = np.array(u).reshape(-1 ,dim)
if strict:
return np.all((u > min) & (u < max), axis=1)
else:
return np.all((u >= min) & (u <= max), axis=1)
# Usage examples :
d = 4
min = np.random.uniform(size=d)*1/2
max = np.random.uniform(size=d)*1/2+1/2
u1 = np.random.uniform(size=d)
u2 = np.random.uniform(size=(100,d))
u3 = u2[np.repeat(False,100)]
contains(u1,min,max,d) # should return a boolean array of shape (1,)
contains(u2,min,max,d) # shape (100,)
contains(u3,min,max,d) # shape (0,)
英文:
I have a function contains
that check for a given 2D array u
if the box [min,max] contains each row of u
. I need it to reshape u
if needed, but the number of values of u
will always be a multiple of d
(can be zero);
I'm using the following snippet of code. This function run thousands of time. Can faster code be produced ? If you think so, any tips on how to ?
import numpy as np
def contains(u, min, max, dim, strict = True):
u = np.array(u).reshape(-1 ,dim)
if strict:
return np.all((u > min) & (u < max), axis=1)
else:
return np.all((u >= min) & (u <= max), axis=1)
# Usage examples :
d = 4
min = np.random.uniform(size=d)*1/2
max = np.random.uniform(size=d)*1/2+1/2
u1 = np.random.uniform(size=d)
u2 = np.random.uniform(size=(100,d))
u3 = u2[np.repeat(False,100)]
contains(u1,min,max,d) # should return a boolean array of shape (1,)
contains(u2,min,max,d) # shape (100,)
contains(u3,min,max,d) # shape (0,)
答案1
得分: 3
以下是您提供的代码的中文翻译:
瓶颈最终出现在np.all()
调用内部。
可以使用Numba来加速,如下所示:
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
n = arr.shape[0]
result = np.ones(n, dtype=np.bool8)
for i in range(n):
for j in range(m):
if not a_arr[j] < arr[i, j] < b_arr[j]:
result[i] = False
break
return result
这与NumPy解决方案相比:
import numpy as np
def contains_np(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
return np.all((arr >= a_arr) & (arr <= b_arr), axis=1)
我对您的方法进行了一些简化(我省略了dim
和strict
参数,因为dim
是多余的,可以从a_arr
或b_arr
的大小中推断出,而strict
参数对解决方案没有太多影响,但可以轻松重新引入)。
我还假设输入已经始终是NumPy数组。
另外,NumPy解决方案可以修改为使用numexpr
,这将导致第三种方法。这将带来一些调用开销,但可以加速计算,例如:
import numpy as np
import numexpr as ne
def contains_ne(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
result = ne.evaluate('(arr >= a_arr) & (arr <= b_arr)')
return np.all(result, axis=1)
可以获得以下性能基准数据:
这表明Numba解决方案始终是最快的。
相反,使用numexpr
似乎对探索的参数范围没有好处。
(完整的基准数据可在此处获得)
英文:
(EDITED: to fix the timing measurement issue raised by @max9111 in the comments, and to include a numexpr
-modified solution).
The bottleneck would eventually be within the np.all()
call.
This could be sped up with Numba like the following:
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
n = arr.shape[0]
result = np.ones(n, dtype=np.bool8)
for i in range(n):
for j in range(m):
if not a_arr[j] < arr[i, j] < b_arr[j]:
result[i] = False
break
return result
This is compared to the NumPy solution:
import numpy as np
def contains_np(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
return np.all((arr >= a_arr) & (arr <= b_arr), axis=1)
which I simplified a bit over your approach (I have omitted dim
and strict
parameters, since dim
is redundant, as it can be inferred from a_arr
or b_arr
sizes, while the strict
parameter does not add much to the solution, but it can be easily reintroduced).
I also assume that the input is already always a NumPy array.
Also, the NumPy solution could be modified to use numexpr
, which leads to a third approach. This will have some calling overhead, but could speed up the computations, e.g.:
import numpy as np
import numexpr as ne
def contains_ne(arr, a_arr, b_arr):
m = a_arr.size
arr = arr.reshape(-1, m)
result = ne.evaluate('(arr >= a_arr) & (arr <= b_arr)')
return np.all(result, axis=1)
The following benchmarks can be obtained:
This show that the Numba solution is consistently the fastest.
On the contrary, the use of numexpr
seems to be non-beneficial for the range of parameters explored.
(full benchmark available here)
答案2
得分: 2
Here is the translated content:
尝试这个以提高速度,在这里阅读更多
from numba import jit
@jit(nopython=True)
def contains(u, min, max, dim, strict=True):
u = np.array(u).reshape(-1, dim)
if strict:
return np.all((u > min) & (u < max), axis=1)
else:
return np.all((u >= min) & (u <= max), axis=1)
英文:
Try this to speed, read more here
from numba import jit
@jit(nopython=True)
def contains(u, min, max, dim, strict = True):
u = np.array(u).reshape(-1 ,dim)
if strict:
return np.all((u > min) & (u < max), axis=1)
else:
return np.all((u >= min) & (u <= max), axis=1)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论