英文:
Is there way to embed non-tf functions to a tf.Keras model graph as SavedModel Signature?
问题
我想将预处理函数和方法添加到模型图作为SavedModel签名。
示例:
# 假设我们有一个Keras模型
# ...
# 定义我要添加到模型图中的函数
@tf.function
def process(model, img_path):
# 使用不同的库和模块进行一些预处理...
outputs = {"preds": model.predict(preprocessed_img)}
return outputs
# 使用自定义签名保存模型
tf.saved_model.save(new_model, dst_path,
signatures={"process": process})
或者我们可以在这里使用 tf.Module
。然而,问题是我无法将自定义函数嵌入到保存的模型图中。
有没有办法可以做到这一点?
英文:
I want to add preprocessing functions and methods to the model graph as a SavedModel signature.
example:
# suppose we have a keras model
# ...
# defining the function I want to add to the model graph
@tf.function
def process(model, img_path):
# do some preprocessing using different libs. and modules...
outputs = {"preds": model.predict(preprocessed_img)}
return outputs
# saving the model with a custom signature
tf.saved_model.save(new_model, dst_path,
signatures={"process": process})
or we can use tf.Module
here. However, the problem is I can not embed custom functions into the saved model graph.
Is there any way to do that?
答案1
得分: 1
根据文档,你稍微误解了Tensorflow中save_model
方法的目的。
save_model
方法的意图是将模型的图形序列化,以便稍后可以使用load_model
进行加载。
由load_model
返回的模型是tf.Module
类,具有所有的方法和属性。相反,你想要序列化预测流程。
老实说,我不知道有什么好的方法来做到这一点,但你可以使用不同的方法来序列化你的预处理参数,例如pickle或你使用的框架提供的其他方法,并在其之上编写一个类,该类执行以下操作:
class MyModel:
def __init__(self, model_path, preprocessing_path):
self.model = load_model(model_path)
self.preprocessing = load_preprocessing(preprocessing_path)
def predict(self, img_path):
return self.model.predict(self.preprocessing(img_path))
英文:
I think you slightly misunderstand the purpose of save_model
method in Tensorflow.
As per the documentation the intent is to have a method which serialises the model's graph so that it can be loaded with load_model
afterwards.
The model returned by load_model
is a class of tf.Module
with all it's methods and attributes. Instead you want to serialise the prediction pipeline.
To be honest, I'm not aware of a good way to do that, however what you can do is to use a different method for serialisation of your preprocessing parameters, for example pickle or a different one, provided by the framework you use and write a class on top of that, which would do the following:
class MyModel:
def __init__(self, model_path, preprocessing_path):
self.model = load_model(model_path)
self.preprocessing = load_preprocessing(preprocessing_path)
def predict(self, img_path):
return self.model.predict(self.preprocessing(img_path))
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论