import sys from threading import Thread import gradio as gr import spaces from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, ) import torch MODEL = "microsoft/Phi-3.5-mini-instruct" if torch.cuda.is_available(): device = "cuda" elif sys.platform == "darwin" and torch.backends.mps.is_available(): device = "mps" else: device = "cpu" # TODO understand this quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) tokenizer = AutoTokenizer.from_pretrained(MODEL) model = AutoModelForCausalLM.from_pretrained( MODEL, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config, ) @spaces.GPU() def stream_chat( message: str, history: list, system_prompt: str, temperature: float = 0.8, max_new_tokens: int = 1024, top_p: float = 1.0, top_k: int = 20, penalty: float = 1.2, ): print(f"message: {message}") print(f"history: {history}") conversation = [{"role": "system", "content": system_prompt}] for prompt, answer in history: conversation.extend( [ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, ] ) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=False if temperature == 0 else True, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=[128001, 128008, 128009], streamer=streamer, ) with torch.no_grad(): thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.ChatInterface( stream_chat, additional_inputs=[ gr.Textbox(value="You are an ARM Assembly language decoder. You receive a line of Arm assembly and respond with a description of what the instruction does.", label="System message"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": demo.launch()