File size: 4,264 Bytes
59e3834
 
 
 
5d42eda
3697a24
 
 
5d42eda
0da0787
59e3834
3697a24
59e3834
3697a24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739f5f0
02cddd1
3697a24
 
 
 
59e3834
3697a24
59e3834
3697a24
 
63643d0
3697a24
59e3834
 
bde65d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5457d9a
 
 
 
 
 
 
 
 
 
 
 
bde65d6
 
59e3834
5457d9a
 
 
 
bde65d6
5457d9a
 
 
 
 
 
bde65d6
 
5457d9a
59e3834
 
 
 
 
 
 
 
 
 
534269b
59e3834
 
91e30ca
59e3834
1312c32
59e3834
91e30ca
 
 
 
 
59e3834
 
 
 
91e30ca
59e3834
91e30ca
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from threading import Thread
from typing import Iterator
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
import transformers
from torch import cuda, bfloat16
from peft import PeftModel, PeftConfig

token = os.environ.get("HF_API_TOKEN")

base_model_id = 'meta-llama/Llama-2-7b-chat-hf'

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

bnb_config = transformers.BitsAndBytesConfig(
    llm_int8_enable_fp32_cpu_offload = True
)

model_config = transformers.AutoConfig.from_pretrained(
    base_model_id,
    use_auth_token=token
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    base_model_id,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    # device_map='auto',
    use_auth_token=token
)

config = PeftConfig.from_pretrained("Ashishkr/llama-2-medical-consultation")
model = PeftModel.from_pretrained(model, "Ashishkr/llama-2-medical-consultation").to(device)

model.eval()

tokenizer = transformers.AutoTokenizer.from_pretrained(
    base_model_id,
    use_auth_token=token
)


# def get_prompt(message: str, chat_history: list[tuple[str, str]],
#                system_prompt: str) -> str:
#     texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
#     # The first user input is _not_ stripped
#     do_strip = False
#     for user_input, response in chat_history:
#         user_input = user_input.strip() if do_strip else user_input
#         do_strip = True
#         texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
#     message = message.strip() if do_strip else message
#     texts.append(f'{message} [/INST]')
#     return ''.join(texts)



# def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
#     texts = [f'{system_prompt}\n']

#     for user_input, response in chat_history[:-1]:
#         texts.append(f'{user_input} {response}\n')

#     # Getting the user input and response from the last tuple in the chat history
#     last_user_input, last_response = chat_history[-1]
#     texts.append(f' input: {last_user_input} {last_response} {message} response: ')

#     return ''.join(texts)

def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
    texts = [f'{system_prompt}\n']

    # If chat_history is not empty, process all but the last entry
    if chat_history:
        for user_input, response in chat_history[:-1]:
            texts.append(f'{user_input} {response}\n')

        # Getting the user input and response from the last tuple in the chat history
        last_user_input, last_response = chat_history[-1]
        texts.append(f' input: {last_user_input} {last_response} {message} Response: ')
    else:
        # If chat_history is empty, just add the message with 'Response:' at the end
        texts.append(f' input: {message} Response: ')

    return ''.join(texts)


def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]


def run(message: str,
        chat_history: list[tuple[str, str]],
        system_prompt: str,
        max_new_tokens: int = 256,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50) -> str:
    prompt = get_prompt(message, chat_history, system_prompt)
    inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)

    # Generate tokens using the model
    output = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_length=max_new_tokens + inputs['input_ids'].shape[-1],
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1
    )

    # Decode the output tokens back to a string
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Remove everything including and after "instruct: "
    output_text = output_text.split("instruct: ")[0]

    return output_text