Tensorflow: Filter中的3D张量(图像的)形状有None。

huangapple go评论104阅读模式
英文:

Tensorflow: Shape of 3D tensor (of an image) in Filter has None

问题

以下是您要翻译的部分:

I am following [this tutorial][1] on [tensorflow.org][2].

I have folder images with two folders cat and dog in it. Following above tutorial I am trying to convert .jpg and ..png images to features (NumPy array) for modeling.

Problem

After processing the images to tensors I found that some images were converted to tensor with shape (28, 28, 4). So I added condition to filter out such tensors. This logic works while explicitly looping each tensor, using for loop, after converting it to numpy array, but same logic does not work when used with filter.

Please help me fix this filter() I went through [filter()][3] documentation and could not find any solution.

Source code

  1. import tensorflow as tf
  2. import os
  3. print("TensorFlow version:", tf.__version__)
  4. def process_image(file_path_tensor):
  5. parts = tf.strings.split(file_path_tensor, os.sep)
  6. label = parts[-2]
  7. image = tf.io.read_file(file_path_tensor)
  8. image = tf.image.decode_jpeg(image)
  9. image = tf.image.resize(image, [128, 128])
  10. image = tf.image.convert_image_dtype(image, tf.float32)
  11. image = image / 255
  12. return image, label
  13. def check_shape(x, y):
  14. print("\nShape received in filter():", x.shape)
  15. d1, d2, d3 = x.shape
  16. return d3 == 3
  17. images_ds = tf.data.Dataset.list_files("./images/*/*", shuffle=True)
  18. file_path = next(iter(images_ds))
  19. image, label = process_image(file_path)
  20. print("Shape:", image.shape)
  21. print("Class label:", label.numpy().decode())
  22. # ETL pipeline.
  23. X_y_tensors = (
  24. images_ds
  25. .map(process_image) # Extra and Transform
  26. .filter(check_shape) # Filter
  27. .as_numpy_iterator() # Load
  28. )
  29. print("\nTechnique 1:")
  30. print("Final X count:", len(list(X_y_tensors)))
  31. X_y_tensors = images_ds.map(process_image)
  32. count = 0
  33. for x, y in X_y_tensors:
  34. d1, d2, d3 = x.shape
  35. if d3 > 3:
  36. continue
  37. count += 1
  38. print("\nTechnique 2:")
  39. print("Final X count:", count)

