英文:
Understanding how JAX's tracer vs static work
问题
以下是您提供的代码的翻译:
我是JAX的新手,试图使用JAX编写一个简单的代码,在某个时候需要使用scipy方法,然后我想取其导数。
代码无法运行并出现错误。以下是代码和错误信息。我多次阅读了JAX的文档,但无法弄清楚如何正确编写代码。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import scipy
key = random.PRNGKey(1)
size = 3
x = random.uniform(key, (size, size), dtype=jnp.float32)
def error_func(x):
dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
error = jnp.sum(jnp.array(dists))
return error
error_diff = grad(error_func)
print(error_func(x))
print(error_diff(x))
我得到以下错误信息:
未找到GPU/TPU,回退到CPU。(设置TF_CPP_MIN_LOG_LEVEL=0并重新运行以获取更多信息。)
3.2158318
Traceback (most recent call last):
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
print(error_diff(x))
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 646, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 722, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2179, in _vjp
out_primal, out_vjp = ad.vjp(
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
XA = np.asarray(XA)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/core.py", line 598, in __array__
raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559 0.3129729 0.12388372]
[0.548188 0.8851279 0.30576992]
[0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
[0.548188 , 0.8851279 , 0.30576992],
[0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[3,3]), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
上述异常是以下异常的直接原因:
Traceback (most recent call last):
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
print(error_diff(x))
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
XA = np.asarray(XA)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559 0.3129729 0.12388372]
[0.548188 0.8851279 0.30576992]
[0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
[0.548188 ,
<details>
<summary>英文:</summary>
I'm new to JAX and trying to write a simple code using JAX where at some point I need to use a scipy method. Then I want to take its derivative.
The code doesn't run and gives me error. The following is the code and the error. I read a the documentation of JAX a couple of times but couldn't Figure out what to do to write the code correctly
```python
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
import scipy
key = random.PRNGKey(1)
size = 3
x = random.uniform(key, (size, size), dtype=jnp.float32)
def error_func(x):
dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
error = jnp.sum(jnp.array(dists))
return error
error_diff = grad(error_func)
print(error_func(x))
print(error_diff(x))
And I get the followig error:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
3.2158318
Traceback (most recent call last):
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
print(error_diff(x))
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 646, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 722, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2179, in _vjp
out_primal, out_vjp = ad.vjp(
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
XA = np.asarray(XA)
File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/core.py", line 598, in __array__
raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559 0.3129729 0.12388372]
[0.548188 0.8851279 0.30576992]
[0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
[0.548188 , 0.8851279 , 0.30576992],
[0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[3,3]), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
print(error_diff(x))
File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
XA = np.asarray(XA)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559 0.3129729 0.12388372]
[0.548188 0.8851279 0.30576992]
[0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
[0.548188 , 0.8851279 , 0.30576992],
[0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[3,3]), None)
recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
答案1
得分: 0
JAX转换只适用于JAX函数,而不适用于numpy或scipy函数(这在上面的错误消息中链接中简要讨论)。如果您想要使用grad
和其他JAX转换,您需要使用JAX操作编写您的逻辑,而不是来自numpy
、scipy
或其他不兼容JAX的库的操作。
JAX当前不包括scipy.spatial.distance
的任何包装器(尽管有一些正在进行中,参见#16147),因此最好的选择是自己编写代码。幸运的是,cdist
相当简单:
def cdist(x, y, metric='euclidean'):
assert x.ndim == y.ndim == 2
if metric != 'euclidean':
raise NotImplementedError(f"{metric=}")
return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))
def error_func(x):
dists = cdist(x, x, metric='euclidean')
error = jnp.sum(dists)
return error
error_diff = grad(error_func)
print(error_func(x))
# 3.2158318
print(error_diff(x))
# [[nan nan nan]
# [nan nan nan]
# [nan nan nan]]
您会注意到梯度处处为nan
。这是预期的结果,因为grad(jnp.sqrt)(0.0)
发散(返回无穷大),而根据定义,0.0 * inf
等于nan
。
英文:
JAX transformations only work on JAX functions, not numpy or scipy functions (this is discussed briefly at the link shown in the error message above) If you want to use grad
and other JAX transformations, you need to write your logic using JAX operations, not operations from numpy
, scipy
, or other non-JAX-compatible libraries.
JAX does not currently include any wrappers of scipy.spatial.distance
(though there are some in progress, see #16147), so the best option would be to write the code yourself. Fortunately, cdist
is pretty straightforward:
def cdist(x, y, metric='euclidean'):
assert x.ndim == y.ndim == 2
if metric != 'euclidean':
raise NotImplementedError(f"{metric=}")
return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))
def error_func(x):
dists = cdist(x, x, metric='euclidean')
error = jnp.sum(dists)
return error
error_diff = grad(error_func)
print(error_func(x))
# 3.2158318
print(error_diff(x))
# [[nan nan nan]
# [nan nan nan]
# [nan nan nan]]
You'll notice that the gradient is everywhere nan
. This is the expected result due to the fact that grad(jnp.sqrt)(0.0)
diverges (returns infinity), and 0.0 * inf
by definition is nan
.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论