Update app.py
Browse files
app.py
CHANGED
@@ -80,3 +80,36 @@ def show_data():
|
|
80 |
# 使用 Gradio 界面显示测试数据
|
81 |
demo = gr.Interface(fn=show_data, inputs=None, outputs="text", title="数据集测试")
|
82 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# 使用 Gradio 界面显示测试数据
|
81 |
demo = gr.Interface(fn=show_data, inputs=None, outputs="text", title="数据集测试")
|
82 |
demo.launch()
|
83 |
+
|
84 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
85 |
+
|
86 |
+
# 加载预训练模型和分词器
|
87 |
+
model_name = "gpt2"
|
88 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
89 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
90 |
+
|
91 |
+
# 数据集预处理
|
92 |
+
def tokenize_function(examples):
|
93 |
+
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
94 |
+
|
95 |
+
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
96 |
+
|
97 |
+
# 微调训练参数
|
98 |
+
training_args = TrainingArguments(
|
99 |
+
output_dir="./results",
|
100 |
+
evaluation_strategy="epoch",
|
101 |
+
learning_rate=5e-5,
|
102 |
+
per_device_train_batch_size=4,
|
103 |
+
num_train_epochs=3,
|
104 |
+
weight_decay=0.01,
|
105 |
+
)
|
106 |
+
|
107 |
+
# 微调
|
108 |
+
trainer = Trainer(
|
109 |
+
model=model,
|
110 |
+
args=training_args,
|
111 |
+
train_dataset=tokenized_datasets["train"],
|
112 |
+
eval_dataset=tokenized_datasets["test"],
|
113 |
+
)
|
114 |
+
|
115 |
+
trainer.train()
|