从可即时编译函数返回分布对象

huangapple go评论68阅读模式
英文:

Returning a distribution object from a jittable function

问题

我想创建一个可即时编译的函数,它输出一个 distrax 分布对象。例如:

import distrax
import jax
import jax.numpy as jnp

def f(x):
   dist = distrax.Categorical(logits=jnp.sin(x))
   return dist

jit_f = jax.jit(f)
a = jnp.array([1,2,3])
dist = jit_f(a)

目前这段代码给我返回以下错误:

追溯到最近的调用最新的调用在最前),最常见的调用为
  File "<stdin>", line 1, in <module>
  File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 628, in cache_miss
    out = tree_unflatten(out_pytree_def, out_flat)
  File "F:\jax_env\lib\site-packages\jax\_src\tree_util.py", line 75, in tree_unflatten
    return treedef.unflatten(leaves)
  File "F:\jax_env\lib\site-packages\distrax\_src\utils\jittable.py", line 40, in tree_unflatten
    obj = cls(*args, **kwargs)
  File "F:\jax_env\lib\site-packages\distrax\_src\distributions\categorical.py", line 60, in __init__
    self._logits = None if logits is None else math.normalize(logits=logits)
  File "F:\jax_env\lib\site-packages\distrax\_src\utils\math.py", line 72, in normalize
    return jax.nn.log_softmax(logits, axis=-1)
  File "F:\jax_env\lib\site-packages\jax\_src\traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "F:\jax_env\lib\site-packages\jax\_src\api.py", line 618, in cache_miss
    keep_unused=keep_unused))
  File "F:\jax_env\lib\site-packages\jax\core.py", line 2031, in call_bind_with_continuation
    top_trace = find_top_trace(args)
  File "F:\jax_env\lib\site-packages\jax\core.py", line 1122, in find_top_trace
    top_tracer._assert_live()
  File "F:\jax\interpreters\partial_eval.py", line 1486, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: 遇到了意外的跟踪器JAX 转换的函数出现了副作用允许浮点32位数值类型的中间值通过 DynamicJaxprTracer 逃逸出转换的范围
JAX 转换要求函数显式返回其输出不允许将中间值保存到全局状态
泄漏值的函数在 jit 中进行跟踪时泄漏值是在 <stdin>:2f 的行上创建的
当值被创建时最终的 5 个堆栈帧最近的最后不包括 JAX 内部帧如下
    <stdin>:1 (<module>)
    <stdin>:2 (f)

要更早地捕获泄漏请尝试设置环境变量 JAX_CHECK_TRACER_LEAKS 或使用 `jax.checking_leaks` 上下文管理器
有关详细信息请参阅 https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

我以为在 f 函数内部使用 dist = jax.block_until_ready(dist) 可能会解决问题,但实际上没有。

英文:

I want to create a jittable function that outputs a distrax distribution object. For instance:

import distrax
import jax
import jax.numpy as jnp
def f(x):
dist = distrax.Categorical(logits=jnp.sin(x))
return dist
jit_f = jax.jit(f)
a = jnp.array([1,2,3])
dist = jit_f(a)

Currently this code gives me the following error:

Traceback (most recent call last):
File &quot;&lt;stdin&gt;&quot;, line 1, in &lt;module&gt;
File &quot;F:\jax_env\lib\site-packages\jax\_src\traceback_util.py&quot;, line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File &quot;F:\jax_env\lib\site-packages\jax\_src\api.py&quot;, line 628, in cache_miss
out = tree_unflatten(out_pytree_def, out_flat)
File &quot;F:\jax_env\lib\site-packages\jax\_src\tree_util.py&quot;, line 75, in tree_unflatten
return treedef.unflatten(leaves)
File &quot;F:\jax_env\lib\site-packages\distrax\_src\utils\jittable.py&quot;, line 40, in tree_unflatten
obj = cls(*args, **kwargs)
File &quot;F:\jax_env\lib\site-packages\distrax\_src\distributions\categorical.py&quot;, line 60, in __init__
self._logits = None if logits is None else math.normalize(logits=logits)
File &quot;F:\jax_env\lib\site-packages\distrax\_src\utils\math.py&quot;, line 72, in normalize
return jax.nn.log_softmax(logits, axis=-1)
File &quot;F:\jax_env\lib\site-packages\jax\_src\traceback_util.py&quot;, line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File &quot;F:\jax_env\lib\site-packages\jax\_src\api.py&quot;, line 618, in cache_miss
keep_unused=keep_unused))
File &quot;F:\jax_env\lib\site-packages\jax\core.py&quot;, line 2031, in call_bind_with_continuation
top_trace = find_top_trace(args)
File &quot;F:\jax_env\lib\site-packages\jax\core.py&quot;, line 1122, in find_top_trace
top_tracer._assert_live()
File &quot;F:\jax_env\lib\site-packages\jax\interpreters\partial_eval.py&quot;, line 1486, in _assert_live
raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[3] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was f at &lt;stdin&gt;:1 traced for jit.
------------------------------
The leaked intermediate value was created on line &lt;stdin&gt;:2 (f).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
&lt;stdin&gt;:1 (&lt;module&gt;)
&lt;stdin&gt;:2 (f)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

I thought that using dist = jax.block_until_ready(dist) inside f could fix the problem, but it doesn't.

答案1

得分: 1

这看起来像是 distrax v0.1.2 中报告的错误,链接在 https://github.com/deepmind/distrax/issues/162。这已经在 https://github.com/deepmind/distrax/pull/177 中修复,该修复是 distrax v0.1.3 版本的一部分。

为了解决这个问题,您应该升级到 distrax v0.1.3 或更高版本。

英文:

This looks like the bug in distrax v0.1.2 reported in https://github.com/deepmind/distrax/issues/162. This wass fixed by https://github.com/deepmind/distrax/pull/177, which is part of the distrax v0.1.3 release.

To fix the issue, you should update to distrax v0.1.3 or later.

huangapple
  • 本文由 发表于 2023年3月7日 22:45:35
  • 转载请务必保留本文链接:https://go.coder-hub.com/75663449.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定