如何在`call`方法中调用加载的Tensorflow模型并传递自定义参数?

huangapple go评论100阅读模式
英文:

How to call loaded Tensorflow model with custom params in call method?

问题

以下是要翻译的代码部分:

  1. I've been following the [Tensorflow text generation tutorial][1] which includes two models, "MyModel" and "OneStep". "MyModel" is an RNN operating on vectorized strings; "OneStep" essentially wraps "MyModel" and operates on strings directly.
  2. The tutorial saves and loads "OneStep", and I followed this successfully, but I now want to save and reload "MyModel". This is not done in the tutorial, and when I try to call the reloaded model with `return_state=True`, I get an error:
  3. # TODO: Loaded model gives an error
  4. for input_example_batch, target_example_batch in train_ds.take(1):
  5. example_batch_predictions, example_states = loaded_model(input_example_batch, False, None, return_state=True)
  6. print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
  7. print(example_states.shape, " # (batch_size, rnn_units)")
  8. /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
  9. 662
  10. def custom_model(inputs, custom_param=False):
  11. return dense(inputs)
  12. model = CustomModel()
  13. sample_inputs = tf.zeros((16, 30))
  14. print('Sample inputs:', sample_inputs)
  15. sample_outputs = model(sample_inputs)
  16. print('Sample outputs:', sample_outputs)
  17. model.save('saved_model')
  18. loaded_model = tf.keras.models.load_model('saved_model')
  19. sample_outputs_2 = loaded_model(sample_inputs, custom_param=True)
  20. print('Sample outputs 2:', sample_outputs_2)

如果您有任何其他问题或需要进一步的帮助,请随时提问。

英文:

I've been following the Tensorflow text generation tutorial which includes two models, "MyModel" and "OneStep". "MyModel" is an RNN operating on vectorized strings; "OneStep" essentially wraps "MyModel" and operates on strings directly.

The tutorial saves and loads "OneStep", and I followed this successfully, but I now want to save and reload "MyModel". This is not done in the tutorial, and when I try to call the reloaded model with return_state=True, I get an error:

  1. ---------------------------------------------------------------------------
  2. ValueError Traceback (most recent call last)
  3. /tmp/ipykernel_23/2335414736.py in <module>
  4. 1 # TODO: Loaded model gives an error
  5. 2 for input_example_batch, target_example_batch in train_ds.take(1):
  6. ----> 3 example_batch_predictions, example_states = loaded_model(input_example_batch, False, None, return_state=True)
  7. 4 print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
  8. 5 print(example_states.shape, " # (batch_size, rnn_units)")
  9. /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
  10. 662
  11. 663 def _call_attribute(instance, *args, **kwargs):
  12. --> 664 return instance.__call__(*args, **kwargs)
  13. 665
  14. 666
  15. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
  16. 883
  17. 884 with OptionalXlaContext(self._jit_compile):
  18. --> 885 result = self._call(*args, **kwds)
  19. 886
  20. 887 new_tracing_count = self.experimental_get_tracing_count()
  21. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
  22. 931 # This is the first call of __call__, so we have to initialize.
  23. 932 initializers = []
  24. --> 933 self._initialize(args, kwds, add_initializers_to=initializers)
  25. 934 finally:
  26. 935 # At this point we know that the initialization is complete (or less
  27. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
  28. 758 self._concrete_stateful_fn = (
  29. 759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
  30. --> 760 *args, **kwds))
  31. 761
  32. 762 def invalid_creator_scope(*unused_args, **unused_kwds):
  33. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
  34. 3064 args, kwargs = None, None
  35. 3065 with self._lock:
  36. -> 3066 graph_function, _ = self._maybe_define_function(args, kwargs)
  37. 3067 return graph_function
  38. 3068
  39. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
  40. 3461
  41. 3462 self._function_cache.missed.add(call_context_key)
  42. -> 3463 graph_function = self._create_graph_function(args, kwargs)
  43. 3464 self._function_cache.primary[cache_key] = graph_function
  44. 3465
  45. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
  46. 3306 arg_names=arg_names,
  47. 3307 override_flat_arg_shapes=override_flat_arg_shapes,
  48. -> 3308 capture_by_value=self._capture_by_value),
  49. 3309 self._function_attributes,
  50. 3310 function_spec=self.function_spec,
  51. /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
  52. 1005 _, original_func = tf_decorator.unwrap(python_func)
  53. 1006
  54. -> 1007 func_outputs = python_func(*func_args, **func_kwargs)
  55. 1008
  56. 1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
  57. /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
  58. 666 # the function a weak reference to itself to avoid a reference cycle.
  59. 667 with OptionalXlaContext(compile_with_xla):
  60. --> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  61. 669 return out
  62. 670
  63. /opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py in restored_function_body(*args, **kwargs)
  64. 292 .format(_pretty_format_positional(args), kwargs,
  65. 293 len(saved_function.concrete_functions),
  66. --> 294 "\n\n".join(signature_descriptions)))
  67. 295
  68. 296 concrete_function_objects = []
  69. ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  70. Positional arguments (4 total):
  71. * Tensor("inputs:0", shape=(64, 113), dtype=int64)
  72. * False
  73. * None
  74. * True
  75. Keyword arguments: {}
  76. Expected these arguments to match one of the following 4 option(s):
  77. Option 1:
  78. Positional arguments (4 total):
  79. * TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
  80. * False
  81. * None
  82. * False
  83. Keyword arguments: {}
  84. Option 2:
  85. Positional arguments (4 total):
  86. * TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
  87. * False
  88. * None
  89. * False
  90. Keyword arguments: {}
  91. Option 3:
  92. Positional arguments (4 total):
  93. * TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
  94. * True
  95. * None
  96. * False
  97. Keyword arguments: {}
  98. Option 4:
  99. Positional arguments (4 total):
  100. * TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
  101. * True
  102. * None
  103. * False
  104. Keyword arguments: {}

