from threading import Thread from typing import Iterator import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import os import transformers from torch import cuda, bfloat16 from peft import PeftModel, PeftConfig token = os.environ.get("HF_API_TOKEN") base_model_id = 'meta-llama/Llama-2-7b-chat-hf' device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' bnb_config = transformers.BitsAndBytesConfig( llm_int8_enable_fp32_cpu_offload = True ) model_config = transformers.AutoConfig.from_pretrained( base_model_id, use_auth_token=token ) model = transformers.AutoModelForCausalLM.from_pretrained( base_model_id, trust_remote_code=True, config=model_config, quantization_config=bnb_config, # device_map='auto', use_auth_token=token ) config = PeftConfig.from_pretrained("Ashishkr/llama-2-medical-consultation") model = PeftModel.from_pretrained(model, "Ashishkr/llama-2-medical-consultation").to(device) model.eval() tokenizer = transformers.AutoTokenizer.from_pretrained( base_model_id, use_auth_token=token ) # def get_prompt(message: str, chat_history: list[tuple[str, str]], # system_prompt: str) -> str: # texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] # # The first user input is _not_ stripped # do_strip = False # for user_input, response in chat_history: # user_input = user_input.strip() if do_strip else user_input # do_strip = True # texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') # message = message.strip() if do_strip else message # texts.append(f'{message} [/INST]') # return ''.join(texts) # def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str: # texts = [f'{system_prompt}\n'] # for user_input, response in chat_history[:-1]: # texts.append(f'{user_input} {response}\n') # # Getting the user input and response from the last tuple in the chat history # last_user_input, last_response = chat_history[-1] # texts.append(f' input: {last_user_input} {last_response} {message} response: ') # return ''.join(texts) def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str: texts = [f'{system_prompt}\n'] # If chat_history is not empty, process all but the last entry if chat_history: for user_input, response in chat_history[:-1]: texts.append(f'{user_input} {response}\n') # Getting the user input and response from the last tuple in the chat history last_user_input, last_response = chat_history[-1] texts.append(f' input: {last_user_input} {last_response} {message} Response: ') else: # If chat_history is empty, just add the message with 'Response:' at the end texts.append(f' input: {message} Response: ') return ''.join(texts) def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int: prompt = get_prompt(message, chat_history, system_prompt) input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids'] return input_ids.shape[-1] def run(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 256, temperature: float = 0.8, top_p: float = 0.95, top_k: int = 50) -> str: prompt = get_prompt(message, chat_history, system_prompt) inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device) # Generate tokens using the model output = model.generate( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=max_new_tokens + inputs['input_ids'].shape[-1], do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1 ) # Decode the output tokens back to a string output_text = tokenizer.decode(output[0], skip_special_tokens=True) # Remove everything including and after "instruct: " output_text = output_text.split("instruct: ")[0] return output_text