import gradio as gr import os import threading import arrow import time import argparse import logging from dataclasses import dataclass import torch import sentencepiece as spm from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.streamers import BaseStreamer from huggingface_hub import hf_hub_download, login logger = logging.getLogger() logger.setLevel("INFO") gr_interface = None VERSION = "1.0" @dataclass class DefaultArgs: hf_model_name_or_path: str = "cyberagent/open-calm-1b" spm_model_path: str = None env: str = "dev" port: int = 7860 make_public: bool = False args = DefaultArgs() def load_model( model_dir, ): model = AutoModelForCausalLM.from_pretrained(args.hf_model_name_or_path, device_map="auto", torch_dtype=torch.float32) if torch.cuda.is_available(): model = model.to("cuda:0") return model logging.info("Loading model") model = load_model(args.hf_model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(args.hf_model_name_or_path) logging.info("Finished loading model") class Streamer(BaseStreamer): def __init__(self, tokenizer): self.tokenizer = tokenizer self.num_invoked = 0 self.prompt = "" self.generated_text = "" self.ended = False def put(self, t: torch.Tensor): d = t.dim() if d == 1: pass elif d == 2: t = t[0] else: raise NotImplementedError t = [int(x) for x in t.numpy()] text = self.tokenizer.decode(t, skip_special_tokens=True) if self.num_invoked == 0: self.prompt = text self.num_invoked += 1 return self.generated_text += text logging.debug(f"[streamer]: {self.generated_text}") def end(self): self.ended = True def generate( prompt, max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ): log = dict(locals()) logging.debug(log) print(log) input_ids = tokenizer(prompt, return_tensors="pt")['input_ids'].to(model.device) max_possilbe_new_tokens = model.config.max_position_embeddings - len(input_ids.squeeze(0)) max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) streamer = Streamer(tokenizer=tokenizer) thr = threading.Thread(target=model.generate, args=(), kwargs=dict( input_ids=input_ids, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, max_new_tokens=max_possilbe_new_tokens, streamer=streamer, # max_length=4096, # top_k=100, # top_p=0.9, # num_return_sequences=2, # num_beams=2, )) thr.start() gen_tokens = model.generate( input_ids=input_ids, do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, max_new_tokens=max_possilbe_new_tokens, ) gen = tokenizer.decode(gen_tokens[0], skip_special_tokens=True) while not streamer.ended: time.sleep(0.05) yield streamer.generated_text # TODO: optimize for final few tokens gen = streamer.generated_text log.update(dict( generation=gen, version=VERSION, time=str(arrow.now("+09:00")))) logging.info(log) yield gen def process_feedback( rating, prompt, generation, max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ): log = dict(locals()) log.update(dict( time=str(arrow.now("+09:00")), version=VERSION, )) logging.info(log) if gr_interface: gr_interface.close(verbose=False) with gr.Blocks() as gr_interface: with gr.Row(): gr.Markdown(f"# {args.hf_model_name_or_path.split('/')[-1]} playground") with gr.Row(): # left panel with gr.Column(scale=1): # generation params with gr.Box(): gr.Markdown("hyper parameters") # hidden default params do_sample = gr.Checkbox(True, label="Do Sample", visible=True) no_repeat_ngram_size = gr.Slider(0, 10, value=5, step=1, label="No Repeat Ngram Size", visible=False) # visible params max_new_tokens = gr.Slider( 128, min(512, model.config.max_position_embeddings), value=128, step=128, label="max tokens", ) temperature = gr.Slider( 0, 1, value=0.7, step=0.05, label="temperature", ) repetition_penalty = gr.Slider( 1, 1.5, value=1.2, step=0.05, label="frequency penalty", ) # grouping params for easier reference gr_params = [ max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ] # right panel with gr.Column(scale=2): # user input block with gr.Box(): textbox_prompt = gr.Textbox( label="入力", placeholder="AIによって私達の暮らしは、", interactive=True, lines=5, value="AIによって私達の暮らしは、" ) with gr.Box(): with gr.Row(): btn_stop = gr.Button(value="キャンセル", variant="secondary") btn_submit = gr.Button(value="実行", variant="primary") # model output block with gr.Box(): textbox_generation = gr.Textbox( label="応答", lines=5, value="" ) # event handling inputs = [textbox_prompt] + gr_params click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) btn_stop.click(None, None, None, cancels=click_event, queue=False) gr_interface.queue(max_size=32, concurrency_count=2) gr_interface.launch(server_port=args.port, share=args.make_public)