Output

  1. TensorFlow version: 2.6.0
  2. Shape: (128, 128, 3)
  3. Class label: cat
  4. Shape received in filter(): (128, 128, None)
  5. Technique 1:
  6. Final X count: 0
  7. Technique 2:
  8. Final X count: 123
  9. ```
  10. As it can be seen,
  11. 1. Count is 0 when _Technique 1_ is used to filter tensors, since the shape of the tensor received is `(128, 128, None)`.
  12. 1. Count is 123 (image count after filtering) when _Technique 2_ is used.
  13. I do not think [this][5] is an issue since I am **not using batches**.
  14. [Full code with dataset][4]
  15. [1]: https://www.tensorflow.org/guide/data#preprocessing_data
  16. [2]: https://www.tensorflow.org/
  17. [3]: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter
  18. [4]: https://github.com/DheemanthBhat/tensorflow-image-pipeline
  19. [5]: https://stackoverflow.com/questions/58331837/filter-data-in-tensorflow
  20. <details>
  21. <summary>英文:</summary>
  22. I am following [this tutorial][1] on [tensorflow.org][2].
  23. I have folder _images_ with two folders _cat_ and _dog_ in it. Following above tutorial I am trying to convert .jpg and .png images to features (NumPy array) for modeling.
  24. ### Problem
  25. After processing the images to tensors I found that some images were converted to tensor with shape `(28, 28, 4)`. So I added condition to filter out such tensors. This logic works while explicitly looping each tensor, using `for` loop, after converting it to numpy array, but same logic does not work when used with `filter`.
  26. Please help me fix this `filter()` I went through [`filter()`][3] documentation and could not find any solution.
  27. ### Source code
  28. ```python
  29. import tensorflow as tf
  30. import os
  31. print(&quot;TensorFlow version:&quot;, tf.__version__)
  32. def process_image(file_path_tensor):
  33. parts = tf.strings.split(file_path_tensor, os.sep)
  34. label = parts[-2]
  35. image = tf.io.read_file(file_path_tensor)
  36. image = tf.image.decode_jpeg(image)
  37. image = tf.image.resize(image, [128, 128])
  38. image = tf.image.convert_image_dtype(image, tf.float32)
  39. image = image / 255
  40. return image, label
  41. def check_shape(x, y):
  42. print(&quot;\nShape received in filter():&quot;, x.shape)
  43. d1, d2, d3 = x.shape
  44. return d3 == 3
  45. images_ds = tf.data.Dataset.list_files(&quot;./images/*/*&quot;, shuffle=True)
  46. file_path = next(iter(images_ds))
  47. image, label = process_image(file_path)
  48. print(&quot;Shape:&quot;, image.shape)
  49. print(&quot;Class label:&quot;, label.numpy().decode())
  50. # ETL pipeline.
  51. X_y_tensors = (
  52. images_ds
  53. .map(process_image) # Extra and Transform
  54. .filter(check_shape) # Filter
  55. .as_numpy_iterator() # Load
  56. )
  57. print(&quot;\nTechnique 1:&quot;)
  58. print(&quot;Final X count:&quot;, len(list(X_y_tensors)))
  59. X_y_tensors = images_ds.map(process_image)
  60. count = 0
  61. for x, y in X_y_tensors:
  62. d1, d2, d3 = x.shape
  63. if d3 &gt; 3:
  64. continue
  65. count += 1
  66. print(&quot;\nTechnique 2:&quot;)
  67. print(&quot;Final X count:&quot;, count)
  68. ```
  69. ### Output
  70. ```
  71. TensorFlow version: 2.6.0
  72. Shape: (128, 128, 3)
  73. Class label: cat
  74. Shape received in filter(): (128, 128, None)
  75. Technique 1:
  76. Final X count: 0
  77. Technique 2:
  78. Final X count: 123
  79. ```
  80. As it can be seen,
  81. 1. Count is 0 when _Technique 1_ is used to filter tensors, since the shape of the tensor received is `(128, 128, None)`.
  82. 1. Count is 123 (image count after filtering) when _Technique 2_ is used.
  83. I do not think [this][5] is an issue since I am **not using batches**.
  84. [Full code with dataset][4]
  85. [1]: https://www.tensorflow.org/guide/data#preprocessing_data
  86. [2]: https://www.tensorflow.org/
  87. [3]: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter
  88. [4]: https://github.com/DheemanthBhat/tensorflow-image-pipeline
  89. [5]: https://stackoverflow.com/questions/58331837/filter-data-in-tensorflow
  90. </details>
  91. # 答案1
  92. **得分**: 0
  93. 问题与`filter()`无关,而是由`decode_jpeg()`引起的。将`tf.io.decode_jpeg()`的[`channels`][1]参数设置为3。
  94. 数字3表示图像的RGB颜色通道。
  95. 解决方案
  96. ```python
  97. import tensorflow as tf
  98. import os
  99. print("TensorFlow版本:", tf.__version__)
  100. def process_image(file_path_tensor):
  101. parts = tf.strings.split(file_path_tensor, os.sep)
  102. label = parts[-2]
  103. image = tf.io.read_file(file_path_tensor)
  104. image = tf.image.decode_jpeg(image, channels=3)
  105. image = tf.image.resize(image, [128, 128])
  106. image = tf.image.convert_image_dtype(image, tf.float32)
  107. image = image / 255
  108. return image, label
  109. def check_shape(x, y):
  110. print("\n在filter()中接收到的形状:", x.shape)
  111. d1, d2, d3 = x.shape
  112. return d3 == 3
  113. images_ds = tf.data.Dataset.list_files("./images/*/*", shuffle=True)
  114. file_path = next(iter(images_ds))
  115. image, label = process_image(file_path)
  116. print("形状:", image.shape)
  117. print("类别标签:", label.numpy().decode())
  118. # ETL流水线。
  119. X_y_tensors = (
  120. images_ds
  121. .map(process_image) # 提取和转换
  122. .filter(check_shape) # 过滤
  123. .as_numpy_iterator() # 载入
  124. )
  125. print("\n技巧 1:")
  126. print("最终 X 计数:", len(list(X_y_tensors)))
  127. X_y_tensors = images_ds.map(process_image)
  128. count = 0
  129. for x, y in X_y_tensors:
  130. d1, d2, d3 = x.shape
  131. if d3 > 3:
  132. continue
  133. count += 1
  134. print("\n技巧 2:")
  135. print("最终 X 计数:", count)
  136. ```
  137. 输出
  138. ```python
  139. TensorFlow版本: 2.6.0
  140. 形状: (128, 128, 3)
  141. 类别标签: 猫
  142. 在filter()中接收到的形状: (128, 128, 3)
  143. 技巧 1:
  144. 最终 X 计数: 129
  145. 技巧 2:
  146. 最终 X 计数: 129
  147. ```
  148. [1]: https://www.tensorflow.org/api_docs/python/tf/io/decode_jpeg#args
  149. <details>
  150. <summary>英文:</summary>
  151. Okay. So the problem is not related to `filter()` its caused by `decode_jpeg()`. Set [`channels`][1] parameter of `tf.io.decode_jpeg()` to value 3.
  152. Three indicates RGB color channels of an image.
  153. ### Solution
  154. ```python
  155. import tensorflow as tf
  156. import os
  157. print(&quot;TensorFlow version:&quot;, tf.__version__)
  158. def process_image(file_path_tensor):
  159. parts = tf.strings.split(file_path_tensor, os.sep)
  160. label = parts[-2]
  161. image = tf.io.read_file(file_path_tensor)
  162. image = tf.image.decode_jpeg(image, channels=3)
  163. image = tf.image.resize(image, [128, 128])
  164. image = tf.image.convert_image_dtype(image, tf.float32)
  165. image = image / 255
  166. return image, label
  167. def check_shape(x, y):
  168. print(&quot;\nShape received in filter():&quot;, x.shape)
  169. d1, d2, d3 = x.shape
  170. return d3 == 3
  171. images_ds = tf.data.Dataset.list_files(&quot;./images/*/*&quot;, shuffle=True)
  172. file_path = next(iter(images_ds))
  173. image, label = process_image(file_path)
  174. print(&quot;Shape:&quot;, image.shape)
  175. print(&quot;Class label:&quot;, label.numpy().decode())
  176. # ETL pipeline.
  177. X_y_tensors = (
  178. images_ds
  179. .map(process_image) # Extra and Transform
  180. .filter(check_shape) # Filter
  181. .as_numpy_iterator() # Load
  182. )
  183. print(&quot;\nTechnique 1:&quot;)
  184. print(&quot;Final X count:&quot;, len(list(X_y_tensors)))
  185. X_y_tensors = images_ds.map(process_image)
  186. count = 0
  187. for x, y in X_y_tensors:
  188. d1, d2, d3 = x.shape
  189. if d3 &gt; 3:
  190. continue
  191. count += 1
  192. print(&quot;\nTechnique 2:&quot;)
  193. print(&quot;Final X count:&quot;, count)
  194. ```
  195. Output
  196. ```
  197. TensorFlow version: 2.6.0
  198. Shape: (128, 128, 3)
  199. Class label: cat
  200. Shape received in filter(): (128, 128, 3)
  201. Technique 1:
  202. Final X count: 129
  203. Technique 2:
  204. Final X count: 129
  205. ```
  206. [1]: https://www.tensorflow.org/api_docs/python/tf/io/decode_jpeg#args
  207. </details>

huangapple
  • 本文由 发表于 2023年2月18日 22:03:38
  • 转载请务必保留本文链接:https://go.coder-hub.com/75493853.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定