如何在函数包含条件if语句时将numpy数组传递给函数?

huangapple go评论60阅读模式
英文:

How to pass a numpy array though a function when the function contains conditional if statements?

问题

Here is the translation of the code you provided:

import numpy as np
from astropy.cosmology import FlatLambdaCDM
import matplotlib.pyplot as plt

cosmopar = FlatLambdaCDM(H0=67.8, Om0=0.3)

td_min = 0.1
d = -1

def t_L(z_arb):
    return cosmopar.lookback_time(z_arb).value

def t_d(z_f, z_m):
    return t_L(z_f) - t_L(z_m)

def P_t(z_f, z_m):
    if (td_min < t_d):
        return t_d ** d
    else:
        return 0

zf_trial1 = np.linspace(0, 30, 100)

# The following code is used to pass zf_trial1 through the P_t function
Pt_vector = np.vectorize(P_t)
Pt_res = Pt_vector(zf_trial1, 3)

plt.scatter(zf_trial1, Pt_res)

Please note that the code has been translated, and I have excluded the explanations and questions from your original text as per your request. If you have any further questions or need assistance with this code, please feel free to ask.

英文:

I have the following code:

import numpy as np
from astropy.cosmology import FlatLambdaCDM
import matplotlib.pyplot as plt

cosmopar = FlatLambdaCDM(H0 = 67.8,Om0 = 0.3)

td_min = 0.1
d = -1

def t_L(z_arb):
    return cosmopar.lookback_time(z_arb).value

def t_d(z_f,z_m):
    return t_L(z_f)-t_L(z_m)

def P_t(z_f,z_m):
    if (td_min&lt;t_d):
        return t_d**d
    else:
        return 0

Now if I define a numpy array zf_trial1 = np.linspace(0,30,100) and try to pass it through the function using the command P_t(zf_trial1,3), the function returns the following error statement:

"The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"

Now I understand why this error is popping up - when making the comparizon with td_min in the if statement, passing an 'array' with several elements leads to certain elements of the array satisfying the condition of the if statement and certain elements failing the condition; however, I am not sure how to fix this. Overall, all I want to do is pass each element of the NumPy array zf_trial1 through P_t(z_f,z_m).

I tried the np.vectorize() function but this doesn't seem to be working all that well and the results seem to be haywire since when I plot the results of the function, the graph I am receiving is different from the one I am receiving if I manually input values into the P_t function and then plot it. What I tried is as follows:

Pt_vector = np.vectorize(P_t)
Pt_res = Pt_vector(zf_trial1,3)

plt.scatter(zf_trial1,Pt_res)

答案1

得分: 0

我有点困惑,因为 P_t 接受 z_fz_m 作为参数,但在函数中根本没有使用它们,因此我不太明白应该如何使用它,因此我将尝试以一种通用的方式回答这个问题。

您可以在您的函数中使用 np.where 来创建一个过滤器,如下所示:

def ex(arr,max_value):
    """
    arr: np.array
    max_value: int

    所有在 `arr` 中低于 `max_value` 的值都会被提升到 `d` 的幂。
    低于 `max_value` 的值将被设置为 `0`
    """
    return np.where(arr < max_value, arr**d, 0)

d = 2
ex(np.arange(0,10),5) # array([ 0,  1,  4,  9, 16,  0,  0,  0,  0,  0])

或者您可以使用列表推导式:

a = np.array([P_t(v,z_m) for v in zf_trial1])
英文:

I'm a bit confused since P_t takes z_f and z_m as argument but they aren't used in the function at all, thus I don't really get how it should be used, thus I'll try answer the question in a general way.

You can use np.where to make a filter, in your function, like so

def ex(arr,max_value):
      &quot;&quot;&quot;
      arr: np.array
      max_value: int


      All values in `arr` below `max_value`
      are raised to a power of `d`.
      Values below are set as `0`
      &quot;&quot;&quot;
      return np.where(arr &lt; max_value, arr**d, 0)

d = 2
ex(np.arange(0,10),5) # array([ 0,  1,  4,  9, 16,  0,  0,  0,  0,  0])
      

or you can just use list-comprehension :

a = np.array([P_t(v,z_m) for v in zf_trial1])

huangapple
  • 本文由 发表于 2023年7月18日 16:30:53
  • 转载请务必保留本文链接:https://go.coder-hub.com/76710871.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定