英文:
How to save the gpt-2-simple model after training?
问题
I trained the gpt-2-simple chat bot model but I am unable to save it. It's important for me to download the trained model from colab because otherwise I have to download the 355M model each time (see below code).
我训练了gpt-2-simple聊天机器人模型,但无法保存它。对我来说,从Colab下载已训练的模型非常重要,否则每次都需要下载355M的模型(参见下面的代码)。
I tried various methods to save the trained model (like gpt2.saveload.save_gpt2()
), but none worked and I don't have any more ideas.
我尝试了各种方法来保存训练好的模型(比如gpt2.saveload.save_gpt2()
),但都没有成功,而且我没有更多的想法。
My training code:
我的训练代码:
英文:
I trained the gpt-2-simple chat bot model but I am unable to save it. It's important for me to download the trained model from colab because otherwise I have to download the 355M model each time (see below code).
I tried various methods to save the trained model (like gpt2.saveload.save_gpt2()
), but none worked and I don't have any more ideas.
My training code:
%tensorflow_version 2.x
!pip install gpt-2-simple
import gpt_2_simple as gpt2
import json
gpt2.download_gpt2(model_name="355M")
raw_data = '/content/drive/My Drive/data.json'
with open(raw_data, 'r') as f:
df =json.load(f)
data = []
for x in df:
for y in range(len(x['dialog'])-1):
a = '[BOT] : ' + x['dialog'][y+1]['text']
q = '[YOU] : ' + x['dialog'][y]['text']
data.append(q)
data.append(a)
with open('chatbot.txt', 'w') as f:
for line in data:
try:
f.write(line)
f.write('\n')
except:
pass
file_name = "/content/chatbot.txt"
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=500,
restore_from='fresh',
run_name='run1',
print_every=10,
sample_every=100,
save_every=100
)
while True:
ques = input("Question : ")
inp = '[YOU] : '+ques+'\n'+'[BOT] :'
x = gpt2.generate(sess,
length=20,
temperature = 0.6,
include_prefix=False,
prefix=inp,
nsamples=1,
)
答案1
得分: 1
以下是翻译好的部分:
"Gpt-2-simple" 仓库的 README.md 链接到一个示例 Colab 笔记本,其中提到了以下内容:
用于 gpt2.finetune 的其他可选但有用的参数:
restore_from
:设置为fresh
以从基本 GPT-2 开始训练,或者设置为latest
以从现有检查点重新开始训练。- ...
run_name
:检查点内的子文件夹,用于保存模型。如果您想使用多个模型,这很有用(在加载模型时还需要指定run_name
)。overwrite
:如果要继续微调现有模型(使用restore_from='latest'
),而不创建重复副本,则设置为True
。
README.md 还指出,默认情况下,模型检查点存储在 /checkpoint/run1
中,您可以通过向 finetune
和 load_gpt2
传递 run_name
参数来存储/加载多个模型到一个检查点文件夹。
综合起来,您应该能够执行以下操作,以便从已保存的模型中工作,而不是每次重新下载:
import gpt_2_simple as gpt2
sess = gpt2.start_tf_sess()
# 从默认检查点目录中加载现有模型 "run1"
gpt2.load_gpt2(sess)
# 或者,微调默认检查点目录中的现有模型 "run1"
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=500,
restore_from='latest',
run_name='run1',
overwrite=True,
print_every=10,
sample_every=100,
save_every=500
)
有关更多详细信息,请查看 load_gpt2() 和 finetune() 函数的源代码。
英文:
The gpt-2-simple repository README.md links an example Colab notebook which states the following:
> Other optional-but-helpful parameters for gpt2.finetune:
> - restore_from
: Set to fresh
to start training from the base GPT-2, or set to latest to restart
training from an existing checkpoint.
> - ...
> - run_name
: subfolder within checkpoint to save the model. This is useful if you want to work with multiple models (will also need to specify run_name
when loading the model)
> - overwrite
: Set to True
if you want to continue finetuning an existing model (w/ restore_from='latest'
) without creating duplicate copies.
The README.md also states that model checkpoints are stored in /checkpoint/run1
by default and that one can pass a run_name
parameter to finetune
and load_gpt2
if you want to store/load multiple models in a checkpoint folder.
Putting this altogether you should be able to do the following to work from saved models instead of re-downloading each time:
import gpt_2_simple as gpt2
sess = gpt2.start_tf_sess()
# To load existing model in default checkpoint dir from "run1"
gpt2.load_gpt2(sess)
# Or, to finetune existing model in default checkpoint dir from "run1"
gpt2.finetune(sess,
dataset=file_name,
model_name='355M',
steps=500,
restore_from='latest',
run_name='run1',
overwrite=True,
print_every=10,
sample_every=100,
save_every=500
)
See the source code for the load_gpt2() and finetune() functions for more specifics.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论