Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import threading | |
import arrow | |
import time | |
import argparse | |
import logging | |
from dataclasses import dataclass | |
import torch | |
import sentencepiece as spm | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation.streamers import BaseStreamer | |
from huggingface_hub import hf_hub_download, login | |
logger = logging.getLogger() | |
logger.setLevel("INFO") | |
gr_interface = None | |
VERSION = "1.0" | |
class DefaultArgs: | |
hf_model_name_or_path: str = "cyberagent/open-calm-1b" | |
spm_model_path: str = None | |
env: str = "dev" | |
port: int = 7860 | |
make_public: bool = False | |
args = DefaultArgs() | |
def load_model( | |
model_dir, | |
): | |
model = AutoModelForCausalLM.from_pretrained(args.hf_model_name_or_path, device_map="auto", torch_dtype=torch.float32) | |
if torch.cuda.is_available(): | |
model = model.to("cuda:0") | |
return model | |
logging.info("Loading model") | |
model = load_model(args.hf_model_name_or_path) | |
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_name_or_path) | |
logging.info("Finished loading model") | |
class Streamer(BaseStreamer): | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.num_invoked = 0 | |
self.prompt = "" | |
self.generated_text = "" | |
self.ended = False | |
def put(self, t: torch.Tensor): | |
d = t.dim() | |
if d == 1: | |
pass | |
elif d == 2: | |
t = t[0] | |
else: | |
raise NotImplementedError | |
t = [int(x) for x in t.numpy()] | |
text = self.tokenizer.decode(t, skip_special_tokens=True) | |
if self.num_invoked == 0: | |
self.prompt = text | |
self.num_invoked += 1 | |
return | |
self.generated_text += text | |
logging.debug(f"[streamer]: {self.generated_text}") | |
def end(self): | |
self.ended = True | |
def generate( | |
prompt, | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
): | |
log = dict(locals()) | |
logging.debug(log) | |
print(log) | |
input_ids = tokenizer(prompt, return_tensors="pt")['input_ids'].to(model.device) | |
max_possilbe_new_tokens = model.config.max_position_embeddings - len(input_ids.squeeze(0)) | |
max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) | |
streamer = Streamer(tokenizer=tokenizer) | |
thr = threading.Thread(target=model.generate, args=(), kwargs=dict( | |
input_ids=input_ids, | |
do_sample=do_sample, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_new_tokens=max_possilbe_new_tokens, | |
streamer=streamer, | |
# max_length=4096, | |
# top_k=100, | |
# top_p=0.9, | |
# num_return_sequences=2, | |
# num_beams=2, | |
)) | |
thr.start() | |
gen_tokens = model.generate( | |
input_ids=input_ids, | |
do_sample=do_sample, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_new_tokens=max_possilbe_new_tokens, | |
) | |
gen = tokenizer.decode(gen_tokens[0], skip_special_tokens=True) | |
while not streamer.ended: | |
time.sleep(0.05) | |
yield streamer.generated_text | |
# TODO: optimize for final few tokens | |
gen = streamer.generated_text | |
log.update(dict( | |
generation=gen, | |
version=VERSION, | |
time=str(arrow.now("+09:00")))) | |
logging.info(log) | |
yield gen | |
def process_feedback( | |
rating, | |
prompt, | |
generation, | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
): | |
log = dict(locals()) | |
log.update(dict( | |
time=str(arrow.now("+09:00")), | |
version=VERSION, | |
)) | |
logging.info(log) | |
if gr_interface: | |
gr_interface.close(verbose=False) | |
with gr.Blocks() as gr_interface: | |
with gr.Row(): | |
gr.Markdown(f"# {args.hf_model_name_or_path.split('/')[-1]} playground") | |
with gr.Row(): | |
# left panel | |
with gr.Column(scale=1): | |
# generation params | |
with gr.Box(): | |
gr.Markdown("hyper parameters") | |
# hidden default params | |
do_sample = gr.Checkbox(True, label="Do Sample", visible=True) | |
no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False) | |
# visible params | |
max_new_tokens = gr.Slider( | |
128, | |
min(512, model.config.max_position_embeddings), | |
value=128, | |
step=128, | |
label="max tokens", | |
) | |
temperature = gr.Slider( | |
0, 1, value=0.7, step=0.05, label="temperature", | |
) | |
repetition_penalty = gr.Slider( | |
1, 1.5, value=1.2, step=0.05, label="frequency penalty", | |
) | |
# grouping params for easier reference | |
gr_params = [ | |
max_new_tokens, | |
temperature, | |
repetition_penalty, | |
do_sample, | |
no_repeat_ngram_size, | |
] | |
# right panel | |
with gr.Column(scale=2): | |
# user input block | |
with gr.Box(): | |
textbox_prompt = gr.Textbox( | |
label="入力", | |
placeholder="AIによって私達の暮らしは、", | |
interactive=True, | |
lines=5, | |
value="AIによって私達の暮らしは、" | |
) | |
with gr.Box(): | |
with gr.Row(): | |
btn_stop = gr.Button(value="キャンセル", variant="secondary") | |
btn_submit = gr.Button(value="実行", variant="primary") | |
# model output block | |
with gr.Box(): | |
textbox_generation = gr.Textbox( | |
label="応答", | |
lines=5, | |
value="" | |
) | |
# event handling | |
inputs = [textbox_prompt] + gr_params | |
click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) | |
btn_stop.click(None, None, None, cancels=click_event, queue=False) | |
gr_interface.queue(max_size=32, concurrency_count=2) | |
gr_interface.launch(server_port=args.port, share=args.make_public) | |