Jax: 在 **JIT** 下生成随机数

huangapple go评论107阅读模式
英文:

Jax: generating random numbers under **JIT**

问题

Here's the translated code portion:

我有一个设置需要生成一些随机数这些随机数由`vmap`消耗然后稍后由`lax.scan`使用

```py
def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
    ...
    return num.astype(int)

def forward(key: Array, input: Array) -> Array:
    k = generate_random(key, 1, 5)
    computation = model(.., k, ..)
    ...

# 计算前向传播
output = jax.vmap(forward, in_axes=.....

但试图将numjax.Array转换为int32会导致ConcretizationError

可以通过这个最小示例来复现:

@jax.jit
def t():
  return jnp.zeros((1,)).item().astype(int)

o = t()
o

JIT要求所有操作都是Jax类型的。

vmap隐式使用JIT。出于性能原因,我希望保留它。


我的尝试

这是我的一种不太正规的尝试:

@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
  key, subkey = jax.random.split(key)
  random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
  return random_number.astype(int)

def react_forward(key: Array, input: Array) -> Array:
  k = get_rand_num(key, 1, MAX_ITERS)
  # 在不跟踪梯度的情况下执行模型的前向传递
  intermediate_array = jax.lax.stop_gradient(model(input, k)) # 这一行报错
  ...
  return ...

a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape

这涉及创建与vmap中的每个批次使用的子密钥数量相同的batch_size个子密钥,从而获得随机数。

但它不起作用,因为k被从jax.Array -> int转换。

但进行以下更改:

-  k = get_rand_num(key, 1, MAX_ITERS)
+  k = 5 # 任何硬编码的整数

可以正常工作。显然,采样在这里引起了问题...


澄清

为了不将问题复杂化,我将明确定义我所需要的:

我正在实现一种随机深度的版本;基本上,我的model的前向传播可以在运行时接受一个depth: int,这是内部运行的scan的长度 - 具体来说,是xs = jnp.arange(depth)用于scan

我希望我的体系结构能够灵活地适应不同的深度。因此,在训练时,我需要一种方法来生成在某个范围内等于depth的伪随机数。

因此,我需要一个函数,每次调用它时(这在vmap中是这样的情况)它都会返回一个不同的数字,采样在某个范围内:depth ∈ [1, max_iters]

该函数必须是可jit化的(vmap的隐式要求),并且必须产生一个int - 因为这是后来输入到jnp.arange的内容(可能可以直接让generate_random生成一个包含jnp.arange(depth)Array,而不转换为静态值的解决方案可能可行)。

(老实说,我不知道其他人是如何做到的;这似乎是一个常见的需求,特别是如果在训练时处理采样的话)

如果我的“不太正规的尝试”的错误回溯有所帮助,我会感激不已...

英文:

I have a setup where I need to generate some random number that is consumed by vmap and then lax.scan later on:

def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
    ...
    return num.astype(int)

def forward(key: Array, input: Array) -> Array:
    k = generate_random(key, 1, 5)
    computation = model(.., k, ..)
    ...

# Computing the forward pass
output = jax.vmap(forward, in_axes=.....

But attempting to convert num from a jax.Array to an int32 causes the ConcretizationError.

This can be reproduced through this minimal example:

@jax.jit
def t():
  return jnp.zeros((1,)).item().astype(int)

o = t()
o

JIT requires that all the manipulations be of the Jax type.

But vmap uses JIT implicitly. And I would prefer to keep it for performance reasons.


My Attempt

This was my hacky attempt:

@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
  key, subkey = jax.random.split(key)
  random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
  return random_number.astype(int)

def react_forward(key: Array, input: Array) -> Array:
  k = get_rand_num(key, 1, MAX_ITERS)
  # forward pass the model without tracking grads
  intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
  ...
  return ...

a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape

Which involves creating the batch_size # of subkeys to use at every batch during vmap (a.shape[0]) thus getting random numbers.

But it doesn't work, because of the k being casted from jax.Array -> int.

But making these changes:

> diff
> - k = get_rand_num(key, 1, MAX_ITERS)
> + k = 5 # any hardcoded int
>

Works perfectly. Clearly, the sampling is causing the problem here...


Clarifications

To not make this into an X-Y problem I'll clearly define what I want precisely:

I'm implementing a version of stochastic depth; basically, my model's forward pass can accept a depth: int at runtime which is the length of a scan run internally - specifically, the xs = jnp.arange(depth) for the scan.

I want my architecture to flexibly adapt to different depths. Therefore, at training time, I need a way to produce pseudorandom numbers that would equal the depth.

So I require a function, that on every call to it (such is the case in vmap) it returns a different number, sampled within some bound: depth ∈ [1, max_iters].

The function has to be jit-able (implicit requirement of vmap) and has to produce an int - as that's what fed into jnp.arange later (Workarounds that directly get generate_random to produce an Array of jnp.arange(depth) without converting to a static value might be possible)

> (I have no idea honestly how others do this; this seems like a common enough want, especially if one's dealing with sampling during train time)

I've attached the error traceback generated by my "hacky solution attempt" if that helps...

---------------------------------------------------------------------------

ConcretizationTypeError                   Traceback (most recent call last)

<ipython-input-32-d6ff062f5054> in <cell line: 16>()
     14 a = jnp.zeros((300, 32)).astype(int)
     15 rndm_keys = jax.random.split(key, a.shape[0])
---> 16 jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape

    [... skipping hidden 3 frame]

4 frames

<ipython-input-32-d6ff062f5054> in react_forward(key, input)
      8   k = get_rand_num(key, 1, MAX_ITERS)
      9   # forward pass the model without tracking grads
---> 10   intermediate_array = jax.lax.stop_gradient(model(input, iters_to_do=k))
     11   # n-k passes, but track the gradient this time
     12   return model(input, MAX_ITERS - k, intermediate_array)

    [... skipping hidden 12 frame]

<ipython-input-22-4760d53eb89c> in __call__(self, input, iters_to_do, prev_thought)
     71       #interim_thought = self.main_block(interim_thought)
     72 
---> 73     interim_thought = self.iterate_for_steps(interim_thought, iters_to_do, x)
     74 
     75     return self.out_head(interim_thought)

    [... skipping hidden 12 frame]

<ipython-input-22-4760d53eb89c> in iterate_for_steps(self, interim_thought, iters_to_do, x)
     56         return self.main_block(interim_thought), None
     57 
---> 58     final_interim_thought, _ = jax.lax.scan(loop_body, interim_thought, jnp.arange(iters_to_do))
     59     return final_interim_thought
     60 

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)
   2286     util.check_arraylike("arange", start)
   2287     if stop is None and step is None:
-> 2288       start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
   2289     else:
   2290       start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'")

/usr/local/lib/python3.10/dist-packages/jax/_src/core.py in concrete_or_error(force, val, context)
   1379       return force(val.aval.val)
   1380     else:
-> 1381       raise ConcretizationTypeError(val, context)
   1382   else:
   1383     return force(val)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
It arose in the jnp.arange argument 'stop'
This BatchTracer with object id 140406974192336 was created on line:
  <ipython-input-32-d6ff062f5054>:8 (react_forward)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Really appreciate you helping me out here. Cheers!

答案1

得分: 1

以下是翻译好的部分:

随机生成的随机数被跟踪;换句话说,它们只在运行时知道(您对“运行时”的心理模型应该是“在XLA设备上运行的操作”)。

另一方面,Python整数是静态的;换句话说,它们必须在编译时定义(您对“编译时”的心理模型应该是“在运行时值确定之前发生的操作”)。

通过这种方式,很明显您不能在jitvmap或任何其他JAX转换中将一个被跟踪的值转换为静态的Python整数,因为静态值必须在确定被跟踪的值之前知道。在您的最小示例中,这在调用.item()时会出现问题,该调用试图将(被跟踪的)JAX数组转换为(静态的)Python标量。

您可以通过避免进行此转换来解决此问题。以下是您的函数的新版本,该版本返回零维整数数组,这是JAX在运行时编码整数标量的方式:

@jax.jit
def t():
  return jnp.zeros((1,)).astype(int).reshape(())

尽管如此,您如此关心从数组创建整数的事实使我认为您的model函数要求其第二个参数是静态的,不幸的是,上述方法在这种情况下对您没有帮助。出于上面讨论的原因,不可能将JAX转换中的跟踪值转换为静态值

编辑:您遇到的问题是JAX数组必须具有静态形状的事实。在您的代码中,您正在运行时生成随机整数,并尝试将它们传递给jnp.arange,这将导致动态形状的数组。在jitvmap等转换中执行这样的代码是不可能的。

通常,解决此问题涉及以支持动态计算大小的方式编写代码(例如,创建最大大小的填充数组,或在jax.scan的位置使用jax.fori_loop)。

英文:

The random numbers you are generating are traced; in other words, they are known only at runtime (your mental model of "runtime" should be "operations running on the XLA device").

Python integers, on the other hand, are static; in other words, they must be defined at compile time (your mental model of "compile time" should be "operations that happen before runtime values are known").

With this framing, it's clear that you cannot convert a traced value to a static Python integer within jit, vmap or any other JAX transform, because static values must be known before the traced values are determined. Where this comes up in your minimal example is in the call to .item(), which attempts to cast a (traced) JAX array to a (static) Python scalar.

You can fix this by avoiding this cast. Here is a new version of your function that returns a zero-dimensional integer array, which is how JAX encodes an integer scalar at runtime:

@jax.jit
def t():
  return jnp.zeros((1,)).astype(int).reshape(())

That said, the fact that you are so concerned with creating an integer from an array makes me think that your model function requires its second argument to be static, and unfortunately the above won't help you in that case. For the reasons discussed above it is impossible to convert a traced value within a JAX transformation into a static value.


Edit: the issue you're running into is the fact that JAX arrays must have static shapes. In your code, you're generating random integers at runtime, and attempting to pass them to jnp.arange, which would result in a dynamically-shaped array. It is not possible to execute such code within transformations like jit or vmap.

Fixing this usually involves writing your code in a way that supports the dynamic computation size (for example, creating a padded array of a maximum size, or using jax.fori_loop in place of jax.scan).

huangapple
  • 本文由 发表于 2023年8月5日 02:56:45
  • 转载请务必保留本文链接:https://go.coder-hub.com/76838537.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定