tf_rep.export_graph(tf_model_path): KeyError: ‘input.1

huangapple go评论113阅读模式
英文:

tf_rep.export_graph(tf_model_path): KeyError: 'input.1

问题

我正在尝试将一个onnx模型转换为tflite,在执行tf_rep.export_graph(tf_model_path)这一行时出现错误。这个问题之前在Stack Overflow上提出过,但没有提供明确的解决方案。

已安装的要求:tensorflow: 2.12.0onnx 1.14.0onnx-tf 1.10.0Python 3.10.12

  1. import torch
  2. import onnx
  3. import tensorflow as tf
  4. import onnx_tf
  5. from torchvision.models import resnet50
  6. # 加载PyTorch的ResNet50模型
  7. pytorch_model = resnet50(pretrained=True)
  8. pytorch_model.eval()
  9. # 将PyTorch模型导出为ONNX格式
  10. input_shape = (1, 3, 224, 224)
  11. dummy_input = torch.randn(input_shape)
  12. onnx_model_path = 'resnet50.onnx'
  13. torch.onnx.export(pytorch_model, dummy_input, onnx_model_path, opset_version=12, verbose=False)
  14. # 加载ONNX模型
  15. onnx_model = onnx.load(onnx_model_path)
  16. # 将ONNX模型转换为TensorFlow格式
  17. tf_model_path = 'resnet50.pb'
  18. onnx_model = onnx.load(onnx_model_path)
  19. from onnx_tf.backend import prepare
  20. tf_rep = prepare(onnx_model)
  21. tf_rep.export_graph(tf_model_path) # 错误

错误信息:

  1. 警告:`input.1`不是有效的tf.function参数名。正在更改为`input_1`
  2. ---------------------------------------------------------------------------
  3. KeyError Traceback (most recent call last)
  4. <ipython-input-4-f35b83c104b8> in <cell line: 8>()
  5. 6 tf_model_path = 'resnet50'
  6. 7 tf_rep = prepare(onnx_model)
  7. ----> 8 tf_rep.export_graph(tf_model_path)
  8. ...

KeyError: 在用户代码中:

