GOMESS_Bot / app.py
minoD's picture
changed sentences
9adf5ee
raw
history blame
2.89 kB
import gradio as gr
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "rinna/japanese-gpt-neox-3.6b"
peft_name = "minoD/GOMESS"
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = PeftModel.from_pretrained(
model,
peft_name,
device_map="cpu",
)
# プロンプトテンプレートの準備にカテゴリを追加
def generate_prompt(data_point, category=None):
category_part = f"### カテゴリ:\n{category}\n\n" if category else ""
result = f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 入力:\n{data_point['input']}\n\n### 回答:\n" if data_point["input"] else f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 回答:\n"
result = result.replace('\n', '<NL>')
return result
def generate(instruction, input=None, category=None, maxTokens=256):
# 推論
prompt = generate_prompt({'instruction':instruction, 'input':input}, category)
input_ids = tokenizer(prompt,
return_tensors="pt",
truncation=True,
add_special_tokens=False).input_ids
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=maxTokens,
do_sample=True,
temperature=0.7,
top_p=0.75,
top_k=40,
no_repeat_ngram_size=2,
)
outputs = outputs[0].tolist()
# EOSトークンにヒットしたらデコード完了
if tokenizer.eos_token_id in outputs:
eos_index = outputs.index(tokenizer.eos_token_id)
decoded = tokenizer.decode(outputs[:eos_index])
# レスポンス内容のみ抽出
sentinel = "### 回答:"
sentinelLoc = decoded.find(sentinel)
if sentinelLoc >= 0:
result = decoded[sentinelLoc+len(sentinel):]
return result.replace("<NL>", "\n") # <NL>→改行
else:
return 'Warning: Expected prompt template to be emitted. Ignoring output.'
else:
return 'Warning: no <eos> detected ignoring output'
# 既存のgenerate関数を使用しますが、print文を削除し、結果を返すように変更します。
import gradio as gr
# generate関数をGradio用に調整します。入力とカテゴリは固定されます。
def generate_for_gradio(instruction):
return generate(instruction, category="ES2Q", maxTokens=200)
# Gradioインターフェースを定義します。
iface = gr.Interface(
fn=generate_for_gradio,
inputs=[
gr.Textbox(lines=2, placeholder="ESの回答を入力してください")
],
outputs="text",
title="ESから質問を生成テスト",
description="エントリーシートから面接官が言いそうな質問を生成します。(精度:悪)"
)
iface.launch()