是否有方法将非 tf 函数嵌入到 tf.Keras 模型图中,作为 SavedModel 签名?

huangapple go评论60阅读模式
英文:

Is there way to embed non-tf functions to a tf.Keras model graph as SavedModel Signature?

问题

我想将预处理函数和方法添加到模型图作为SavedModel签名。

示例:

# 假设我们有一个Keras模型
# ...

# 定义我要添加到模型图中的函数
@tf.function
def process(model, img_path):
    # 使用不同的库和模块进行一些预处理...

    outputs = {"preds": model.predict(preprocessed_img)}
    return outputs

# 使用自定义签名保存模型
tf.saved_model.save(new_model, dst_path, 
                    signatures={"process": process})

或者我们可以在这里使用 tf.Module。然而,问题是我无法将自定义函数嵌入到保存的模型图中。

有没有办法可以做到这一点?

英文:

I want to add preprocessing functions and methods to the model graph as a SavedModel signature.

example:


# suppose we have a keras model
# ...

# defining the function I want to add to the model graph
@tf.function
def process(model, img_path):
    # do some preprocessing using different libs. and modules...

    outputs = {"preds": model.predict(preprocessed_img)}
    return outputs

# saving the model with a custom signature
tf.saved_model.save(new_model, dst_path, 
                    signatures={"process": process})

or we can use tf.Module here. However, the problem is I can not embed custom functions into the saved model graph.

Is there any way to do that?

答案1

得分: 1

根据文档,你稍微误解了Tensorflow中save_model方法的目的。

save_model方法的意图是将模型的图形序列化,以便稍后可以使用load_model进行加载

load_model返回的模型是tf.Module类,具有所有的方法和属性。相反,你想要序列化预测流程。

老实说,我不知道有什么好的方法来做到这一点,但你可以使用不同的方法来序列化你的预处理参数,例如pickle或你使用的框架提供的其他方法,并在其之上编写一个类,该类执行以下操作:

class MyModel:
    def __init__(self, model_path, preprocessing_path):
        self.model = load_model(model_path)
        self.preprocessing = load_preprocessing(preprocessing_path)

    def predict(self, img_path):
        return self.model.predict(self.preprocessing(img_path))
英文:

I think you slightly misunderstand the purpose of save_model method in Tensorflow.

As per the documentation the intent is to have a method which serialises the model's graph so that it can be loaded with load_model afterwards.

The model returned by load_model is a class of tf.Module with all it's methods and attributes. Instead you want to serialise the prediction pipeline.

To be honest, I'm not aware of a good way to do that, however what you can do is to use a different method for serialisation of your preprocessing parameters, for example pickle or a different one, provided by the framework you use and write a class on top of that, which would do the following:

class MyModel:
    def __init__(self, model_path, preprocessing_path):
        self.model = load_model(model_path)
        self.preprocessing = load_preprocessing(preprocessing_path)

    def predict(self, img_path):
        return self.model.predict(self.preprocessing(img_path))

huangapple
  • 本文由 发表于 2023年2月19日 09:02:19
  • 转载请务必保留本文链接:https://go.coder-hub.com/75497369.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定