`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

huangapple go评论75阅读模式
英文:

`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

问题

我尝试使用pytorch_forecasting模块中的Temporal Fusion Transformer,但在trainer.fit方法中出现错误:'model'必须是'LightningModule'或'torch._dynamo.OptimizedModule',得到了'TemporalFusionTransformer'。我只是复制了'towardsdatascience'上的这篇论文。参考链接:https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91#:~:text=T%20emporal%20F%20usion%20T,dynamics%20of%20multiple%20time%20sequences.

英文:

im trying to temporal fusion transformer from pytorch_forecasting module but im getting the error in trainer.fit method:model must be a LightningModule or torch._dynamo.OptimizedModule, got TemporalFusionTransformer.I'm just replicating this paper from 'towardsdatascience'.reference:https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91#:~:text=T%20emporal%20F%20usion%20T,dynamics%20of%20multiple%20time%20sequences.

答案1

得分: 5

pytorch-forecasting的要求已更新,pytorch lightning不再导入为lightning.pytorch,而是pytorch_lightning

pytorch-forecastingbasemodel.py中进行此更改解决了我的问题。

英文:

There has been an update to pytorch-forecasting requirements and pytorch lightning no longer imports as lightning.pytorch, but pytorch_lightning.

Changing this in pytorch-forecasting basemodel.py solved the issue for me.

答案2

得分: 2

Soffies'回答基本上是正确的。根本原因是import/from lightning.pytorch与import/from pytorch_lightning不兼容。
在pytorch_forecasting中,作者广泛使用import lightning.pytorch。
因此,你只需要在你的代码中将'pytorch_lightning'替换为'lightning.pytorch'。
另外,在更新到torch2后,我遇到了这个问题。

英文:

Soffies' answer is basically correct. The fundamental reason is that import/from lightning.pytorch is incompatible with import/from pytorch_lightning.
In pytorch_forecasting, the author uses import lightning.pytorch extensively.
Therefore, all you need to do is to replace 'pytorch_lightning' with 'lightning.pytorch' in the your code.
Additionally, I encountered this issue after updating to torch2

huangapple
  • 本文由 发表于 2023年4月13日 15:51:59
  • 转载请务必保留本文链接:https://go.coder-hub.com/76002944.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定