为什么我的tf.tensor_scatter_nd_add不能像torch.scatter_add_一样工作?

huangapple go评论72阅读模式
英文:

why my tf.tensor_scatter_nd_add can't do the same as torch scatter_add_

问题

new_means = tf.tensor_scatter_nd_add(new_means, indices=repeat(buckets, "n -> n d", d=dim), updates=samples)
英文:
new_means = tf.tensor_scatter_nd_add(new_means, indices=repeat(buckets, "n -> n d", d=dim), updates=samples)

assumpt new_means.shape=[3, 4], indices.shape=[4, 4] and updates.sahpe=[4, 4].
the above code return err :

Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,4] updates: [4,4].

even when I set the two arr the same shape, it still returns the similar err.

but it works well in pytroch scatter_add.

I don't know why and how to achieve the same in tensorflow as in pytorch.

Can you help me please?

I try to read the offical explantations and found out there are some confusing requirements within it. How can I get the same effect as catter_add

答案1

得分: 0

根据类似的问题和解决方案 https://stackoverflow.com/questions/63185202/,我通过将原始数据转换为1-D维度来解决了自己的问题(以满足tensorflow函数的要求)

def scatter_add(tensor, indices, updates):

    """
    根据使用tf.tensor_scatter_nd_add存在的一些问题,首先将其重塑为一维
    """

    original_tensor = tensor
    indices_add = tf.range(0, indices.shape[-1])
    indices_add = repeat(indices_add, "n -> d n", d=indices.shape[0])
    indices = indices * indices.shape[-1]
    indices += indices_add

    tensor = tf.reshape(tensor, shape=[-1])
    indices = tf.reshape(indices, shape=[-1, 1])
    updates = tf.reshape(updates, shape=[-1])

    scatter = tf.tensor_scatter_nd_add(tensor, indices, updates)
    scatter = tf.reshape(scatter, shape=[original_tensor.shape[0], original_tensor.shape[1], -1])
    scatter = tf.squeeze(scatter)

    return scatter
英文:

inspired by a similar question and resolve
https://stackoverflow.com/questions/63185202/

I solved my own question by transform the original data into 1-D dimension(To satisfy the requirement of function in tensorflow)


def scatter_add(tensor, indices, updates):

    """
    according to some problems with using tf.tensor_scatter_nd_add, we firstly reshape to one-dimension
    """

    original_tensor = tensor
    indices_add = tf.range(0, indices.shape[-1])
    indices_add = repeat(indices_add, "n -> d n", d=indices.shape[0])
    indices = indices * indices.shape[-1]
    indices += indices_add

    tensor = tf.reshape(tensor, shape=[-1])
    indices = tf.reshape(indices, shape=[-1, 1])
    updates = tf.reshape(updates, shape=[-1])

    scatter = tf.tensor_scatter_nd_add(tensor, indices, updates)
    scatter = tf.reshape(scatter, shape=[original_tensor.shape[0], original_tensor.shape[1], -1])
    scatter = tf.squeeze(scatter)

    return scatter

huangapple
  • 本文由 发表于 2023年3月12日 16:02:59
  • 转载请务必保留本文链接:https://go.coder-hub.com/75711783.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定