英文:
Jax: generating random numbers under **JIT**
问题
Here's the translated code portion:
我有一个设置,需要生成一些随机数,这些随机数由`vmap`消耗,然后稍后由`lax.scan`使用:
```py
def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
...
return num.astype(int)
def forward(key: Array, input: Array) -> Array:
k = generate_random(key, 1, 5)
computation = model(.., k, ..)
...
# 计算前向传播
output = jax.vmap(forward, in_axes=.....
但试图将num
从jax.Array
转换为int32
会导致ConcretizationError
。
可以通过这个最小示例来复现:
@jax.jit
def t():
return jnp.zeros((1,)).item().astype(int)
o = t()
o
JIT要求所有操作都是Jax类型的。
但vmap
隐式使用JIT。出于性能原因,我希望保留它。
我的尝试
这是我的一种不太正规的尝试:
@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
key, subkey = jax.random.split(key)
random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
return random_number.astype(int)
def react_forward(key: Array, input: Array) -> Array:
k = get_rand_num(key, 1, MAX_ITERS)
# 在不跟踪梯度的情况下执行模型的前向传递
intermediate_array = jax.lax.stop_gradient(model(input, k)) # 这一行报错
...
return ...
a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape
这涉及创建与vmap
中的每个批次使用的子密钥数量相同的batch_size
个子密钥,从而获得随机数。
但它不起作用,因为k
被从jax.Array -> int
转换。
但进行以下更改:
- k = get_rand_num(key, 1, MAX_ITERS)
+ k = 5 # 任何硬编码的整数
可以正常工作。显然,采样在这里引起了问题...
澄清
为了不将问题复杂化,我将明确定义我所需要的:
我正在实现一种随机深度的版本;基本上,我的model
的前向传播可以在运行时接受一个depth: int
,这是内部运行的scan
的长度 - 具体来说,是xs = jnp.arange(depth)
用于scan
。
我希望我的体系结构能够灵活地适应不同的深度。因此,在训练时,我需要一种方法来生成在某个范围内等于depth
的伪随机数。
因此,我需要一个函数,每次调用它时(这在vmap
中是这样的情况)它都会返回一个不同的数字,采样在某个范围内:depth ∈ [1, max_iters]
。
该函数必须是可jit
化的(vmap
的隐式要求),并且必须产生一个int
- 因为这是后来输入到jnp.arange
的内容(可能可以直接让generate_random
生成一个包含jnp.arange(depth)
的Array
,而不转换为静态值的解决方案可能可行)。
(老实说,我不知道其他人是如何做到的;这似乎是一个常见的需求,特别是如果在训练时处理采样的话)
如果我的“不太正规的尝试”的错误回溯有所帮助,我会感激不已...
英文:
I have a setup where I need to generate some random number that is consumed by vmap
and then lax.scan
later on:
def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
...
return num.astype(int)
def forward(key: Array, input: Array) -> Array:
k = generate_random(key, 1, 5)
computation = model(.., k, ..)
...
# Computing the forward pass
output = jax.vmap(forward, in_axes=.....
But attempting to convert num
from a jax.Array
to an int32
causes the ConcretizationError
.
This can be reproduced through this minimal example:
@jax.jit
def t():
return jnp.zeros((1,)).item().astype(int)
o = t()
o
JIT requires that all the manipulations be of the Jax type.
But vmap
uses JIT implicitly. And I would prefer to keep it for performance reasons.
My Attempt
This was my hacky attempt:
@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
key, subkey = jax.random.split(key)
random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
return random_number.astype(int)
def react_forward(key: Array, input: Array) -> Array:
k = get_rand_num(key, 1, MAX_ITERS)
# forward pass the model without tracking grads
intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
...
return ...
a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape
Which involves creating the batch_size
# of subkeys to use at every batch during vmap
(a.shape[0]
) thus getting random numbers.
But it doesn't work, because of the k
being casted from jax.Array -> int
.
But making these changes:
> diff
> - k = get_rand_num(key, 1, MAX_ITERS)
> + k = 5 # any hardcoded int
>
Works perfectly. Clearly, the sampling is causing the problem here...
Clarifications
To not make this into an X-Y problem I'll clearly define what I want precisely:
I'm implementing a version of stochastic depth; basically, my model
's forward pass can accept a depth: int
at runtime which is the length of a scan
run internally - specifically, the xs = jnp.arange(depth)
for the scan
.
I want my architecture to flexibly adapt to different depths. Therefore, at training time, I need a way to produce pseudorandom numbers that would equal the depth
.
So I require a function, that on every call to it (such is the case in vmap
) it returns a different number, sampled within some bound: depth ∈ [1, max_iters]
.
The function has to be jit
-able (implicit requirement of vmap
) and has to produce an int
- as that's what fed into jnp.arange
later (Workarounds that directly get generate_random
to produce an Array
of jnp.arange(depth)
without converting to a static value might be possible)
> (I have no idea honestly how others do this; this seems like a common enough want, especially if one's dealing with sampling during train time)
I've attached the error traceback generated by my "hacky solution attempt" if that helps...
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-32-d6ff062f5054> in <cell line: 16>()
14 a = jnp.zeros((300, 32)).astype(int)
15 rndm_keys = jax.random.split(key, a.shape[0])
---> 16 jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape
[... skipping hidden 3 frame]
4 frames
<ipython-input-32-d6ff062f5054> in react_forward(key, input)
8 k = get_rand_num(key, 1, MAX_ITERS)
9 # forward pass the model without tracking grads
---> 10 intermediate_array = jax.lax.stop_gradient(model(input, iters_to_do=k))
11 # n-k passes, but track the gradient this time
12 return model(input, MAX_ITERS - k, intermediate_array)
[... skipping hidden 12 frame]
<ipython-input-22-4760d53eb89c> in __call__(self, input, iters_to_do, prev_thought)
71 #interim_thought = self.main_block(interim_thought)
72
---> 73 interim_thought = self.iterate_for_steps(interim_thought, iters_to_do, x)
74
75 return self.out_head(interim_thought)
[... skipping hidden 12 frame]
<ipython-input-22-4760d53eb89c> in iterate_for_steps(self, interim_thought, iters_to_do, x)
56 return self.main_block(interim_thought), None
57
---> 58 final_interim_thought, _ = jax.lax.scan(loop_body, interim_thought, jnp.arange(iters_to_do))
59 return final_interim_thought
60
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)
2286 util.check_arraylike("arange", start)
2287 if stop is None and step is None:
-> 2288 start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
2289 else:
2290 start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'")
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py in concrete_or_error(force, val, context)
1379 return force(val.aval.val)
1380 else:
-> 1381 raise ConcretizationTypeError(val, context)
1382 else:
1383 return force(val)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
It arose in the jnp.arange argument 'stop'
This BatchTracer with object id 140406974192336 was created on line:
<ipython-input-32-d6ff062f5054>:8 (react_forward)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Really appreciate you helping me out here. Cheers!
答案1
得分: 1
以下是翻译好的部分:
随机生成的随机数被跟踪;换句话说,它们只在运行时知道(您对“运行时”的心理模型应该是“在XLA设备上运行的操作”)。
另一方面,Python整数是静态的;换句话说,它们必须在编译时定义(您对“编译时”的心理模型应该是“在运行时值确定之前发生的操作”)。
通过这种方式,很明显您不能在jit
、vmap
或任何其他JAX转换中将一个被跟踪的值转换为静态的Python整数,因为静态值必须在确定被跟踪的值之前知道。在您的最小示例中,这在调用.item()
时会出现问题,该调用试图将(被跟踪的)JAX数组转换为(静态的)Python标量。
您可以通过避免进行此转换来解决此问题。以下是您的函数的新版本,该版本返回零维整数数组,这是JAX在运行时编码整数标量的方式:
@jax.jit
def t():
return jnp.zeros((1,)).astype(int).reshape(())
尽管如此,您如此关心从数组创建整数的事实使我认为您的model
函数要求其第二个参数是静态的,不幸的是,上述方法在这种情况下对您没有帮助。出于上面讨论的原因,不可能将JAX转换中的跟踪值转换为静态值。
编辑:您遇到的问题是JAX数组必须具有静态形状的事实。在您的代码中,您正在运行时生成随机整数,并尝试将它们传递给jnp.arange
,这将导致动态形状的数组。在jit
或vmap
等转换中执行这样的代码是不可能的。
通常,解决此问题涉及以支持动态计算大小的方式编写代码(例如,创建最大大小的填充数组,或在jax.scan
的位置使用jax.fori_loop
)。
英文:
The random numbers you are generating are traced; in other words, they are known only at runtime (your mental model of "runtime" should be "operations running on the XLA device").
Python integers, on the other hand, are static; in other words, they must be defined at compile time (your mental model of "compile time" should be "operations that happen before runtime values are known").
With this framing, it's clear that you cannot convert a traced value to a static Python integer within jit
, vmap
or any other JAX transform, because static values must be known before the traced values are determined. Where this comes up in your minimal example is in the call to .item()
, which attempts to cast a (traced) JAX array to a (static) Python scalar.
You can fix this by avoiding this cast. Here is a new version of your function that returns a zero-dimensional integer array, which is how JAX encodes an integer scalar at runtime:
@jax.jit
def t():
return jnp.zeros((1,)).astype(int).reshape(())
That said, the fact that you are so concerned with creating an integer from an array makes me think that your model
function requires its second argument to be static, and unfortunately the above won't help you in that case. For the reasons discussed above it is impossible to convert a traced value within a JAX transformation into a static value.
Edit: the issue you're running into is the fact that JAX arrays must have static shapes. In your code, you're generating random integers at runtime, and attempting to pass them to jnp.arange
, which would result in a dynamically-shaped array. It is not possible to execute such code within transformations like jit
or vmap
.
Fixing this usually involves writing your code in a way that supports the dynamic computation size (for example, creating a padded array of a maximum size, or using jax.fori_loop
in place of jax.scan
).
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论