重塑具有未知形状的张量,使用 tf.function。

huangapple go评论63阅读模式
英文:

Reshape tensors of unknown shape with tf.function

问题

以下是代码部分的翻译:

让我们假设在我的函数中我需要处理形状为[4,3,2,1]和[5,4,3,2,1]的输入张量我想以这样的方式重塑它们使最后两个维度交换例如变成[4,3,1,2]在即时执行模式下这很容易但当我尝试使用`@tf.function`包装我的函数时会出现以下错误

`OperatorNotAllowedInGraphError: 不允许在符号化的tf.Tensor上进行迭代AutoGraph确实转换了此函数这可能表明您正在尝试使用不受支持的功能。`

相关代码如下

```python
tensor = tf.random.uniform(shape=[4, 3, 2, 1])

@tf.function
def my_func():
    reshaped = tf.reshape(tensor, shape=[*tf.shape(tensor)[:-2], tf.shape(tensor)[-1], tf.shape(tensor)[-2]])
    return reshaped

logging.info(my_func())

看起来 TensorFlow 不喜欢[:-2]的表示方式,但我真的不知道我应该如何以一种优雅和易读的方式解决这个问题。


<details>
<summary>英文:</summary>

Let&#39;s say that in my function I have to deal with input tensors of shape [4,3,2,1] and [5,4,3,2,1]. I want to reshape them in such a way that the last two dimensions are swapped, e.g., to [4,3,1,2]. In eager mode this is easy, but when I try to wrap my function using `@tf.function` the following error is thrown:

`OperatorNotAllowedInGraphError: Iterating over a symbolic tf.Tensor is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.`

The code in question is as follows:

    tensor = tf.random.uniform(shape=[4, 3, 2, 1])
    
        @tf.function
        def my_func():
            reshaped = tf.reshape(tensor, shape=[*tf.shape(tensor)[:-2], tf.shape(tensor)[-1], tf.shape(tensor)[-2]])
            return reshaped
    
        logging.info(my_func())

It looks like tensorflow does not like the `[:-2]` notation, but I don&#39;t really know how else I should solve this problem in an elegant and well-readable way.



</details>


# 答案1
**得分**: 1

在“图”模式下切片张量的方式不起作用,但我认为您可以使用`tf.transpose`:
```python
import tensorflow as tf

tensor = tf.random.uniform(shape=[4, 3, 2, 1])

@tf.function
def my_func():
  rank = tf.rank(tensor)
  some_magic = tf.concat([tf.zeros((rank - 2,), dtype=tf.int32), [1, -1]], axis=-1)
  reshaped = tf.transpose(tensor, perm = tf.range(rank) + some_magic)
  return reshaped

print(my_func().shape)
# (4, 3, 1, 2)
英文:

Slicing the tensor like that in Graph mode unfortunately does not work, but I think you can use tf.transpose:

import tensorflow as tf

tensor = tf.random.uniform(shape=[4, 3, 2, 1])

@tf.function
def my_func():
  rank = tf.rank(tensor)
  some_magic = tf.concat([tf.zeros((rank - 2,), dtype=tf.int32), [1, -1]], axis=-1)
  reshaped = tf.transpose(tensor, perm = tf.range(rank) + some_magic)
  return reshaped

print(my_func().shape)
# (4, 3, 1, 2)

huangapple
  • 本文由 发表于 2023年3月21日 00:10:47
  • 转载请务必保留本文链接:https://go.coder-hub.com/75792711.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定