英文:
Jit a JAX function that select a function from a dictionary
问题
I have a JAX function that, given the order and the index, selects a polynomial from a pre-defined dictionary, as follows:
poly_dict = {
(0, 0): lambda x, y, z: 1.,
(1, 0): lambda x, y, z: x,
(1, 1): lambda x, y, z: y,
(1, 2): lambda x, y, z: z,
(2, 0): lambda x, y, z: x*x,
(2, 1): lambda x, y, z: y*y,
(2, 2): lambda x, y, z: z*z,
(2, 3): lambda x, y, z: x*y,
(2, 4): lambda x, y, z: y*z,
(2, 5): lambda x, y, z: z*x
}
def poly_func(order: int, index: int):
try:
return poly_dict[(order, index)]
except KeyError:
print("`(order, index)` must be a key in `poly_dict!`")
return
Now I want to jit poly_func()
, but it gives an error TypeError: unhashable type: 'DynamicJaxprTracer'
Moreover, if I just do
def poly_func(order: int, index: int):
return poly_dict[(order, index)]
it still gives the same error. Is there a way to resolve this issue?
英文:
I have a JAX function that, given the order and the index, selects a polynomial from a pre-defined dictionary, as follows:
poly_dict = {
(0, 0): lambda x, y, z: 1.,
(1, 0): lambda x, y, z: x,
(1, 1): lambda x, y, z: y,
(1, 2): lambda x, y, z: z,
(2, 0): lambda x, y, z: x*x,
(2, 1): lambda x, y, z: y*y,
(2, 2): lambda x, y, z: z*z,
(2, 3): lambda x, y, z: x*y,
(2, 4): lambda x, y, z: y*z,
(2, 5): lambda x, y, z: z*x
}
def poly_func(order: int, index: int):
try:
return poly_dict[(order, index)]
except KeyError:
print("(order, index) must be a key in poly_dict!")
return
Now I want to jit poly_func()
, but it gives an error TypeError: unhashable type: 'DynamicJaxprTracer'
Moreover, if I just do
def poly_func(order: int, index: int):
return poly_dict[(order, index)]
it still gives the same error. Is there a way to resolve this issue?
答案1
得分: 1
以下是您要翻译的内容:
你的方法有两个问题:首先,跟踪的值不能用于索引Python集合,如字典或列表。其次,JIT编译的函数只能返回数组值,而不能返回函数。
考虑到这一点,使用JAX转换如JIT无法使您提出的函数正常工作。但您可以修改方法,使用函数列表而不是函数字典,然后使用lax.switch
来动态选择要调用的函数。以下是可能的实现方式:
def get_index(order, index):
return order * 5 + index
poly_list = 16 * [lambda x, y, z: 0.0] # 占位符函数
for key, val in poly_dict.items():
poly_list[get_index(*key)] = poly_dict[key]
def eval_poly_func(order: int, index: int, args):
ind = get_index(order, index)
return jax.lax.switch(ind, poly_list, *args)
result = jax.jit(eval_poly_func)(2, 0, (5.0, 6.0, 7.0))
print(result)
# 25.0
英文:
There are two issues with your approach: first, traced values cannot be used to index into Python collections like dicts or lists. Second, JIT-compiled functions can only return array values, not functions.
With that in mind, it is impossible to make the function you propose work with JAX transforms like JIT. But you could modify your approach to use a list rather than a dict of functions, and then use lax.switch
to dynamically select which function you call. Here's how it might look:
def get_index(order, index):
return order * 5 + index
poly_list = 16 * [lambda x, y, z: 0.0] # place-holder function
for key, val in poly_dict.items():
poly_list[get_index(*key)] = poly_dict[key]
def eval_poly_func(order: int, index: int, args):
ind = get_index(order, index)
return jax.lax.switch(ind, poly_list, *args)
result = jax.jit(eval_poly_func)(2, 0, (5.0, 6.0, 7.0))
print(result)
# 25.0
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论