英文:
why my tf.tensor_scatter_nd_add can't do the same as torch scatter_add_
问题
new_means = tf.tensor_scatter_nd_add(new_means, indices=repeat(buckets, "n -> n d", d=dim), updates=samples)
英文:
new_means = tf.tensor_scatter_nd_add(new_means, indices=repeat(buckets, "n -> n d", d=dim), updates=samples)
assumpt new_means.shape=[3, 4], indices.shape=[4, 4] and updates.sahpe=[4, 4].
the above code return err :
Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,4] updates: [4,4].
even when I set the two arr the same shape, it still returns the similar err.
but it works well in pytroch scatter_add.
I don't know why and how to achieve the same in tensorflow as in pytorch.
Can you help me please?
I try to read the offical explantations and found out there are some confusing requirements within it. How can I get the same effect as catter_add
答案1
得分: 0
根据类似的问题和解决方案 https://stackoverflow.com/questions/63185202/,我通过将原始数据转换为1-D维度来解决了自己的问题(以满足tensorflow函数的要求)
def scatter_add(tensor, indices, updates):
"""
根据使用tf.tensor_scatter_nd_add存在的一些问题,首先将其重塑为一维
"""
original_tensor = tensor
indices_add = tf.range(0, indices.shape[-1])
indices_add = repeat(indices_add, "n -> d n", d=indices.shape[0])
indices = indices * indices.shape[-1]
indices += indices_add
tensor = tf.reshape(tensor, shape=[-1])
indices = tf.reshape(indices, shape=[-1, 1])
updates = tf.reshape(updates, shape=[-1])
scatter = tf.tensor_scatter_nd_add(tensor, indices, updates)
scatter = tf.reshape(scatter, shape=[original_tensor.shape[0], original_tensor.shape[1], -1])
scatter = tf.squeeze(scatter)
return scatter
英文:
inspired by a similar question and resolve
https://stackoverflow.com/questions/63185202/
I solved my own question by transform the original data into 1-D dimension(To satisfy the requirement of function in tensorflow)
def scatter_add(tensor, indices, updates):
"""
according to some problems with using tf.tensor_scatter_nd_add, we firstly reshape to one-dimension
"""
original_tensor = tensor
indices_add = tf.range(0, indices.shape[-1])
indices_add = repeat(indices_add, "n -> d n", d=indices.shape[0])
indices = indices * indices.shape[-1]
indices += indices_add
tensor = tf.reshape(tensor, shape=[-1])
indices = tf.reshape(indices, shape=[-1, 1])
updates = tf.reshape(updates, shape=[-1])
scatter = tf.tensor_scatter_nd_add(tensor, indices, updates)
scatter = tf.reshape(scatter, shape=[original_tensor.shape[0], original_tensor.shape[1], -1])
scatter = tf.squeeze(scatter)
return scatter
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论