File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend_tf_module.py", line 99, in call *
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op *
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/handler.py", line 59, in handle *
return ver_handle(node, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv.py", line 15, in version_11 *
return cls.conv(node, kwargs["tensor_dict"])
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py", line 29, in conv *
x = input_dict[node.inputs[0]]

KeyError: 'input.1'

  1. <details>
  2. <summary>英文:</summary>
  3. I am trying to convert a `onnx` model to `tflite`, im facing an error executing line `tf_rep.export_graph(tf_model_path)`. This question was asked in SO before but none provided a definitive solution.
  4. Requirements installed: `tensorflow: 2.12.0`, `onnx 1.14.0`, `onnx-tf 1.10.0`, `Python 3.10.12`
  5. import torch
  6. import onnx
  7. import tensorflow as tf
  8. import onnx_tf
  9. from torchvision.models import resnet50
  10. # Load the PyTorch ResNet50 model
  11. pytorch_model = resnet50(pretrained=True)
  12. pytorch_model.eval()
  13. # Export the PyTorch model to ONNX format
  14. input_shape = (1, 3, 224, 224)
  15. dummy_input = torch.randn(input_shape)
  16. onnx_model_path = &#39;resnet50.onnx&#39;
  17. torch.onnx.export(pytorch_model, dummy_input, onnx_model_path, opset_version=12, verbose=False)
  18. # Load the ONNX model
  19. onnx_model = onnx.load(onnx_model_path)
  20. # Convert the ONNX model to TensorFlow format
  21. tf_model_path = &#39;resnet50.pb
  22. onnx_model = onnx.load(onnx_model_path)
  23. from onnx_tf.backend import prepare
  24. tf_rep = prepare(onnx_model)
  25. tf_rep.export_graph(tf_model_path) #ERROR
  26. Error:
  27. WARNING:absl:`input.1` is not a valid tf.function parameter name. Sanitizing to `input_1`.
  28. ---------------------------------------------------------------------------
  29. KeyError Traceback (most recent call last)
  30. &lt;ipython-input-4-f35b83c104b8&gt; in &lt;cell line: 8&gt;()
  31. 6 tf_model_path = &#39;resnet50&#39;
  32. 7 tf_rep = prepare(onnx_model)
  33. ----&gt; 8 tf_rep.export_graph(tf_model_path)
  34. 35 frames
  35. /usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py in tf__conv(cls, node, input_dict, transpose)
  36. 17 do_return = False
  37. 18 retval_ = ag__.UndefinedReturnValue()
  38. ---&gt; 19 x = ag__.ld(input_dict)[ag__.ld(node).inputs[0]]
  39. 20 x_rank = ag__.converted_call(ag__.ld(len), (ag__.converted_call(ag__.ld(x).get_shape, (), None, fscope),), None, fscope)
  40. 21 x_shape = ag__.converted_call(ag__.ld(tf_shape), (ag__.ld(x), ag__.ld(tf).int32), None, fscope)
  41. KeyError: in user code:
  42. File &quot;/usr/local/lib/python3.10/dist-packages/onnx_tf/backend_tf_module.py&quot;, line 99, in __call__ *
  43. output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
  44. File &quot;/usr/local/lib/python3.10/dist-packages/onnx_tf/backend.py&quot;, line 347, in _onnx_node_to_tensorflow_op *
  45. return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
  46. File &quot;/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/handler.py&quot;, line 59, in handle *
  47. return ver_handle(node, **kwargs)
  48. File &quot;/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv.py&quot;, line 15, in version_11 *
  49. return cls.conv(node, kwargs[&quot;tensor_dict&quot;])
  50. File &quot;/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py&quot;, line 29, in conv *
  51. x = input_dict[node.inputs[0]]
  52. KeyError: &#39;input.1&#39;
  53. </details>
  54. # 答案1
  55. **得分**: 2
  56. The problem was with a parameter name in `onnx` model.
  57. ```python
  58. import onnx
  59. onnx_model = onnx.load(onnx_model_path)
  60. print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])

Model Inputs: ['input.1']

Here tflite cannot parse the input.1 and has to be replaced by input_1. The following code does that:

  1. import onnx
  2. from onnx import helper
  3. onnx_model = onnx.load(onnx_model_path)
  4. # Define a mapping from old names to new names
  5. name_map = {"input.1": "input_1"}
  6. # Initialize a list to hold the new inputs
  7. new_inputs = []
  8. # Iterate over the inputs and change their names if needed
  9. for inp in onnx_model.graph.input:
  10. if inp.name in name_map:
  11. # Create a new ValueInfoProto with the new name
  12. new_inp = helper.make_tensor_value_info(name_map[inp.name],
  13. inp.type.tensor_type.elem_type,
  14. [dim.dim_value for dim in inp.type.tensor_type.shape.dim])
  15. new_inputs.append(new_inp)
  16. else:
  17. new_inputs.append(inp)
  18. # Clear the old inputs and add the new ones
  19. onnx_model.graph.ClearField("input")
  20. onnx_model.graph.input.extend(new_inputs)
  21. # Go through all nodes in the model and replace the old input name with the new one
  22. for node in onnx_model.graph.node:
  23. for i, input_name in enumerate(node.input):
  24. if input_name in name_map:
  25. node.input[i] = name_map[input_name]
  26. # Save the renamed ONNX model
  27. onnx.save(onnx_model, 'resnet50-new.onnx')

The new parameter looks like:

Model Inputs: ['input_1']

The output tflite file generates without error.

  1. import onnx
  2. onnx_model_path = 'resnet50-new.onnx'
  3. onnx_model = onnx.load(onnx_model_path)
  4. from onnx_tf.backend import prepare
  5. tf_model_path = 'resnet50'
  6. tf_rep = prepare(onnx_model)
  7. tf_rep.export_graph(tf_model_path)
英文:

The problem was with a parameter name in onnx model.

  1. import onnx
  2. onnx_model = onnx.load(onnx_model_path)
  3. print(&quot;Model Inputs: &quot;, [inp.name for inp in onnx_model.graph.input])

> Model Inputs: ['input.1']

Here tflite cannot parse the input.1 and has to be replaced by input_1. The following code does that:

  1. import onnx
  2. from onnx import helper
  3. onnx_model = onnx.load(onnx_model_path)
  4. # Define a mapping from old names to new names
  5. name_map = {&quot;input.1&quot;: &quot;input_1&quot;}
  6. # Initialize a list to hold the new inputs
  7. new_inputs = []
  8. # Iterate over the inputs and change their names if needed
  9. for inp in onnx_model.graph.input:
  10. if inp.name in name_map:
  11. # Create a new ValueInfoProto with the new name
  12. new_inp = helper.make_tensor_value_info(name_map[inp.name],
  13. inp.type.tensor_type.elem_type,
  14. [dim.dim_value for dim in inp.type.tensor_type.shape.dim])
  15. new_inputs.append(new_inp)
  16. else:
  17. new_inputs.append(inp)
  18. # Clear the old inputs and add the new ones
  19. onnx_model.graph.ClearField(&quot;input&quot;)
  20. onnx_model.graph.input.extend(new_inputs)
  21. # Go through all nodes in the model and replace the old input name with the new one
  22. for node in onnx_model.graph.node:
  23. for i, input_name in enumerate(node.input):
  24. if input_name in name_map:
  25. node.input[i] = name_map[input_name]
  26. # Save the renamed ONNX model
  27. onnx.save(onnx_model, &#39;resnet50-new.onnx&#39;)

The new parameter looks like:

> Model Inputs: ['input_1']

The output tflite file generates without error.

  1. import onnx
  2. onnx_model_path = &#39;resnet50-new.onnx&#39;
  3. onnx_model = onnx.load(onnx_model_path)
  4. from onnx_tf.backend import prepare
  5. tf_model_path = &#39;resnet50&#39;
  6. tf_rep = prepare(onnx_model)
  7. tf_rep.export_graph(tf_model_path)

huangapple
  • 本文由 发表于 2023年8月5日 06:14:57
  • 转载请务必保留本文链接:https://go.coder-hub.com/76839366.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定