Deactivate PyTorch Lightning 模型在预测期间的日志记录

huangapple go评论64阅读模式
英文:

Deativate Pytorch Lightning Module Model Logging During Prediction

问题

我正在尝试使用FastAPI提供PyTorch Forecasting模型。在启动时,我使用以下代码从检查点加载模型:

model = BaseModel.load_from_checkpoint(model_path)
model.eval()

尽管预测结果看起来正常,但每次进行预测后,lightning_logs文件夹中都会生成一个新版本,其中包含每次预测后存储的新文件中的超参数。我在进行预测时使用以下代码:

raw_predictions = model.predict(df, mode="raw", return_x=True)

如何在提供模型进行预测时停止记录日志?

英文:

I am trying to serve a Pytorch Forecasting model using FastAPI. I am loading the model from a checkpoint using the following code on startup:

model = BaseModel.load_from_checkpoint(model_path)

model.eval()

Although the predictions do come up fine, every time there's a new version generated in the lightining_logs folder with the hyperparameters stored in a new file after each prediction. I use the following code for the predictions:

raw_predictions = model.predict(df, mode="raw", return_x=True)

How can I stop logging when I serve the model for predictions?

答案1

得分: 1

Hi heres what i normally do

  1. 通常我会将模型保存为普通的pt文件,PyTorch Lightning与PyTorch完全兼容(当然,你需要将它从LightningModule重新设计为普通的nn.Module类)
  2. 将模型保存为ONNX模型
from model import Model
import pytorch_lightning as pl
import torch

model: pl.LightningModule = Model()
torch.save(model.state_dict(), 'weights.pt')

# 或保存为ONNX
torch.onnx.export(model, (inputs), fname)
英文:

Hi heres what i normally do

  1. Save as a normal pt file pytorch lighthning is fully compatible with pytorch (of course you have to redesign from a LightningModule to a normal nn.Module class)
  2. Save as onnx model
from model import Model
import pytorch_lightning as pl
import torch

model:pl.LightningModule = Model()
torch.save(model.state_dict(), 'weights.pt')

# Or save to onnx
torch.onnx.export(model, (inputs), fname))

答案2

得分: 0

以下是您要翻译的内容:

有人在GitHub上发布了答案,与我在进行了大量阅读后发现它的时间差不多。对我来说,这并不那么明显:

trainer_kwargs={'logger': False}

对于我的问题中的代码,预测部分将变成:

raw_predictions = model.predict(df, mode="raw", return_x=False, trainer_kwargs=dict(accelerator="cpu|gpu", logger=False))
英文:

Someone posted the answer on GitHub around the same time I discovered it after doing lots of reading. It's not that evident, at least for me:

trainer_kwargs={'logger':False}

In the case of the code in my question the prediction part would turn into:

raw_predictions = model.predict(df, mode="raw", return_x=False, trainer_kwardgs=dict(accelarator="cpu|gpu", logger=False))

huangapple
  • 本文由 发表于 2023年5月26日 09:34:16
  • 转载请务必保留本文链接:https://go.coder-hub.com/76337140.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定