英文:
tf_rep.export_graph(tf_model_path): KeyError: 'input.1
问题
我正在尝试将一个onnx
模型转换为tflite
,在执行tf_rep.export_graph(tf_model_path)
这一行时出现错误。这个问题之前在Stack Overflow上提出过,但没有提供明确的解决方案。
已安装的要求:tensorflow: 2.12.0
,onnx 1.14.0
,onnx-tf 1.10.0
,Python 3.10.12
import torch
import onnx
import tensorflow as tf
import onnx_tf
from torchvision.models import resnet50
# 加载PyTorch的ResNet50模型
pytorch_model = resnet50(pretrained=True)
pytorch_model.eval()
# 将PyTorch模型导出为ONNX格式
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
onnx_model_path = 'resnet50.onnx'
torch.onnx.export(pytorch_model, dummy_input, onnx_model_path, opset_version=12, verbose=False)
# 加载ONNX模型
onnx_model = onnx.load(onnx_model_path)
# 将ONNX模型转换为TensorFlow格式
tf_model_path = 'resnet50.pb'
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path) # 错误
错误信息:
警告:`input.1`不是有效的tf.function参数名。正在更改为`input_1`。
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-4-f35b83c104b8> in <cell line: 8>()
6 tf_model_path = 'resnet50'
7 tf_rep = prepare(onnx_model)
----> 8 tf_rep.export_graph(tf_model_path)
...
KeyError: 在用户代码中:
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend_tf_module.py", line 99, in call *
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op *
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/handler.py", line 59, in handle *
return ver_handle(node, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv.py", line 15, in version_11 *
return cls.conv(node, kwargs["tensor_dict"])
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py", line 29, in conv *
x = input_dict[node.inputs[0]]
KeyError: 'input.1'
<details>
<summary>英文:</summary>
I am trying to convert a `onnx` model to `tflite`, im facing an error executing line `tf_rep.export_graph(tf_model_path)`. This question was asked in SO before but none provided a definitive solution.
Requirements installed: `tensorflow: 2.12.0`, `onnx 1.14.0`, `onnx-tf 1.10.0`, `Python 3.10.12`
import torch
import onnx
import tensorflow as tf
import onnx_tf
from torchvision.models import resnet50
# Load the PyTorch ResNet50 model
pytorch_model = resnet50(pretrained=True)
pytorch_model.eval()
# Export the PyTorch model to ONNX format
input_shape = (1, 3, 224, 224)
dummy_input = torch.randn(input_shape)
onnx_model_path = 'resnet50.onnx'
torch.onnx.export(pytorch_model, dummy_input, onnx_model_path, opset_version=12, verbose=False)
# Load the ONNX model
onnx_model = onnx.load(onnx_model_path)
# Convert the ONNX model to TensorFlow format
tf_model_path = 'resnet50.pb
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path) #ERROR
Error:
WARNING:absl:`input.1` is not a valid tf.function parameter name. Sanitizing to `input_1`.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-4-f35b83c104b8> in <cell line: 8>()
6 tf_model_path = 'resnet50'
7 tf_rep = prepare(onnx_model)
----> 8 tf_rep.export_graph(tf_model_path)
35 frames
/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py in tf__conv(cls, node, input_dict, transpose)
17 do_return = False
18 retval_ = ag__.UndefinedReturnValue()
---> 19 x = ag__.ld(input_dict)[ag__.ld(node).inputs[0]]
20 x_rank = ag__.converted_call(ag__.ld(len), (ag__.converted_call(ag__.ld(x).get_shape, (), None, fscope),), None, fscope)
21 x_shape = ag__.converted_call(ag__.ld(tf_shape), (ag__.ld(x), ag__.ld(tf).int32), None, fscope)
KeyError: in user code:
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend_tf_module.py", line 99, in __call__ *
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op *
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/handler.py", line 59, in handle *
return ver_handle(node, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv.py", line 15, in version_11 *
return cls.conv(node, kwargs["tensor_dict"])
File "/usr/local/lib/python3.10/dist-packages/onnx_tf/handlers/backend/conv_mixin.py", line 29, in conv *
x = input_dict[node.inputs[0]]
KeyError: 'input.1'
</details>
# 答案1
**得分**: 2
The problem was with a parameter name in `onnx` model.
```python
import onnx
onnx_model = onnx.load(onnx_model_path)
print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])
Model Inputs: ['input.1']
Here tflite
cannot parse the input.1
and has to be replaced by input_1
. The following code does that:
import onnx
from onnx import helper
onnx_model = onnx.load(onnx_model_path)
# Define a mapping from old names to new names
name_map = {"input.1": "input_1"}
# Initialize a list to hold the new inputs
new_inputs = []
# Iterate over the inputs and change their names if needed
for inp in onnx_model.graph.input:
if inp.name in name_map:
# Create a new ValueInfoProto with the new name
new_inp = helper.make_tensor_value_info(name_map[inp.name],
inp.type.tensor_type.elem_type,
[dim.dim_value for dim in inp.type.tensor_type.shape.dim])
new_inputs.append(new_inp)
else:
new_inputs.append(inp)
# Clear the old inputs and add the new ones
onnx_model.graph.ClearField("input")
onnx_model.graph.input.extend(new_inputs)
# Go through all nodes in the model and replace the old input name with the new one
for node in onnx_model.graph.node:
for i, input_name in enumerate(node.input):
if input_name in name_map:
node.input[i] = name_map[input_name]
# Save the renamed ONNX model
onnx.save(onnx_model, 'resnet50-new.onnx')
The new parameter looks like:
Model Inputs: ['input_1']
The output tflite
file generates without error.
import onnx
onnx_model_path = 'resnet50-new.onnx'
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare
tf_model_path = 'resnet50'
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path)
英文:
The problem was with a parameter name in onnx
model.
import onnx
onnx_model = onnx.load(onnx_model_path)
print("Model Inputs: ", [inp.name for inp in onnx_model.graph.input])
> Model Inputs: ['input.1']
Here tflite
cannot parse the input.1
and has to be replaced by input_1
. The following code does that:
import onnx
from onnx import helper
onnx_model = onnx.load(onnx_model_path)
# Define a mapping from old names to new names
name_map = {"input.1": "input_1"}
# Initialize a list to hold the new inputs
new_inputs = []
# Iterate over the inputs and change their names if needed
for inp in onnx_model.graph.input:
if inp.name in name_map:
# Create a new ValueInfoProto with the new name
new_inp = helper.make_tensor_value_info(name_map[inp.name],
inp.type.tensor_type.elem_type,
[dim.dim_value for dim in inp.type.tensor_type.shape.dim])
new_inputs.append(new_inp)
else:
new_inputs.append(inp)
# Clear the old inputs and add the new ones
onnx_model.graph.ClearField("input")
onnx_model.graph.input.extend(new_inputs)
# Go through all nodes in the model and replace the old input name with the new one
for node in onnx_model.graph.node:
for i, input_name in enumerate(node.input):
if input_name in name_map:
node.input[i] = name_map[input_name]
# Save the renamed ONNX model
onnx.save(onnx_model, 'resnet50-new.onnx')
The new parameter looks like:
> Model Inputs: ['input_1']
The output tflite
file generates without error.
import onnx
onnx_model_path = 'resnet50-new.onnx'
onnx_model = onnx.load(onnx_model_path)
from onnx_tf.backend import prepare
tf_model_path = 'resnet50'
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path)
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论