Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from peft import PeftModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
model_name = "rinna/japanese-gpt-neox-3.6b" | |
peft_name = "minoD/GOMESS" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="cpu", | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = PeftModel.from_pretrained( | |
model, | |
peft_name, | |
device_map="cpu", | |
) | |
# プロンプトテンプレートの準備にカテゴリを追加 | |
def generate_prompt(data_point, category=None): | |
category_part = f"### カテゴリ:\n{category}\n\n" if category else "" | |
result = f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 入力:\n{data_point['input']}\n\n### 回答:\n" if data_point["input"] else f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 回答:\n" | |
result = result.replace('\n', '<NL>') | |
return result | |
def generate(instruction, input=None, category=None, maxTokens=256): | |
# 推論 | |
prompt = generate_prompt({'instruction':instruction, 'input':input}, category) | |
input_ids = tokenizer(prompt, | |
return_tensors="pt", | |
truncation=True, | |
add_special_tokens=False).input_ids | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=maxTokens, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.75, | |
top_k=40, | |
no_repeat_ngram_size=2, | |
) | |
outputs = outputs[0].tolist() | |
# EOSトークンにヒットしたらデコード完了 | |
if tokenizer.eos_token_id in outputs: | |
eos_index = outputs.index(tokenizer.eos_token_id) | |
decoded = tokenizer.decode(outputs[:eos_index]) | |
# レスポンス内容のみ抽出 | |
sentinel = "### 回答:" | |
sentinelLoc = decoded.find(sentinel) | |
if sentinelLoc >= 0: | |
result = decoded[sentinelLoc+len(sentinel):] | |
return result.replace("<NL>", "\n") # <NL>→改行 | |
else: | |
return 'Warning: Expected prompt template to be emitted. Ignoring output.' | |
else: | |
return 'Warning: no <eos> detected ignoring output' | |
# 既存のgenerate関数を使用しますが、print文を削除し、結果を返すように変更します。 | |
import gradio as gr | |
# generate関数をGradio用に調整します。入力とカテゴリは固定されます。 | |
def generate_for_gradio(instruction): | |
return generate(instruction, category="ES2Q", maxTokens=200) | |
# Gradioインターフェースを定義します。 | |
iface = gr.Interface( | |
fn=generate_for_gradio, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="ESの回答を入力してください") | |
], | |
outputs="text", | |
title="ESから質問を生成テスト", | |
description="エントリーシートから面接官が言いそうな質問を生成します。(精度:悪)" | |
) | |
iface.launch() |