如何在训练后保存 gpt-2-simple 模型?

huangapple go评论82阅读模式
英文:

How to save the gpt-2-simple model after training?

问题

I trained the gpt-2-simple chat bot model but I am unable to save it. It's important for me to download the trained model from colab because otherwise I have to download the 355M model each time (see below code).

我训练了gpt-2-simple聊天机器人模型,但无法保存它。对我来说,从Colab下载已训练的模型非常重要,否则每次都需要下载355M的模型(参见下面的代码)。

I tried various methods to save the trained model (like gpt2.saveload.save_gpt2()), but none worked and I don't have any more ideas.

我尝试了各种方法来保存训练好的模型(比如gpt2.saveload.save_gpt2()),但都没有成功,而且我没有更多的想法。

My training code:

我的训练代码:

英文:

I trained the gpt-2-simple chat bot model but I am unable to save it. It's important for me to download the trained model from colab because otherwise I have to download the 355M model each time (see below code).

I tried various methods to save the trained model (like gpt2.saveload.save_gpt2()), but none worked and I don't have any more ideas.

My training code:

%tensorflow_version 2.x
!pip install gpt-2-simple

import gpt_2_simple as gpt2
import json

gpt2.download_gpt2(model_name="355M")

raw_data = '/content/drive/My Drive/data.json'

with open(raw_data, 'r') as f:
    df =json.load(f)

data = []

for x in df:
    for y in range(len(x['dialog'])-1):
        a = '[BOT] : ' + x['dialog'][y+1]['text']
        q = '[YOU] : ' + x['dialog'][y]['text']
        data.append(q)
        data.append(a)

with open('chatbot.txt', 'w') as f:
     for line in data:
        try:
            f.write(line)
            f.write('\n')
        except:
            pass

file_name = "/content/chatbot.txt"

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=500,
              restore_from='fresh',
              run_name='run1',
              print_every=10,
              sample_every=100,
              save_every=100
              )

while True:
  ques = input("Question : ")
  inp = '[YOU] : '+ques+'\n'+'[BOT] :'
  x = gpt2.generate(sess,
                length=20,
                temperature = 0.6,
                include_prefix=False,
                prefix=inp,
                nsamples=1,
                )

答案1

得分: 1

以下是翻译好的部分:

"Gpt-2-simple" 仓库的 README.md 链接到一个示例 Colab 笔记本,其中提到了以下内容:

用于 gpt2.finetune 的其他可选但有用的参数:

  • restore_from:设置为 fresh 以从基本 GPT-2 开始训练,或者设置为 latest 以从现有检查点重新开始训练。
  • ...
  • run_name:检查点内的子文件夹,用于保存模型。如果您想使用多个模型,这很有用(在加载模型时还需要指定 run_name)。
  • overwrite:如果要继续微调现有模型(使用 restore_from='latest'),而不创建重复副本,则设置为 True

README.md 还指出,默认情况下,模型检查点存储在 /checkpoint/run1 中,您可以通过向 finetuneload_gpt2 传递 run_name 参数来存储/加载多个模型到一个检查点文件夹。

综合起来,您应该能够执行以下操作,以便从已保存的模型中工作,而不是每次重新下载:

import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()

# 从默认检查点目录中加载现有模型 "run1"
gpt2.load_gpt2(sess)

# 或者,微调默认检查点目录中的现有模型 "run1"
gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=500,
              restore_from='latest',
              run_name='run1',
              overwrite=True,
              print_every=10,
              sample_every=100,
              save_every=500
)

有关更多详细信息,请查看 load_gpt2()finetune() 函数的源代码。

英文:

The gpt-2-simple repository README.md links an example Colab notebook which states the following:

> Other optional-but-helpful parameters for gpt2.finetune:
> - restore_from: Set to fresh to start training from the base GPT-2, or set to latest to restart training from an existing checkpoint.
> - ...
> - run_name: subfolder within checkpoint to save the model. This is useful if you want to work with multiple models (will also need to specify run_name when loading the model)
> - overwrite: Set to True if you want to continue finetuning an existing model (w/ restore_from='latest') without creating duplicate copies.

The README.md also states that model checkpoints are stored in /checkpoint/run1 by default and that one can pass a run_name parameter to finetune and load_gpt2 if you want to store/load multiple models in a checkpoint folder.

Putting this altogether you should be able to do the following to work from saved models instead of re-downloading each time:

import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()

# To load existing model in default checkpoint dir from "run1"
gpt2.load_gpt2(sess)

# Or, to finetune existing model in default checkpoint dir from "run1"
gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=500,
              restore_from='latest',
              run_name='run1',
              overwrite=True,
              print_every=10,
              sample_every=100,
              save_every=500
)

See the source code for the load_gpt2() and finetune() functions for more specifics.

huangapple
  • 本文由 发表于 2023年4月6日 23:42:27
  • 转载请务必保留本文链接:https://go.coder-hub.com/75951366.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定