yigiao commited on
Commit
3f4db92
·
verified ·
1 Parent(s): f7da1e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py CHANGED
@@ -80,3 +80,36 @@ def show_data():
80
  # 使用 Gradio 界面显示测试数据
81
  demo = gr.Interface(fn=show_data, inputs=None, outputs="text", title="数据集测试")
82
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # 使用 Gradio 界面显示测试数据
81
  demo = gr.Interface(fn=show_data, inputs=None, outputs="text", title="数据集测试")
82
  demo.launch()
83
+
84
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
85
+
86
+ # 加载预训练模型和分词器
87
+ model_name = "gpt2"
88
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
89
+ model = AutoModelForCausalLM.from_pretrained(model_name)
90
+
91
+ # 数据集预处理
92
+ def tokenize_function(examples):
93
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
94
+
95
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
96
+
97
+ # 微调训练参数
98
+ training_args = TrainingArguments(
99
+ output_dir="./results",
100
+ evaluation_strategy="epoch",
101
+ learning_rate=5e-5,
102
+ per_device_train_batch_size=4,
103
+ num_train_epochs=3,
104
+ weight_decay=0.01,
105
+ )
106
+
107
+ # 微调
108
+ trainer = Trainer(
109
+ model=model,
110
+ args=training_args,
111
+ train_dataset=tokenized_datasets["train"],
112
+ eval_dataset=tokenized_datasets["test"],
113
+ )
114
+
115
+ trainer.train()