翻译结果:Jit是一个从字典中选择函数的JAX函数。

huangapple go评论101阅读模式
英文:

Jit a JAX function that select a function from a dictionary

问题

I have a JAX function that, given the order and the index, selects a polynomial from a pre-defined dictionary, as follows:

poly_dict = {
    (0, 0): lambda x, y, z: 1.,
    (1, 0): lambda x, y, z: x,
    (1, 1): lambda x, y, z: y,
    (1, 2): lambda x, y, z: z,
    (2, 0): lambda x, y, z: x*x,
    (2, 1): lambda x, y, z: y*y,
    (2, 2): lambda x, y, z: z*z,
    (2, 3): lambda x, y, z: x*y,
    (2, 4): lambda x, y, z: y*z,
    (2, 5): lambda x, y, z: z*x
}

def poly_func(order: int, index: int):
    
    try:
        return poly_dict[(order, index)]
    
    except KeyError:
        print("`(order, index)` must be a key in `poly_dict!`")
        return

Now I want to jit poly_func(), but it gives an error TypeError: unhashable type: 'DynamicJaxprTracer' Moreover, if I just do

def poly_func(order: int, index: int):
    
    return poly_dict[(order, index)]

it still gives the same error. Is there a way to resolve this issue?

英文:

I have a JAX function that, given the order and the index, selects a polynomial from a pre-defined dictionary, as follows:

poly_dict = {
    (0, 0): lambda x, y, z: 1.,
    (1, 0): lambda x, y, z: x,
    (1, 1): lambda x, y, z: y,
    (1, 2): lambda x, y, z: z,
    (2, 0): lambda x, y, z: x*x,
    (2, 1): lambda x, y, z: y*y,
    (2, 2): lambda x, y, z: z*z,
    (2, 3): lambda x, y, z: x*y,
    (2, 4): lambda x, y, z: y*z,
    (2, 5): lambda x, y, z: z*x
}

def poly_func(order: int, index: int):
    
    try:
        return poly_dict[(order, index)]
    
    except KeyError:
        print("(order, index) must be a key in poly_dict!")
        return

Now I want to jit poly_func(), but it gives an error TypeError: unhashable type: 'DynamicJaxprTracer' Moreover, if I just do

def poly_func(order: int, index: int):
    
    return poly_dict[(order, index)]

it still gives the same error. Is there a way to resolve this issue?

答案1

得分: 1

以下是您要翻译的内容:

你的方法有两个问题:首先,跟踪的值不能用于索引Python集合,如字典或列表。其次,JIT编译的函数只能返回数组值,而不能返回函数。

考虑到这一点,使用JAX转换如JIT无法使您提出的函数正常工作。但您可以修改方法,使用函数列表而不是函数字典,然后使用lax.switch来动态选择要调用的函数。以下是可能的实现方式:

def get_index(order, index):
  return order * 5 + index

poly_list = 16 * [lambda x, y, z: 0.0]  # 占位符函数

for key, val in poly_dict.items():
  poly_list[get_index(*key)] = poly_dict[key]

def eval_poly_func(order: int, index: int, args):
  ind = get_index(order, index)
  return jax.lax.switch(ind, poly_list, *args)

result = jax.jit(eval_poly_func)(2, 0, (5.0, 6.0, 7.0))
print(result)
# 25.0
英文:

There are two issues with your approach: first, traced values cannot be used to index into Python collections like dicts or lists. Second, JIT-compiled functions can only return array values, not functions.

With that in mind, it is impossible to make the function you propose work with JAX transforms like JIT. But you could modify your approach to use a list rather than a dict of functions, and then use lax.switch to dynamically select which function you call. Here's how it might look:

def get_index(order, index):
  return order * 5 + index

poly_list = 16 * [lambda x, y, z: 0.0]  # place-holder function

for key, val in poly_dict.items():
  poly_list[get_index(*key)] = poly_dict[key]


def eval_poly_func(order: int, index: int, args):
  ind = get_index(order, index)
  return jax.lax.switch(ind, poly_list, *args)

result = jax.jit(eval_poly_func)(2, 0, (5.0, 6.0, 7.0))
print(result)
# 25.0

huangapple
  • 本文由 发表于 2023年5月7日 07:31:10
  • 转载请务必保留本文链接:https://go.coder-hub.com/76191616.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定