英文:
How to pass a numpy array though a function when the function contains conditional if statements?
问题
Here is the translation of the code you provided:
import numpy as np
from astropy.cosmology import FlatLambdaCDM
import matplotlib.pyplot as plt
cosmopar = FlatLambdaCDM(H0=67.8, Om0=0.3)
td_min = 0.1
d = -1
def t_L(z_arb):
return cosmopar.lookback_time(z_arb).value
def t_d(z_f, z_m):
return t_L(z_f) - t_L(z_m)
def P_t(z_f, z_m):
if (td_min < t_d):
return t_d ** d
else:
return 0
zf_trial1 = np.linspace(0, 30, 100)
# The following code is used to pass zf_trial1 through the P_t function
Pt_vector = np.vectorize(P_t)
Pt_res = Pt_vector(zf_trial1, 3)
plt.scatter(zf_trial1, Pt_res)
Please note that the code has been translated, and I have excluded the explanations and questions from your original text as per your request. If you have any further questions or need assistance with this code, please feel free to ask.
英文:
I have the following code:
import numpy as np
from astropy.cosmology import FlatLambdaCDM
import matplotlib.pyplot as plt
cosmopar = FlatLambdaCDM(H0 = 67.8,Om0 = 0.3)
td_min = 0.1
d = -1
def t_L(z_arb):
return cosmopar.lookback_time(z_arb).value
def t_d(z_f,z_m):
return t_L(z_f)-t_L(z_m)
def P_t(z_f,z_m):
if (td_min<t_d):
return t_d**d
else:
return 0
Now if I define a numpy array zf_trial1 = np.linspace(0,30,100)
and try to pass it through the function using the command P_t(zf_trial1,3)
, the function returns the following error statement:
"The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"
Now I understand why this error is popping up - when making the comparizon with td_min
in the if statement, passing an 'array' with several elements leads to certain elements of the array satisfying the condition of the if statement and certain elements failing the condition; however, I am not sure how to fix this. Overall, all I want to do is pass each element of the NumPy array zf_trial1
through P_t(z_f,z_m)
.
I tried the np.vectorize()
function but this doesn't seem to be working all that well and the results seem to be haywire since when I plot the results of the function, the graph I am receiving is different from the one I am receiving if I manually input values into the P_t function and then plot it. What I tried is as follows:
Pt_vector = np.vectorize(P_t)
Pt_res = Pt_vector(zf_trial1,3)
plt.scatter(zf_trial1,Pt_res)
答案1
得分: 0
我有点困惑,因为 P_t
接受 z_f
和 z_m
作为参数,但在函数中根本没有使用它们,因此我不太明白应该如何使用它,因此我将尝试以一种通用的方式回答这个问题。
您可以在您的函数中使用 np.where
来创建一个过滤器,如下所示:
def ex(arr,max_value):
"""
arr: np.array
max_value: int
所有在 `arr` 中低于 `max_value` 的值都会被提升到 `d` 的幂。
低于 `max_value` 的值将被设置为 `0`
"""
return np.where(arr < max_value, arr**d, 0)
d = 2
ex(np.arange(0,10),5) # array([ 0, 1, 4, 9, 16, 0, 0, 0, 0, 0])
或者您可以使用列表推导式:
a = np.array([P_t(v,z_m) for v in zf_trial1])
英文:
I'm a bit confused since P_t
takes z_f
and z_m
as argument but they aren't used in the function at all, thus I don't really get how it should be used, thus I'll try answer the question in a general way.
You can use np.where
to make a filter, in your function, like so
def ex(arr,max_value):
"""
arr: np.array
max_value: int
All values in `arr` below `max_value`
are raised to a power of `d`.
Values below are set as `0`
"""
return np.where(arr < max_value, arr**d, 0)
d = 2
ex(np.arange(0,10),5) # array([ 0, 1, 4, 9, 16, 0, 0, 0, 0, 0])
or you can just use list-comprehension :
a = np.array([P_t(v,z_m) for v in zf_trial1])
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论