Testac-1002 / app.py
Sakalti's picture
Update app.py
454b875 verified
raw
history blame
2.23 kB
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
from huggingface_hub import login
# Gradioで使うための関数
def start_training(write_token, repo_name):
# Hugging Face APIトークンでログイン
login(token=write_token)
# range3/cc100-jaデータセットをロード
dataset = load_dataset("Sakalti/Multilingal-sakalt-data")
# モデルとトークナイザーをロード
model_name = "rinna/japanese-gpt-neox-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# トレーニング引数の設定
training_args = TrainingArguments(
output_dir="./results", # 結果の保存先
num_train_epochs=2, # エポック数
per_device_train_batch_size=8, # バッチサイズ
per_device_eval_batch_size=8, # 評価バッチサイズ
warmup_steps=500, # ウォームアップステップ数
weight_decay=0.01, # 重みの減衰
logging_dir="./logs", # ログディレクトリ
)
# Trainerの設定
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# トレーニングの実行
trainer.train()
# トレーニングが完了した後にモデルをHugging Face Hubにアップロード
model.push_to_hub(repo_name)
tokenizer.push_to_hub(repo_name)
return f"トレーニングが完了しました。モデルが'{repo_name}'にアップロードされました。"
# Gradioインターフェースを設定
interface = gr.Interface(
fn=start_training,
inputs=[
gr.Textbox(label="Hugging Face Write Token"),
gr.Textbox(label="Hugging Face リポジトリ名") # リポジトリパスの入力
],
outputs="text",
title="モデル トレーニング",
description="このボタンを押すと、指定したトークンでトレーニングが開始されます。"
)
# アプリの起動
interface.launch()