英文:
torch.load(model_dict.pt) Error It doesn't load
问题
error:
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
发生了一个错误,该错误是在加载一个使用PyTorch训练的.pt文件时发生的。
英文:
code:
model.load_state_dict(torch.load(path+'model_dict.pt', map_location=device))
error:
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
TypeError: Expected state_dict to be dict-like, got <class 'transformers.models.bert.modeling_bert.BertForTokenClassification'>.
An error occurred while loading a pt file that was trained with PyTorch.
答案1
得分: 1
请查看这个链接 https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
似乎你没有正确保存检查点。从你收到的错误来看,我猜你这样做了
torch.save(model, some_path)
如果你尝试通过 model.load_state_dict()
加载它,你会得到这种类型的错误。
正确保存检查点的方法是这样的
torch.save(model.state_dict(), some_path)
然后可以这样加载和使用它
model_state_dict = torch.load(some_path)
model.load_state_dict(model_state_dict)
一般来说,通常不仅保存模型状态字典,还保存更多信息,比如迄今为止的训练轮数,如果有的话,还保存优化器状态字典,例如如果你在使用 ADAM 优化器。一般来说,你需要保存一切用于使用模型,甚至可能是恢复训练的信息。这就是官方文档也做的事情。这是类似于以下方式的内容
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, some_path)
# 然后你可以通过以下方式加载它
checkpoint = torch.load(some_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
英文:
Please have a look at this https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
It seems like you did not save the checkpoint correctly. From what you get as an error I suppose you did
torch.save(model, some_path)
If you try to load this via model.load_state_dict()
you will get this kind of error.
How you save a checkpoint is by doing
torch.save(model.state_dict(), some_path)
This can then be loaded and uses
model_state_dict = torch.load(some_path)
model.load_state_dict(model_state_dict)
In general it often makes sense to not only save the model state dict, but more information like number of epochs so far, the optimizer state dict if it has one, e.g. if you are using ADAM. Generally everything you need to use the model and maybe also resume a training. This is what the official documentation also does. This is something along the line of
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, some_path)
# and then you can load it via
checkpoint = torch.load(some_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论