如何在加载预训练的转换模型时跳过权重初始化?

huangapple go评论70阅读模式
英文:

How to skip weights init when loading pretrained transformers model?

问题

  1. 我需要找出如何在开始时不初始化权重的情况下加载预训练的转换器模型(以节省时间和内存)?

    saved_inits = torch.nn.init.kaiming_uniform_,
        torch.nn.init.uniform_,
        torch.nn.init.normal_  # 保留
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    
    model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=args.model_path)
    
    torch.nn.init.kaiming_uniform_,
        torch.nn.init.uniform_,
        torch.nn.init.normal_ = saved_inits  # 恢复
    
  2. 对于nn.module子类,有torch.nn.utils.skip_init,但它不适用于AutoModelForCausalLM

Quest: 找到一种在AutoModelForCausalLM(或任何类似的变换器类)中跳过权重初始化的方法,可以使用一些标准包装器或参数。

英文:

I need to find out how to load a pretrained transformer model without initializing weights in the beginning (to save time and memory)?

  1. I saw this code example, but this is not elegant:
    saved_inits = torch.nn.init.kaiming_uniform_, 
        torch.nn.init.uniform_, 
        torch.nn.init.normal_  # preserving
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    
    model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=args.model_path)
    
    torch.nn.init.kaiming_uniform_, 
        torch.nn.init.uniform_, 
        torch.nn.init.normal_ = saved_inits  # restoring
    
  2. for nn.module subclasses there is torch.nn.utils.skip_init, but it won't work with AutoModelForCausalLM

Quest: find a way to skip weights initialization in AutoModelForCausalLM (or any similar transformers class) either using some standard wrapper or parameter.

答案1

得分: 1

以下是翻译好的内容:
"答案在cronoik的评论中提到了:"
"我使用Llama 30B进行了测试,发现加载时间加快了3倍,但内存使用没有增加。"

英文:

The answer was suggested in the comment by cronoik:

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=args.model_path,
    low_cpu_mem_usage=True
    )

I tested it with Llama 30B and found 3x acceleration in loading time, though no gain in memory use.

huangapple
  • 本文由 发表于 2023年5月29日 18:33:20
  • 转载请务必保留本文链接:https://go.coder-hub.com/76356591.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定