“torch.load(model_dict.pt) Error It doesn’t load” 没有加载。

huangapple go评论81阅读模式
英文:

torch.load(model_dict.pt) Error It doesn't load

问题

error:

raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))

发生了一个错误,该错误是在加载一个使用PyTorch训练的.pt文件时发生的。

英文:

code:

model.load_state_dict(torch.load(path+'model_dict.pt', map_location=device))
error:
    raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
TypeError: Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForTokenClassification'>.

An error occurred while loading a pt file that was trained with PyTorch.

答案1

得分: 1

请查看这个链接 https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

似乎你没有正确保存检查点。从你收到的错误来看,我猜你这样做了

torch.save(model, some_path)

如果你尝试通过 model.load_state_dict() 加载它,你会得到这种类型的错误。

正确保存检查点的方法是这样的

torch.save(model.state_dict(), some_path)

然后可以这样加载和使用它

model_state_dict = torch.load(some_path)
model.load_state_dict(model_state_dict)

一般来说,通常不仅保存模型状态字典,还保存更多信息,比如迄今为止的训练轮数,如果有的话,还保存优化器状态字典,例如如果你在使用 ADAM 优化器。一般来说,你需要保存一切用于使用模型,甚至可能是恢复训练的信息。这就是官方文档也做的事情。这是类似于以下方式的内容

torch.save({'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
            }, some_path)

# 然后你可以通过以下方式加载它
checkpoint = torch.load(some_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
英文:

Please have a look at this https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

It seems like you did not save the checkpoint correctly. From what you get as an error I suppose you did

torch.save(model, some_path)

If you try to load this via model.load_state_dict() you will get this kind of error.

How you save a checkpoint is by doing

torch.save(model.state_dict(), some_path)

This can then be loaded and uses

model_state_dict = torch.load(some_path)
model.load_state_dict(model_state_dict)

In general it often makes sense to not only save the model state dict, but more information like number of epochs so far, the optimizer state dict if it has one, e.g. if you are using ADAM. Generally everything you need to use the model and maybe also resume a training. This is what the official documentation also does. This is something along the line of

torch.save({'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
            }, some_path)

# and then you can load it via
checkpoint = torch.load(some_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

huangapple
  • 本文由 发表于 2023年6月29日 12:47:26
  • 转载请务必保留本文链接:https://go.coder-hub.com/76578117.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定