Apply tf.ensure_shape for multiple outputs.

huangapple go评论88阅读模式
英文:

Apply tf.ensure_shape for multiple outputs

问题

以下是您要翻译的代码部分:

  1. 我有这段代码
  2. import tensorflow as tf
  3. import numpy as np
  4. def scale(X, a=-1, b=1, dtype='float32'):
  5. if a > b:
  6. a, b = b, a
  7. xmin = tf.cast(tf.math.reduce_min(X), dtype=dtype)
  8. xmax = tf.cast(tf.math.reduce_max(X), dtype=dtype)
  9. X = (X - xmin) / (xmax - xmin)
  10. scaled = X * (b - a) + a
  11. return scaled, xmin, xmax
  12. a = np.ones((10, 20, 20, 2))
  13. dataset = tf.data.Dataset.from_tensor_slices(a)
  14. data = dataset.map(lambda x: tf.py_function(scale,
  15. [x],
  16. (tf.float32, tf.float32, tf.float32)))
  17. 到此为止一切正常我收到
  18. data
  19. <MapDataset shapes: (<unknown>, <unknown>, <unknown>), types: (tf.float32, tf.float32, tf.float32)>
  20. 现在我必须使用tf.ensure_shape来创建形状
  21. 例如如果`scale`函数只返回一个值`scale`那么我会这样做
  22. data = data.map(lambda x: tf.ensure_shape(x, [10, 20, 20, 2]))
  23. 现在当我有3个输出值时该怎么办
  24. 因此我希望能够使用`scale`函数的结果这就是我在做所有这些的原因如果有其他方法我不知道
  25. 缩放值xminxmax
英文:

I have this code:

  1. import tensorflow as tf
  2. import numpy as np
  3. def scale(X, a=-1, b=1, dtype=&#39;float32&#39;):
  4. if a &gt; b:
  5. a, b = b, a
  6. xmin = tf.cast(tf.math.reduce_min(X), dtype=dtype)
  7. xmax = tf.cast(tf.math.reduce_max(X), dtype=dtype)
  8. X = (X - xmin) / (xmax - xmin)
  9. scaled = X * (b - a) + a
  10. return scaled, xmin, xmax
  11. a = np.ones((10, 20, 20, 2))
  12. dataset = tf.data.Dataset.from_tensor_slices(a)
  13. data = dataset.map(lambda x: tf.py_function(scale,
  14. [x],
  15. (tf.float32, tf.float32, tf.float32)))

Until here it is ok, I receive :

  1. data
  2. &lt;MapDataset shapes: (&lt;unknown&gt;, &lt;unknown&gt;, &lt;unknown&gt;), types: (tf.float32, tf.float32, tf.float32)&gt;

Now, I have to use tf.ensure_shape, to create the shapes.

If for example the scale function returned only one value, scale, then I would do:

  1. data = data.map(lambda x: tf.ensure_shape(x, [10, 20, 20, 2]))

Now that I have 3 output values?

So, I want to be able to use the result of the scale function that's why I am doing all these. If there is another way, I don't know.

scaled values, xmin and xmax

答案1

得分: 1

If it is just about transforming unknown shape to known shape, I think you can use tf.reshape method.

  1. def scale(X, a=-1, b=1, dtype='float32'):
  2. if a > b:
  3. a, b = b, a
  4. xmin = tf.cast(tf.math.reduce_min(X), dtype=dtype)
  5. xmax = tf.cast(tf.math.reduce_max(X), dtype=dtype)
  6. X = (X - xmin) / (xmax - xmin)
  7. scaled = X * (b - a) + a
  8. return scaled, xmin, xmax
  9. a = tf.random.uniform(shape=[10, 20, 20, 2], minval=1, maxval=5)
  10. dataset = tf.data.Dataset.from_tensor_slices(a)
  11. dataset = dataset.map(
  12. lambda x: tf.py_function(
  13. scale,
  14. [x],
  15. (tf.float32, tf.float32, tf.float32))
  16. )
  17. def set_shape(x, y, z):
  18. x = tf.reshape(x, [-1, 20, 20, 2])
  19. y = tf.reshape(y, [1])
  20. z = tf.reshape(z, [1])
  21. return x, y, z
  22. dataset = dataset.map(set_shape)
  23. a, b, c = next(iter(data))
  24. a.shape, b.shape, c.shape
  25. (TensorShape([1, 20, 20, 2]), TensorShape([1]), TensorShape([1]))
英文:

If it is just about transforming uknown shape to known shape, I think you can use tf.reshape method.

  1. def scale(X, a=-1, b=1, dtype=&#39;float32&#39;):
  2. if a &gt; b:
  3. a, b = b, a
  4. xmin = tf.cast(tf.math.reduce_min(X), dtype=dtype)
  5. xmax = tf.cast(tf.math.reduce_max(X), dtype=dtype)
  6. X = (X - xmin) / (xmax - xmin)
  7. scaled = X * (b - a) + a
  8. return scaled, xmin, xmax
  9. a = tf.random.uniform(shape=[10, 20, 20, 2], minval=1, maxval=5)
  10. dataset = tf.data.Dataset.from_tensor_slices(a)
  11. dataset = dataset.map(
  12. lambda x: tf.py_function(
  13. scale,
  14. [x],
  15. (tf.float32, tf.float32, tf.float32))
  16. )
  17. def set_shape(x, y, z):
  18. x = tf.reshape(x, [-1, 20, 20, 2])
  19. y = tf.reshape(y, [1])
  20. z = tf.reshape(z, [1])
  21. return x, y, z
  22. dataset = dataset.map(set_shape)
  23. a, b, c = next(iter(data))
  24. a.shape, b.shape, c.shape
  25. (TensorShape([1, 20, 20, 2]), TensorShape([1]), TensorShape([1]))
  26. </details>

huangapple
  • 本文由 发表于 2023年5月17日 22:14:34
  • 转载请务必保留本文链接:https://go.coder-hub.com/76273083.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定