I think this is due to the custom parameter in the call method. Here is a minimal example that reproduces the problem:

  1. import tensorflow as tf
  2. class CustomModel(tf.keras.models.Model):
  3. def __init__(self):
  4. super().__init__()
  5. self.dense = tf.keras.layers.Dense(10)
  6. def call(self, inputs, custom_param=False):
  7. return self.dense(inputs)
  8. model = CustomModel()
  9. sample_inputs = tf.zeros((16, 30))
  10. print('Sample inputs:', sample_inputs)
  11. sample_outputs = model(sample_inputs)
  12. print('Sample outputs:', sample_outputs)
  13. model.save('saved_model')
  14. loaded_model = tf.keras.models.load_model('saved_model')
  15. sample_outputs_2 = loaded_model(sample_inputs, custom_param=True)
  16. print('Sample outputs 2:', sample_outputs_2)

Calling the reloaded model with custom_param taking any value other than the default always seems to fail.

Is this a bug or by design? How can I modify the model so that it returns just the output sequence while training, but returns the output sequence and state at inference time? This is so I can feed the state back into the model and generate additional characters during inference.

答案1

得分: 0

我找到了解决方法。解决方案在TensorFlow文档中,尽管不太清楚。

使用上述代码,加载的模型类型为keras.saving.saved_model.load.CustomModel,与原始类型不同。要恢复原始类型,需要执行以下操作。

CustomModel类需要get_configfrom_config方法。

  1. class CustomModel(tf.keras.models.Model):
  2. def __init__(self):
  3. super().__init__()
  4. self.dense = tf.keras.layers.Dense(10)
  5. def call(self, inputs, custom_param=False):
  6. return self.dense(inputs)
  7. def get_config(self):
  8. return {} # 最初传递给__init__的任何参数都应放在这里。
  9. @classmethod
  10. def from_config(cls, config):
  11. return cls(**config)

加载模型时,需要在custom_objects字典中传递自定义类。

  1. loaded_model = tf.keras.models.load_model('saved_model', custom_objects={'CustomModel': CustomModel})

然后,loaded_model的类型为CustomModel,使用custom_param=True进行调用。

英文:

I figured it out. The solution was in the TensorFlow docs, albeit not very clearly.

With the above code, the loaded model is of type keras.saving.saved_model.load.CustomModel, which is not the same as the original type. To get back the original type, you need to do the following.

The CustomModel class needs the get_config and from_config methods.

  1. class CustomModel(tf.keras.models.Model):
  2. def __init__(self):
  3. super().__init__()
  4. self.dense = tf.keras.layers.Dense(10)
  5. def call(self, inputs, custom_param=False):
  6. return self.dense(inputs)
  7. def get_config(self):
  8. return {} # Any parameters originally passed to __init__ should go here.
  9. @classmethod
  10. def from_config(cls, config):
  11. return cls(**config)

When loading the model, you need to pass the custom class in the custom_objects dictionary.

  1. loaded_model = tf.keras.models.load_model('saved_model', custom_objects={'CustomModel': CustomModel})

Then, loaded_model is of type CustomModel and calling it with custom_param=True works.

huangapple
  • 本文由 发表于 2023年2月14日 06:32:42
  • 转载请务必保留本文链接:https://go.coder-hub.com/75441813.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定