英文:
How to call loaded Tensorflow model with custom params in call method?
问题
以下是要翻译的代码部分:
I've been following the [Tensorflow text generation tutorial][1] which includes two models, "MyModel" and "OneStep". "MyModel" is an RNN operating on vectorized strings; "OneStep" essentially wraps "MyModel" and operates on strings directly.
The tutorial saves and loads "OneStep", and I followed this successfully, but I now want to save and reload "MyModel". This is not done in the tutorial, and when I try to call the reloaded model with `return_state=True`, I get an error:
# TODO: Loaded model gives an error
for input_example_batch, target_example_batch in train_ds.take(1):
    example_batch_predictions, example_states = loaded_model(input_example_batch, False, None, return_state=True)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
    print(example_states.shape, "   # (batch_size, rnn_units)")
/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
    662 
def custom_model(inputs, custom_param=False):
    return dense(inputs)
model = CustomModel()
sample_inputs = tf.zeros((16, 30))
print('Sample inputs:', sample_inputs)
sample_outputs = model(sample_inputs)
print('Sample outputs:', sample_outputs)
model.save('saved_model')
loaded_model = tf.keras.models.load_model('saved_model')
sample_outputs_2 = loaded_model(sample_inputs, custom_param=True)
print('Sample outputs 2:', sample_outputs_2)
如果您有任何其他问题或需要进一步的帮助,请随时提问。
英文:
I've been following the Tensorflow text generation tutorial which includes two models, "MyModel" and "OneStep". "MyModel" is an RNN operating on vectorized strings; "OneStep" essentially wraps "MyModel" and operates on strings directly.
The tutorial saves and loads "OneStep", and I followed this successfully, but I now want to save and reload "MyModel". This is not done in the tutorial, and when I try to call the reloaded model with return_state=True, I get an error:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_23/2335414736.py in <module>
      1 # TODO: Loaded model gives an error
      2 for input_example_batch, target_example_batch in train_ds.take(1):
----> 3     example_batch_predictions, example_states = loaded_model(input_example_batch, False, None, return_state=True)
      4     print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
      5     print(example_states.shape, "   # (batch_size, rnn_units)")
/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in _call_attribute(instance, *args, **kwargs)
    662 
    663 def _call_attribute(instance, *args, **kwargs):
--> 664   return instance.__call__(*args, **kwargs)
    665 
    666 
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    931       # This is the first call of __call__, so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args, kwds, add_initializers_to=initializers)
    934     finally:
    935       # At this point we know that the initialization is complete (or less
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    758     self._concrete_stateful_fn = (
    759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 760             *args, **kwds))
    761 
    762     def invalid_creator_scope(*unused_args, **unused_kwds):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3306             arg_names=arg_names,
   3307             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3308             capture_by_value=self._capture_by_value),
   3309         self._function_attributes,
   3310         function_spec=self.function_spec,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 
/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py in restored_function_body(*args, **kwargs)
    292         .format(_pretty_format_positional(args), kwargs,
    293                 len(saved_function.concrete_functions),
--> 294                 "\n\n".join(signature_descriptions)))
    295 
    296   concrete_function_objects = []
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (4 total):
    * Tensor("inputs:0", shape=(64, 113), dtype=int64)
    * False
    * None
    * True
  Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
    * False
    * None
    * False
  Keyword arguments: {}
Option 2:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
    * False
    * None
    * False
  Keyword arguments: {}
Option 3:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='inputs')
    * True
    * None
    * False
  Keyword arguments: {}
Option 4:
  Positional arguments (4 total):
    * TensorSpec(shape=(None, 113), dtype=tf.int64, name='input_1')
    * True
    * None
    * False
  Keyword arguments: {}
I think this is due to the custom parameter in the call method. Here is a minimal example that reproduces the problem:
import tensorflow as tf
class CustomModel(tf.keras.models.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(10)
        
    def call(self, inputs, custom_param=False):
        return self.dense(inputs)
model = CustomModel()
sample_inputs = tf.zeros((16, 30))
print('Sample inputs:', sample_inputs)
sample_outputs = model(sample_inputs)
print('Sample outputs:', sample_outputs)
model.save('saved_model')
loaded_model = tf.keras.models.load_model('saved_model')
sample_outputs_2 = loaded_model(sample_inputs, custom_param=True)
print('Sample outputs 2:', sample_outputs_2)
Calling the reloaded model with custom_param taking any value other than the default always seems to fail.
Is this a bug or by design? How can I modify the model so that it returns just the output sequence while training, but returns the output sequence and state at inference time? This is so I can feed the state back into the model and generate additional characters during inference.
答案1
得分: 0
我找到了解决方法。解决方案在TensorFlow文档中,尽管不太清楚。
使用上述代码,加载的模型类型为keras.saving.saved_model.load.CustomModel,与原始类型不同。要恢复原始类型,需要执行以下操作。
CustomModel类需要get_config和from_config方法。
class CustomModel(tf.keras.models.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(10)
        
    def call(self, inputs, custom_param=False):
        return self.dense(inputs)
    def get_config(self):
        return {} # 最初传递给__init__的任何参数都应放在这里。
    @classmethod
    def from_config(cls, config):
        return cls(**config)
加载模型时,需要在custom_objects字典中传递自定义类。
loaded_model = tf.keras.models.load_model('saved_model', custom_objects={'CustomModel': CustomModel})
然后,loaded_model的类型为CustomModel,使用custom_param=True进行调用。
英文:
I figured it out. The solution was in the TensorFlow docs, albeit not very clearly.
With the above code, the loaded model is of type keras.saving.saved_model.load.CustomModel, which is not the same as the original type. To get back the original type, you need to do the following.
The CustomModel class needs the get_config and from_config methods.
class CustomModel(tf.keras.models.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(10)
        
    def call(self, inputs, custom_param=False):
        return self.dense(inputs)
    def get_config(self):
        return {} # Any parameters originally passed to __init__ should go here.
    @classmethod
    def from_config(cls, config):
        return cls(**config)
When loading the model, you need to pass the custom class in the custom_objects dictionary.
loaded_model = tf.keras.models.load_model('saved_model', custom_objects={'CustomModel': CustomModel})
Then, loaded_model is of type CustomModel and calling it with custom_param=True works.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。


评论