理解JAX的追踪器与静态工作方式

huangapple go评论108阅读模式
英文:

Understanding how JAX's tracer vs static work

问题

以下是您提供的代码的翻译:

  1. 我是JAX的新手试图使用JAX编写一个简单的代码在某个时候需要使用scipy方法然后我想取其导数
  2. 代码无法运行并出现错误以下是代码和错误信息我多次阅读了JAX的文档但无法弄清楚如何正确编写代码
  3. import jax.numpy as jnp
  4. from jax import grad, jit, vmap
  5. from jax import random
  6. import numpy as np
  7. import scipy
  8. key = random.PRNGKey(1)
  9. size = 3
  10. x = random.uniform(key, (size, size), dtype=jnp.float32)
  11. def error_func(x):
  12. dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
  13. error = jnp.sum(jnp.array(dists))
  14. return error
  15. error_diff = grad(error_func)
  16. print(error_func(x))
  17. print(error_diff(x))
  18. 我得到以下错误信息
  19. 未找到GPU/TPU回退到CPU。(设置TF_CPP_MIN_LOG_LEVEL=0并重新运行以获取更多信息。)
  20. 3.2158318
  21. Traceback (most recent call last):
  22. File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
  23. print(error_diff(x))
  24. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
  25. return fun(*args, **kwargs)
  26. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 646, in grad_f
  27. _, g = value_and_grad_f(*args, **kwargs)
  28. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
  29. return fun(*args, **kwargs)
  30. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 722, in value_and_grad_f
  31. ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  32. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py", line 2179, in _vjp
  33. out_primal, out_vjp = ad.vjp(
  34. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
  35. out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  36. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
  37. jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  38. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
  39. return func(*args, **kwargs)
  40. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
  41. jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  42. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
  43. ans = self.f(*args, **dict(self.params, **kwargs))
  44. File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
  45. dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
  46. File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
  47. XA = np.asarray(XA)
  48. File "/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/core.py", line 598, in __array__
  49. raise TracerArrayConversionError(self)
  50. jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559 0.3129729 0.12388372]
  51. [0.548188 0.8851279 0.30576992]
  52. [0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
  53. primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
  54. [0.548188 , 0.8851279 , 0.30576992],
  55. [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  56. tangent = Traced<ShapedArray(float32[3,3])>with<JaxprTrace(level=1/0)> with
  57. pval = (ShapedArray(float32[3,3]), None)
  58. recipe = LambdaBinding()
  59. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
  60. The stack trace below excludes JAX-internal frames.
  61. The preceding is the original exception that occurred, unmodified.
  62. --------------------
  63. 上述异常是以下异常的直接原因
  64. Traceback (most recent call last):
  65. File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 26, in <module>
  66. print(error_diff(x))
  67. File "/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py", line 15, in error_func
  68. dists = scipy.spatial.distance.cdist(x, x, metric='euclidean')
  69. File "/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py", line 2909, in cdist
  70. XA = np.asarray(XA)
  71. jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[0.7551559 0.3129729 0.12388372]
  72. [0.548188 0.8851279 0.30576992]
  73. [0.82008433 0.95633745 0.3566252 ]], dtype=float32)>with<JVPTrace(level=2/0)> with
  74. primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
  75. [0.548188 ,
  76. <details>
  77. <summary>英文:</summary>
  78. I&#39;m new to JAX and trying to write a simple code using JAX where at some point I need to use a scipy method. Then I want to take its derivative.
  79. The code doesn&#39;t run and gives me error. The following is the code and the error. I read a the documentation of JAX a couple of times but couldn&#39;t Figure out what to do to write the code correctly
  80. ```python
  81. import jax.numpy as jnp
  82. from jax import grad, jit, vmap
  83. from jax import random
  84. import numpy as np
  85. import scipy
  86. key = random.PRNGKey(1)
  87. size = 3
  88. x = random.uniform(key, (size, size), dtype=jnp.float32)
  89. def error_func(x):
  90. dists = scipy.spatial.distance.cdist(x, x, metric=&#39;euclidean&#39;)
  91. error = jnp.sum(jnp.array(dists))
  92. return error
  93. error_diff = grad(error_func)
  94. print(error_func(x))
  95. print(error_diff(x))

And I get the followig error:

  1. No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  2. 3.2158318
  3. Traceback (most recent call last):
  4. File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 26, in &lt;module&gt;
  5. print(error_diff(x))
  6. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py&quot;, line 166, in reraise_with_filtered_traceback
  7. return fun(*args, **kwargs)
  8. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py&quot;, line 646, in grad_f
  9. _, g = value_and_grad_f(*args, **kwargs)
  10. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/traceback_util.py&quot;, line 166, in reraise_with_filtered_traceback
  11. return fun(*args, **kwargs)
  12. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py&quot;, line 722, in value_and_grad_f
  13. ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  14. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/api.py&quot;, line 2179, in _vjp
  15. out_primal, out_vjp = ad.vjp(
  16. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py&quot;, line 139, in vjp
  17. out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  18. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/ad.py&quot;, line 128, in linearize
  19. jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  20. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/profiler.py&quot;, line 314, in wrapper
  21. return func(*args, **kwargs)
  22. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py&quot;, line 777, in trace_to_jaxpr_nounits
  23. jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  24. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/linear_util.py&quot;, line 188, in call_wrapped
  25. ans = self.f(*args, **dict(self.params, **kwargs))
  26. File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 15, in error_func
  27. dists = scipy.spatial.distance.cdist(x, x, metric=&#39;euclidean&#39;)
  28. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py&quot;, line 2909, in cdist
  29. XA = np.asarray(XA)
  30. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/jax/_src/core.py&quot;, line 598, in __array__
  31. raise TracerArrayConversionError(self)
  32. jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced&lt;ConcreteArray([[0.7551559 0.3129729 0.12388372]
  33. [0.548188 0.8851279 0.30576992]
  34. [0.82008433 0.95633745 0.3566252 ]], dtype=float32)&gt;with&lt;JVPTrace(level=2/0)&gt; with
  35. primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
  36. [0.548188 , 0.8851279 , 0.30576992],
  37. [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  38. tangent = Traced&lt;ShapedArray(float32[3,3])&gt;with&lt;JaxprTrace(level=1/0)&gt; with
  39. pval = (ShapedArray(float32[3,3]), None)
  40. recipe = LambdaBinding()
  41. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
  42. The stack trace below excludes JAX-internal frames.
  43. The preceding is the original exception that occurred, unmodified.
  44. --------------------
  45. The above exception was the direct cause of the following exception:
  46. Traceback (most recent call last):
  47. File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 26, in &lt;module&gt;
  48. print(error_diff(x))
  49. File &quot;/mnt/d/OneDrive - UW-Madison/opttest/jaxtest6.py&quot;, line 15, in error_func
  50. dists = scipy.spatial.distance.cdist(x, x, metric=&#39;euclidean&#39;)
  51. File &quot;/home/sattarian/.local/lib/python3.9/site-packages/scipy/spatial/distance.py&quot;, line 2909, in cdist
  52. XA = np.asarray(XA)
  53. jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced&lt;ConcreteArray([[0.7551559 0.3129729 0.12388372]
  54. [0.548188 0.8851279 0.30576992]
  55. [0.82008433 0.95633745 0.3566252 ]], dtype=float32)&gt;with&lt;JVPTrace(level=2/0)&gt; with
  56. primal = Array([[0.7551559 , 0.3129729 , 0.12388372],
  57. [0.548188 , 0.8851279 , 0.30576992],
  58. [0.82008433, 0.95633745, 0.3566252 ]], dtype=float32)
  59. tangent = Traced&lt;ShapedArray(float32[3,3])&gt;with&lt;JaxprTrace(level=1/0)&gt; with
  60. pval = (ShapedArray(float32[3,3]), None)
  61. recipe = LambdaBinding()
  62. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

答案1

得分: 0

JAX转换只适用于JAX函数,而不适用于numpy或scipy函数(这在上面的错误消息中链接中简要讨论)。如果您想要使用grad和其他JAX转换,您需要使用JAX操作编写您的逻辑,而不是来自numpyscipy或其他不兼容JAX的库的操作。

JAX当前不包括scipy.spatial.distance的任何包装器(尽管有一些正在进行中,参见#16147),因此最好的选择是自己编写代码。幸运的是,cdist相当简单:

  1. def cdist(x, y, metric='euclidean'):
  2. assert x.ndim == y.ndim == 2
  3. if metric != 'euclidean':
  4. raise NotImplementedError(f"{metric=}")
  5. return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))
  6. def error_func(x):
  7. dists = cdist(x, x, metric='euclidean')
  8. error = jnp.sum(dists)
  9. return error
  10. error_diff = grad(error_func)
  11. print(error_func(x))
  12. # 3.2158318
  13. print(error_diff(x))
  14. # [[nan nan nan]
  15. # [nan nan nan]
  16. # [nan nan nan]]

您会注意到梯度处处为nan。这是预期的结果,因为grad(jnp.sqrt)(0.0)发散(返回无穷大),而根据定义,0.0 * inf等于nan

英文:

JAX transformations only work on JAX functions, not numpy or scipy functions (this is discussed briefly at the link shown in the error message above) If you want to use grad and other JAX transformations, you need to write your logic using JAX operations, not operations from numpy, scipy, or other non-JAX-compatible libraries.

JAX does not currently include any wrappers of scipy.spatial.distance (though there are some in progress, see #16147), so the best option would be to write the code yourself. Fortunately, cdist is pretty straightforward:

  1. def cdist(x, y, metric=&#39;euclidean&#39;):
  2. assert x.ndim == y.ndim == 2
  3. if metric != &#39;euclidean&#39;:
  4. raise NotImplementedError(f&quot;{metric=}&quot;)
  5. return jnp.sqrt(jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1))
  6. def error_func(x):
  7. dists = cdist(x, x, metric=&#39;euclidean&#39;)
  8. error = jnp.sum(dists)
  9. return error
  10. error_diff = grad(error_func)
  11. print(error_func(x))
  12. # 3.2158318
  13. print(error_diff(x))
  14. # [[nan nan nan]
  15. # [nan nan nan]
  16. # [nan nan nan]]

You'll notice that the gradient is everywhere nan. This is the expected result due to the fact that grad(jnp.sqrt)(0.0) diverges (returns infinity), and 0.0 * inf by definition is nan.

huangapple
  • 本文由 发表于 2023年6月13日 06:56:28
  • 转载请务必保留本文链接:https://go.coder-hub.com/76460770.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定