如何将一个函数变成有效的 JAX 类型?

huangapple go评论68阅读模式
英文:

How to make a function a valid jax type?

问题

当我将使用以下函数创建的对象传递到jax.lax.scan函数中时:

def logdensity_create(model, centeredness=None, varname=None):
    if centeredness is not None:
        model = reparam(model, config={varname: LocScaleReparam(centered=centeredness)})

    init_params, potential_fn_gen, *_ = initialize_model(jax.random.PRNGKey(0), model, dynamic_args=True)
    logdensity = lambda position: -potential_fn_gen()(position)
    initial_position = init_params.z
    return (logdensity, initial_position)

我会得到以下错误(在将logdensity传递给使用jax.lax.scan创建的迭代函数时):

TypeError: Value .logdensity_create.. at 0x13fca7d80> with type  is not a valid JAX type

如何解决这个错误?

英文:

When I pass an object created using the following function function into a jax.lax.scan function:

def logdensity_create(model, centeredness = None, varname = None):
    if centeredness is not None:
        model = reparam(model, config={varname: LocScaleReparam(centered= centeredness)})
          
    init_params, potential_fn_gen, *_ = initialize_model(jax.random.PRNGKey(0),model,dynamic_args=True)
    logdensity = lambda position: -potential_fn_gen()(position)
    initial_position = init_params.z
    return (logdensity, initial_position)

I get the following error (on passing the logdensity to an iterative function created using jax.lax.scan):

TypeError: Value .logdensity_create.. at 0x13fca7d80> with type  is not a valid JAX type

How can I resolve this error?

答案1

得分: 2

我可能会通过 jax.tree_util.Partial 来实现这个,它可以将可调用对象包装成一个 PyTree,以便与 jit 和其他转换兼容:

logdensity = jax.tree_util.Partial(lambda position: -potential_fn_gen()(position))
英文:

I would probably do this via jax.tree_util.Partial, which wraps callables in a PyTree for compatibility with jit and other transformations:

logdensity = jax.tree_util.Partial(lambda position: -potential_fn_gen()(position))

huangapple
  • 本文由 发表于 2023年7月10日 23:18:51
  • 转载请务必保留本文链接:https://go.coder-hub.com/76655153.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定