如何测试 JIT 编译的 Jax 函数是否创建新张量或视图?

huangapple go评论60阅读模式
英文:

How can I test if a jitted Jax function creates new tensor or a view?

问题

I have a basic code like this:

@jit
def concat_permute(indices, in1, in2):
    tensor = jnp.concatenate([jnp.atleast_1d(in1), jnp.atleast_1d(in2)])
    return tensor[indices]

Here is my test tensors:

key = jax.random.PRNGKey(758493)
in1 = tens = jax.random.uniform(key, shape=(15,5,3))
in2 = tens = jax.random.uniform(key, shape=(10,5,3))
indices = jax.random.choice(key, 25, (25,), replace=False)

And here is the Jaxpr of the function:

{ lambda ; a:i32[25] b:f32[15,5,3] c:f32[10,5,3]. let
    d:f32[25,5,3] = xla_call[
      call_jaxpr={ lambda ; e:i32[25] f:f32[15,5,3] g:f32[10,5,3]. let
          h:f32[15,5,3] = xla_call[
            call_jaxpr={ lambda ; i:f32[15,5,3]. let  in (i,) }
            name=atleast_1d
          ] f
          j:f32[10,5,3] = xla_call[
            call_jaxpr={ lambda ; k:f32[10,5,3]. let  in (k,) }
            name=atleast_1d
          ] g
          l:f32[25,5,3] = concatenate[dimension=0] h j
          m:bool[25] = lt e 0
          n:i32[25] = add e 25
          o:i32[25] = select_n m e n
          p:i32[25,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(25, 1)
          ] o
          q:f32[25,5,3] = gather[
            dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
            fill_value=None
            indices_are_sorted=False
            mode=GatherScatterMode.PROMISE_IN_BOUNDS
            slice_sizes=(1, 5, 3)
            unique_indices=False
          ] l p
        in (q,) }
      name=concat_permute
    ] a b c
  in (d,) }

It seems it creates a new tensor using my permutation array but I'm not sure. Is there a more clear way to see if this operation is made by creating new tensor or not?

I tried "jax.make_jaxpr" and see the results but not sure about the problem.

英文:

I have a basic code like this:

@jit
def concat_permute(indices, in1, in2):
    tensor = jnp.concatenate([jnp.atleast_1d(in1), jnp.atleast_1d(in2)])
    return tensor[indices]

Here is my test tensors:

key = jax.random.PRNGKey(758493)
in1 = tens = jax.random.uniform(key, shape=(15,5,3))
in2 = tens = jax.random.uniform(key, shape=(10,5,3))
indices = jax.random.choice(key, 25, (25,), replace=False)

And here is the Jaxpr of the function:

{ lambda ; a:i32[25] b:f32[15,5,3] c:f32[10,5,3]. let
    d:f32[25,5,3] = xla_call[
      call_jaxpr={ lambda ; e:i32[25] f:f32[15,5,3] g:f32[10,5,3]. let
          h:f32[15,5,3] = xla_call[
            call_jaxpr={ lambda ; i:f32[15,5,3]. let  in (i,) }
            name=atleast_1d
          ] f
          j:f32[10,5,3] = xla_call[
            call_jaxpr={ lambda ; k:f32[10,5,3]. let  in (k,) }
            name=atleast_1d
          ] g
          l:f32[25,5,3] = concatenate[dimension=0] h j
          m:bool[25] = lt e 0
          n:i32[25] = add e 25
          o:i32[25] = select_n m e n
          p:i32[25,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(25, 1)
          ] o
          q:f32[25,5,3] = gather[
            dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
            fill_value=None
            indices_are_sorted=False
            mode=GatherScatterMode.PROMISE_IN_BOUNDS
            slice_sizes=(1, 5, 3)
            unique_indices=False
          ] l p
        in (q,) }
      name=concat_permute
    ] a b c
  in (d,) }

It seems it creates a new tensor using my permutation array but I'm not sure. Is there a more clear way to see if this opeeration is made by creating new tensor or not?

I tried "jax.make_jaxpr" and see the results but not sure about the problem.

答案1

得分: 0

短答案是,不,您的函数的输出不会与为 tensor 分配的数组共享内存。

在XLA中,一个数组由一个均匀分布的缓冲区表示,当您从数组中选择随机值时,通常不能通过对输入缓冲区的视图进行均匀分布来构建结果。

英文:

The short answer is, no the output of your function will not share memory with the array allocated for tensor.

In XLA, an array is represented by a uniformly-strided buffer, and when you select random values from an array, the result cannot in general be constructed via uniform-striding over a view of the input buffer.

huangapple
  • 本文由 发表于 2023年6月15日 14:03:55
  • 转载请务必保留本文链接:https://go.coder-hub.com/76479539.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定