File size: 4,264 Bytes
59e3834 5d42eda 3697a24 5d42eda 0da0787 59e3834 3697a24 59e3834 3697a24 739f5f0 02cddd1 3697a24 59e3834 3697a24 59e3834 3697a24 63643d0 3697a24 59e3834 bde65d6 5457d9a bde65d6 59e3834 5457d9a bde65d6 5457d9a bde65d6 5457d9a 59e3834 534269b 59e3834 91e30ca 59e3834 1312c32 59e3834 91e30ca 59e3834 91e30ca 59e3834 91e30ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from threading import Thread
from typing import Iterator
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
import transformers
from torch import cuda, bfloat16
from peft import PeftModel, PeftConfig
token = os.environ.get("HF_API_TOKEN")
base_model_id = 'meta-llama/Llama-2-7b-chat-hf'
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
bnb_config = transformers.BitsAndBytesConfig(
llm_int8_enable_fp32_cpu_offload = True
)
model_config = transformers.AutoConfig.from_pretrained(
base_model_id,
use_auth_token=token
)
model = transformers.AutoModelForCausalLM.from_pretrained(
base_model_id,
trust_remote_code=True,
config=model_config,
quantization_config=bnb_config,
# device_map='auto',
use_auth_token=token
)
config = PeftConfig.from_pretrained("Ashishkr/llama-2-medical-consultation")
model = PeftModel.from_pretrained(model, "Ashishkr/llama-2-medical-consultation").to(device)
model.eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model_id,
use_auth_token=token
)
# def get_prompt(message: str, chat_history: list[tuple[str, str]],
# system_prompt: str) -> str:
# texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# # The first user input is _not_ stripped
# do_strip = False
# for user_input, response in chat_history:
# user_input = user_input.strip() if do_strip else user_input
# do_strip = True
# texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
# message = message.strip() if do_strip else message
# texts.append(f'{message} [/INST]')
# return ''.join(texts)
# def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
# texts = [f'{system_prompt}\n']
# for user_input, response in chat_history[:-1]:
# texts.append(f'{user_input} {response}\n')
# # Getting the user input and response from the last tuple in the chat history
# last_user_input, last_response = chat_history[-1]
# texts.append(f' input: {last_user_input} {last_response} {message} response: ')
# return ''.join(texts)
def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
texts = [f'{system_prompt}\n']
# If chat_history is not empty, process all but the last entry
if chat_history:
for user_input, response in chat_history[:-1]:
texts.append(f'{user_input} {response}\n')
# Getting the user input and response from the last tuple in the chat history
last_user_input, last_response = chat_history[-1]
texts.append(f' input: {last_user_input} {last_response} {message} Response: ')
else:
# If chat_history is empty, just add the message with 'Response:' at the end
texts.append(f' input: {message} Response: ')
return ''.join(texts)
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
prompt = get_prompt(message, chat_history, system_prompt)
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def run(message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 50) -> str:
prompt = get_prompt(message, chat_history, system_prompt)
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
# Generate tokens using the model
output = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=max_new_tokens + inputs['input_ids'].shape[-1],
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1
)
# Decode the output tokens back to a string
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Remove everything including and after "instruct: "
output_text = output_text.split("instruct: ")[0]
return output_text
|