hdeldar commited on
Commit
07155e4
1 Parent(s): a8a9773

add to git

Browse files
Files changed (2) hide show
  1. app.py +220 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ from langfuse import Langfuse
7
+ from langfuse.decorators import observe
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+ import time
12
+
13
+ MAX_MAX_NEW_TOKENS = 2048
14
+ DEFAULT_MAX_NEW_TOKENS = 1024
15
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
16
+
17
+
18
+ DESCRIPTION = """\
19
+ # Dorna-Llama3-8B-Instruct Chat
20
+ """
21
+
22
+ PLACEHOLDER = """
23
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
24
+ <img src="https://avatars.githubusercontent.com/u/39557177?v=4" style="width: 80%; max-width: 550px; height: auto; opacity: 0.80; ">
25
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Dorna-Llama3-8B-Instruct</h1>
26
+ </div>
27
+ """
28
+
29
+ custom_css = """
30
+ @import url('https://fonts.googleapis.com/css2?family=Vazirmatn&display=swap');
31
+
32
+ body, .gradio-container, .gr-button, .gr-input, .gr-slider, .gr-dropdown, .gr-markdown {
33
+ font-family: 'Vazirmatn', sans-serif !important;
34
+ }
35
+
36
+ ._button {
37
+ font-size: 20px;
38
+ }
39
+
40
+ pre, code {
41
+ direction: ltr !important;
42
+ unicode-bidi: plaintext !important;
43
+ }
44
+ """
45
+
46
+
47
+ system_prompt = str(os.getenv("SYSTEM_PROMPT"))
48
+
49
+ secret_key = str(os.getenv("LANGFUSE_SECRET_KEY"))
50
+ public_key = str(os.getenv("LANGFUSE_PUBLIC_KEY"))
51
+ host = str(os.getenv("LANGFUSE_HOST"))
52
+
53
+ langfuse = Langfuse(
54
+ secret_key=secret_key,
55
+ public_key=public_key,
56
+ host=host
57
+ )
58
+
59
+
60
+ def execution_time_calculator(start_time, log=True):
61
+ delta = time.time() - start_time
62
+ if log:
63
+ print("--- %s seconds ---" % (delta))
64
+ return delta
65
+
66
+ def token_per_second_calculator(tokens_count, time_delta):
67
+ return tokens_count/time_delta
68
+
69
+ if not torch.cuda.is_available():
70
+ DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
71
+
72
+
73
+ if torch.cuda.is_available():
74
+ model_id = "PartAI/Dorna-Llama3-8B-Instruct"
75
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
76
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
77
+
78
+ generation_speed = 0
79
+
80
+ def get_generation_speed():
81
+ global generation_speed
82
+
83
+ return generation_speed
84
+
85
+ @observe()
86
+ def log_to_langfuse(message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, do_sample, generation_speed, model_outputs):
87
+ print(f"generation_speed: {generation_speed}")
88
+ return "".join(model_outputs)
89
+
90
+
91
+ @spaces.GPU
92
+ def generate(
93
+ message: str,
94
+ chat_history: list[tuple[str, str]],
95
+ max_new_tokens: int = 1024,
96
+ temperature: float = 0.6,
97
+ top_p: float = 0.9,
98
+ top_k: int = 50,
99
+ repetition_penalty: float = 1.2,
100
+ do_sample: bool =True,
101
+ ) -> Iterator[str]:
102
+ global generation_speed
103
+ global system_prompt
104
+
105
+ conversation = []
106
+ if system_prompt:
107
+ conversation.append({"role": "system", "content": system_prompt})
108
+ for user, assistant in chat_history:
109
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
110
+ conversation.append({"role": "user", "content": message})
111
+
112
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
113
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
114
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
115
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
116
+ input_ids = input_ids.to(model.device)
117
+
118
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
119
+ generate_kwargs = dict(
120
+ {"input_ids": input_ids},
121
+ streamer=streamer,
122
+ max_new_tokens=max_new_tokens,
123
+ do_sample=do_sample,
124
+ top_p=top_p,
125
+ top_k=top_k,
126
+ temperature=temperature,
127
+ num_beams=1,
128
+ repetition_penalty=repetition_penalty,
129
+ )
130
+
131
+ start_time = time.time()
132
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
133
+ t.start()
134
+
135
+ outputs = []
136
+ sum_tokens = 0
137
+ for text in streamer:
138
+ num_tokens = len(tokenizer.tokenize(text))
139
+ sum_tokens += num_tokens
140
+
141
+ outputs.append(text)
142
+ yield "".join(outputs)
143
+
144
+ time_delta = execution_time_calculator(start_time, log=False)
145
+
146
+ generation_speed = token_per_second_calculator(sum_tokens, time_delta)
147
+
148
+ log_function = log_to_langfuse(message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, do_sample, generation_speed, outputs)
149
+
150
+
151
+
152
+
153
+
154
+ chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, show_copy_button=True, height="68%", rtl=True) #, elem_classes=["chatbot"])
155
+ chat_input = gr.Textbox(show_label=False, lines=2, rtl=True, placeholder="ورودی", show_copy_button=True, scale=4)
156
+ submit_btn = gr.Button(variant="primary", value="ارسال", size="sm", scale=1, elem_classes=["_button"])
157
+
158
+
159
+ chat_interface = gr.ChatInterface(
160
+ fn=generate,
161
+ additional_inputs_accordion=gr.Accordion(label="ورودی‌های اضافی", open=False),
162
+ additional_inputs=[
163
+ gr.Slider(
164
+ label="حداکثر تعداد توکن ها",
165
+ minimum=1,
166
+ maximum=MAX_MAX_NEW_TOKENS,
167
+ step=1,
168
+ value=DEFAULT_MAX_NEW_TOKENS,
169
+ ),
170
+ gr.Slider(
171
+ label="Temperature",
172
+ minimum=0.01,
173
+ maximum=4.0,
174
+ step=0.01,
175
+ value=0.5,
176
+ ),
177
+ gr.Slider(
178
+ label="Top-p",
179
+ minimum=0.05,
180
+ maximum=1.0,
181
+ step=0.01,
182
+ value=0.9,
183
+ ),
184
+ gr.Slider(
185
+ label="Top-k",
186
+ minimum=1,
187
+ maximum=1000,
188
+ step=1,
189
+ value=20,
190
+ ),
191
+ gr.Slider(
192
+ label="جریمه تکرار",
193
+ minimum=1.0,
194
+ maximum=2.0,
195
+ step=0.05,
196
+ value=1.2,
197
+ ),
198
+ gr.Dropdown(
199
+ label="نمونه‌گیری",
200
+ choices=[False, True],
201
+ value=True)
202
+ ],
203
+ stop_btn="توقف",
204
+ chatbot=chatbot,
205
+ textbox=chat_input,
206
+ submit_btn=submit_btn,
207
+ retry_btn="🔄 تلاش مجدد",
208
+ undo_btn="↩️ بازگشت",
209
+ clear_btn="🗑️ پاک کردن",
210
+ title="درنا، محصول مرکز تحقیقات هوش مصنوعی پارت"
211
+ )
212
+
213
+
214
+ with gr.Blocks(css=custom_css, fill_height=False) as demo:
215
+ gr.Markdown(DESCRIPTION)
216
+ chat_interface.render()
217
+
218
+
219
+ if __name__ == "__main__":
220
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ accelerate
3
+ bitsandbytes
4
+ gradio
5
+ spaces
6
+ torch
7
+ transformers
8
+ langfuse