英文:
How to make a function a valid jax type?
问题
当我将使用以下函数创建的对象传递到jax.lax.scan
函数中时:
def logdensity_create(model, centeredness=None, varname=None):
if centeredness is not None:
model = reparam(model, config={varname: LocScaleReparam(centered=centeredness)})
init_params, potential_fn_gen, *_ = initialize_model(jax.random.PRNGKey(0), model, dynamic_args=True)
logdensity = lambda position: -potential_fn_gen()(position)
initial_position = init_params.z
return (logdensity, initial_position)
我会得到以下错误(在将logdensity
传递给使用jax.lax.scan
创建的迭代函数时):
TypeError: Value .logdensity_create.. at 0x13fca7d80> with type is not a valid JAX type
如何解决这个错误?
英文:
When I pass an object created using the following function function into a jax.lax.scan function:
def logdensity_create(model, centeredness = None, varname = None):
if centeredness is not None:
model = reparam(model, config={varname: LocScaleReparam(centered= centeredness)})
init_params, potential_fn_gen, *_ = initialize_model(jax.random.PRNGKey(0),model,dynamic_args=True)
logdensity = lambda position: -potential_fn_gen()(position)
initial_position = init_params.z
return (logdensity, initial_position)
I get the following error (on passing the logdensity to an iterative function created using jax.lax.scan):
TypeError: Value .logdensity_create.. at 0x13fca7d80> with type is not a valid JAX type
How can I resolve this error?
答案1
得分: 2
我可能会通过 jax.tree_util.Partial
来实现这个,它可以将可调用对象包装成一个 PyTree,以便与 jit
和其他转换兼容:
logdensity = jax.tree_util.Partial(lambda position: -potential_fn_gen()(position))
英文:
I would probably do this via jax.tree_util.Partial
, which wraps callables in a PyTree for compatibility with jit
and other transformations:
logdensity = jax.tree_util.Partial(lambda position: -potential_fn_gen()(position))
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论