理解JAX的追踪器与静态工作方式

huangapple go评论84阅读模式
英文:

Understanding how JAX's tracer vs static work

问题

以下是您提供的代码的翻译:

我是JAX的新手试图使用JAX编写一个简单的代码在某个时候需要使用scipy方法然后我想取其导数

代码无法运行并出现错误以下是代码和错误信息我多次阅读了JAX的文档但无法弄清楚如何正确编写代码

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import scipy
key = random.PRNGKey(1)

size = 3
x = random.uniform(key, (size, size), dtype=jnp.float32)

def error_func(x):
    dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
    error = jnp.sum(jnp.array(dists))
    return error

error_diff = grad(error_func)

print(error_func(x))
print(error_diff(x))

我得到以下错误信息

未找到GPU/TPU回退到CPU。(设置TF_CPP_MIN_LOG_LEVEL=0并重新运行以获取更多信息。)
3.2158318
Traceback (most recent call last):
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
    print(error_diff(x))
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 646, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 722, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2179, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
    dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
  File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
    XA = np.asarray(XA)
  File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/core.py", line 598, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559  0.3129729  0.12388372]
 [0.548188   0.8851279  0.30576992]
 [0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
       [0.548188  , 0.8851279 , 0.30576992],
       [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

上述异常是以下异常的直接原因

Traceback (most recent call last):
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
    print(error_diff(x))
  File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
    dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
  File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
    XA = np.asarray(XA)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559  0.3129729  0.12388372]
 [0.548188   0.8851279  0.30576992]
 [0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
       [0.548188  , 

<details>
<summary>英文:</summary>

I&#39;m new to JAX and trying to write a simple code using JAX where at some point I need to use a scipy method. Then I want to take its derivative. 


The code doesn&#39;t run and gives me error. The following is the code and the error. I read a the documentation of JAX a couple of times but couldn&#39;t Figure out what to do to write the code correctly
```python
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import scipy
key = random.PRNGKey(1)

size = 3
x = random.uniform(key, (size, size), dtype=jnp.float32)

def error_func(x):
    dists = scipy.spatial.distance.cdist(x, x, metric=&#39;euclidean&#39;)
    error = jnp.sum(jnp.array(dists))
    return error

error_diff = grad(error_func)

print(error_func(x))
print(error_diff(x))

And I get the followig error:

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
3.2158318
Traceback (most recent call last):
  File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 26, in &lt;module&gt;
    print(error_diff(x))
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py&quot;, line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py&quot;, line 646, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py&quot;, line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py&quot;, line 722, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py&quot;, line 2179, in _vjp
    out_primal, out_vjp = ad.vjp(
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py&quot;, line 139, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py&quot;, line 128, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/profiler.py&quot;, line 314, in wrapper
    return func(*args, **kwargs)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py&quot;, line 777, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/linear_util.py&quot;, line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 15, in error_func
    dists = scipy.spatial.distance.cdist(x, x, metric=&#39;euclidean&#39;)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py&quot;, line 2909, in cdist
    XA = np.asarray(XA)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/core.py&quot;, line 598, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced&lt;ConcreteArray([[0.7551559  0.3129729  0.12388372]
 [0.548188   0.8851279  0.30576992]
 [0.82008433 0.95633745 0.3566252 ]], dtype=float32)&gt;with&lt;JVPTrace(level=2/0)&gt; with
  primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
       [0.548188  , 0.8851279 , 0.30576992],
       [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  tangent = Traced&lt;ShapedArray(float32[3,3])&gt;with&lt;JaxprTrace(level=1/0)&gt; with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 26, in &lt;module&gt;
    print(error_diff(x))
  File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 15, in error_func
    dists = scipy.spatial.distance.cdist(x, x, metric=&#39;euclidean&#39;)
  File &quot;/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py&quot;, line 2909, in cdist
    XA = np.asarray(XA)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced&lt;ConcreteArray([[0.7551559  0.3129729  0.12388372]
 [0.548188   0.8851279  0.30576992]
 [0.82008433 0.95633745 0.3566252 ]], dtype=float32)&gt;with&lt;JVPTrace(level=2/0)&gt; with
  primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
       [0.548188  , 0.8851279 , 0.30576992],
       [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  tangent = Traced&lt;ShapedArray(float32[3,3])&gt;with&lt;JaxprTrace(level=1/0)&gt; with
    pval = (ShapedArray(float32[3,3]), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

答案1

得分: 0

JAX转换只适用于JAX函数,而不适用于numpy或scipy函数(这在上面的错误消息中链接中简要讨论)。如果您想要使用grad和其他JAX转换,您需要使用JAX操作编写您的逻辑,而不是来自numpyscipy或其他不兼容JAX的库的操作。

JAX当前不包括scipy.spatial.distance的任何包装器(尽管有一些正在进行中,参见#16147),因此最好的选择是自己编写代码。幸运的是,cdist相当简单:

def cdist(x, y, metric='euclidean'):
  assert x.ndim == y.ndim == 2
  if metric != 'euclidean':
    raise NotImplementedError(f"{metric=}")
  return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))

def error_func(x):
    dists = cdist(x, x, metric='euclidean')
    error = jnp.sum(dists)
    return error

error_diff = grad(error_func)

print(error_func(x))
# 3.2158318

print(error_diff(x))
# [[nan nan nan]
#  [nan nan nan]
#  [nan nan nan]]

您会注意到梯度处处为nan。这是预期的结果,因为grad(jnp.sqrt)(0.0)发散(返回无穷大),而根据定义,0.0 * inf等于nan

英文:

JAX transformations only work on JAX functions, not numpy or scipy functions (this is discussed briefly at the link shown in the error message above) If you want to use grad and other JAX transformations, you need to write your logic using JAX operations, not operations from numpy, scipy, or other non-JAX-compatible libraries.

JAX does not currently include any wrappers of scipy.spatial.distance (though there are some in progress, see #16147), so the best option would be to write the code yourself. Fortunately, cdist is pretty straightforward:

def cdist(x, y, metric=&#39;euclidean&#39;):
  assert x.ndim == y.ndim == 2
  if metric != &#39;euclidean&#39;:
    raise NotImplementedError(f&quot;{metric=}&quot;)
  return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))

def error_func(x):
    dists = cdist(x, x, metric=&#39;euclidean&#39;)
    error = jnp.sum(dists)
    return error

error_diff = grad(error_func)

print(error_func(x))
# 3.2158318

print(error_diff(x))
# [[nan nan nan]
#  [nan nan nan]
#  [nan nan nan]]

You'll notice that the gradient is everywhere nan. This is the expected result due to the fact that grad(jnp.sqrt)(0.0) diverges (returns infinity), and 0.0 * inf by definition is nan.

huangapple
  • 本文由 发表于 2023年6月13日 06:56:28
  • 转载请务必保留本文链接:https://go.coder-hub.com/76460770.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定