当前位置: 首页> 科技> 能源 > gpt-2语言模型训练

gpt-2语言模型训练

时间:2025/7/10 1:27:47来源:https://blog.csdn.net/m0_37570494/article/details/141433225 浏览次数:2次

一、通过下载对应的语言模型数据集 

1.1 根据你想让回答的内容,针对性下载对应的数据集,我下载的是个医疗问答数据集

1.2 针对你要用到的字段信息进行处理,然后把需要处理的数据丢给模型去训练,这个模型我是直接从GPT2的网站下载下来的依赖的必要文件截图如下:

二、具体代码样例实现:

import os
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, TextDataset, \DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForCausalLM# 读取CSV文件
data_path = '内科500.csv'  # 替换为你的CSV文件路径
df = pd.read_csv(data_path, encoding='ISO-8859-1')# 将数据集转换为适合训练的格式
def preprocess_dialogues(df):conversations = []for index, row in df.iterrows():department = row['department']title = row['title']ask = row['ask']answer = row['answer']# 将每条问答对转换为连续的对话context = f"科室: {department}\n问题: {title}\n提问: {ask}\n回答: {answer}\n"conversations.append(context)return conversationsconversations = preprocess_dialogues(df)# 保存对话数据到文本文件
train_file_path = 'train_data.txt'
with open(train_file_path, 'w', encoding='utf-8') as file:for conversation in conversations:file.write(conversation + '\n')# 加载预训练模型和tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-model')
model = GPT2LMHeadModel.from_pretrained('./gpt2-model')# 准备数据集
def load_dataset(file_path, tokenizer, block_size=128):return TextDataset(tokenizer=tokenizer,file_path=file_path,block_size=block_size)train_dataset = load_dataset(train_file_path, tokenizer)# 数据整理器
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm=False
)# 训练参数
training_args = TrainingArguments(output_dir='./results',overwrite_output_dir=True,num_train_epochs=3,per_device_train_batch_size=4,save_steps=10_000,save_total_limit=2,resume_from_checkpoint=True  # 从检查点恢复训练
)# 创建Trainer
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset
)last_checkpoint = None
if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):last_checkpoint = training_args.output_dir
# 开始训练
trainer.train(resume_from_checkpoint=last_checkpoint)
关键字:gpt-2语言模型训练

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: