英文:
How can I test if a jitted Jax function creates new tensor or a view?
问题
I have a basic code like this:
@jit
def concat_permute(indices, in1, in2):
tensor = jnp.concatenate([jnp.atleast_1d(in1), jnp.atleast_1d(in2)])
return tensor[indices]
Here is my test tensors:
key = jax.random.PRNGKey(758493)
in1 = tens = jax.random.uniform(key, shape=(15,5,3))
in2 = tens = jax.random.uniform(key, shape=(10,5,3))
indices = jax.random.choice(key, 25, (25,), replace=False)
And here is the Jaxpr of the function:
{ lambda ; a:i32[25] b:f32[15,5,3] c:f32[10,5,3]. let
d:f32[25,5,3] = xla_call[
call_jaxpr={ lambda ; e:i32[25] f:f32[15,5,3] g:f32[10,5,3]. let
h:f32[15,5,3] = xla_call[
call_jaxpr={ lambda ; i:f32[15,5,3]. let in (i,) }
name=atleast_1d
] f
j:f32[10,5,3] = xla_call[
call_jaxpr={ lambda ; k:f32[10,5,3]. let in (k,) }
name=atleast_1d
] g
l:f32[25,5,3] = concatenate[dimension=0] h j
m:bool[25] = lt e 0
n:i32[25] = add e 25
o:i32[25] = select_n m e n
p:i32[25,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(25, 1)
] o
q:f32[25,5,3] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 5, 3)
unique_indices=False
] l p
in (q,) }
name=concat_permute
] a b c
in (d,) }
It seems it creates a new tensor using my permutation array but I'm not sure. Is there a more clear way to see if this operation is made by creating new tensor or not?
I tried "jax.make_jaxpr" and see the results but not sure about the problem.
英文:
I have a basic code like this:
@jit
def concat_permute(indices, in1, in2):
tensor = jnp.concatenate([jnp.atleast_1d(in1), jnp.atleast_1d(in2)])
return tensor[indices]
Here is my test tensors:
key = jax.random.PRNGKey(758493)
in1 = tens = jax.random.uniform(key, shape=(15,5,3))
in2 = tens = jax.random.uniform(key, shape=(10,5,3))
indices = jax.random.choice(key, 25, (25,), replace=False)
And here is the Jaxpr of the function:
{ lambda ; a:i32[25] b:f32[15,5,3] c:f32[10,5,3]. let
d:f32[25,5,3] = xla_call[
call_jaxpr={ lambda ; e:i32[25] f:f32[15,5,3] g:f32[10,5,3]. let
h:f32[15,5,3] = xla_call[
call_jaxpr={ lambda ; i:f32[15,5,3]. let in (i,) }
name=atleast_1d
] f
j:f32[10,5,3] = xla_call[
call_jaxpr={ lambda ; k:f32[10,5,3]. let in (k,) }
name=atleast_1d
] g
l:f32[25,5,3] = concatenate[dimension=0] h j
m:bool[25] = lt e 0
n:i32[25] = add e 25
o:i32[25] = select_n m e n
p:i32[25,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(25, 1)
] o
q:f32[25,5,3] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))
fill_value=None
indices_are_sorted=False
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1, 5, 3)
unique_indices=False
] l p
in (q,) }
name=concat_permute
] a b c
in (d,) }
It seems it creates a new tensor using my permutation array but I'm not sure. Is there a more clear way to see if this opeeration is made by creating new tensor or not?
I tried "jax.make_jaxpr" and see the results but not sure about the problem.
答案1
得分: 0
短答案是,不,您的函数的输出不会与为 tensor
分配的数组共享内存。
在XLA中,一个数组由一个均匀分布的缓冲区表示,当您从数组中选择随机值时,通常不能通过对输入缓冲区的视图进行均匀分布来构建结果。
英文:
The short answer is, no the output of your function will not share memory with the array allocated for tensor
.
In XLA, an array is represented by a uniformly-strided buffer, and when you select random values from an array, the result cannot in general be constructed via uniform-striding over a view of the input buffer.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论