英文:
Returning a distribution object from a jittable function
问题
我想创建一个可即时编译的函数,它输出一个 distrax
分布对象。例如:
import distrax
import jax
import jax.numpy as jnp
def f(x):
dist = distrax.Categorical(logits=jnp.sin(x))
return dist
jit_f = jax.jit(f)
a = jnp.array([1,2,3])
dist = jit_f(a)
目前这段代码给我返回以下错误:
追溯到最近的调用(最新的调用在最前),最常见的调用为:
File "<stdin>", line 1, in <module>
File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 628, in cache_miss
out = tree_unflatten(out_pytree_def, out_flat)
File "F:\jax_env\lib\site-packages\jax\_src\tree_util.py", line 75, in tree_unflatten
return treedef.unflatten(leaves)
File "F:\jax_env\lib\site-packages\distrax\_src\utils\jittable.py", line 40, in tree_unflatten
obj = cls(*args, **kwargs)
File "F:\jax_env\lib\site-packages\distrax\_src\distributions\categorical.py", line 60, in __init__
self._logits = None if logits is None else math.normalize(logits=logits)
File "F:\jax_env\lib\site-packages\distrax\_src\utils\math.py", line 72, in normalize
return jax.nn.log_softmax(logits, axis=-1)
File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 618, in cache_miss
keep_unused=keep_unused))
File "F:\jax_env\lib\site-packages\jax\core.py", line 2031, in call_bind_with_continuation
top_trace = find_top_trace(args)
File "F:\jax_env\lib\site-packages\jax\core.py", line 1122, in find_top_trace
top_tracer._assert_live()
File "F:\jax\interpreters\partial_eval.py", line 1486, in _assert_live
raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: 遇到了意外的跟踪器。JAX 转换的函数出现了副作用,允许浮点32位数值类型的中间值通过 DynamicJaxprTracer 逃逸出转换的范围。
JAX 转换要求函数显式返回其输出,不允许将中间值保存到全局状态。
泄漏值的函数在 jit 中进行跟踪时,泄漏值是在 <stdin>:2(f) 的行上创建的。
当值被创建时,最终的 5 个堆栈帧(最近的最后)不包括 JAX 内部帧,如下:
<stdin>:1 (<module>)
<stdin>:2 (f)
要更早地捕获泄漏,请尝试设置环境变量 JAX_CHECK_TRACER_LEAKS 或使用 `jax.checking_leaks` 上下文管理器。
有关详细信息,请参阅 https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
我以为在 f
函数内部使用 dist = jax.block_until_ready(dist)
可能会解决问题,但实际上没有。
英文:
I want to create a jittable function that outputs a distrax
distribution object. For instance:
import distrax
import jax
import jax.numpy as jnp
def f(x):
dist = distrax.Categorical(logits=jnp.sin(x))
return dist
jit_f = jax.jit(f)
a = jnp.array([1,2,3])
dist = jit_f(a)
Currently this code gives me the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 628, in cache_miss
out = tree_unflatten(out_pytree_def, out_flat)
File "F:\jax_env\lib\site-packages\jax\_src\tree_util.py", line 75, in tree_unflatten
return treedef.unflatten(leaves)
File "F:\jax_env\lib\site-packages\distrax\_src\utils\jittable.py", line 40, in tree_unflatten
obj = cls(*args, **kwargs)
File "F:\jax_env\lib\site-packages\distrax\_src\distributions\categorical.py", line 60, in __init__
self._logits = None if logits is None else math.normalize(logits=logits)
File "F:\jax_env\lib\site-packages\distrax\_src\utils\math.py", line 72, in normalize
return jax.nn.log_softmax(logits, axis=-1)
File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 618, in cache_miss
keep_unused=keep_unused))
File "F:\jax_env\lib\site-packages\jax\core.py", line 2031, in call_bind_with_continuation
top_trace = find_top_trace(args)
File "F:\jax_env\lib\site-packages\jax\core.py", line 1122, in find_top_trace
top_tracer._assert_live()
File "F:\jax_env\lib\site-packages\jax\interpreters\partial_eval.py", line 1486, in _assert_live
raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[3] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was f at <stdin>:1 traced for jit.
------------------------------
The leaked intermediate value was created on line <stdin>:2 (f).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<stdin>:1 (<module>)
<stdin>:2 (f)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
I thought that using dist = jax.block_until_ready(dist)
inside f
could fix the problem, but it doesn't.
答案1
得分: 1
这看起来像是 distrax v0.1.2 中报告的错误,链接在 https://github.com/deepmind/distrax/issues/162。这已经在 https://github.com/deepmind/distrax/pull/177 中修复,该修复是 distrax v0.1.3 版本的一部分。
为了解决这个问题,您应该升级到 distrax v0.1.3 或更高版本。
英文:
This looks like the bug in distrax v0.1.2 reported in https://github.com/deepmind/distrax/issues/162. This wass fixed by https://github.com/deepmind/distrax/pull/177, which is part of the distrax v0.1.3 release.
To fix the issue, you should update to distrax v0.1.3 or later.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论