|
import os |
|
from threading import Thread |
|
from typing import Iterator |
|
|
|
import gradio as gr |
|
from langfuse import Langfuse |
|
from langfuse.decorators import observe |
|
import spaces |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import time |
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
|
|
|
|
|
DESCRIPTION = """\ |
|
# Dorna-Llama3-8B-Instruct Chat |
|
""" |
|
|
|
PLACEHOLDER = """ |
|
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> |
|
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Test Dorna-Llama3-8B-Instruct</h1> |
|
</div> |
|
""" |
|
|
|
custom_css = """ |
|
@import url('https://fonts.googleapis.com/css2?family=Vazirmatn&display=swap'); |
|
|
|
body, .gradio-container, .gr-button, .gr-input, .gr-slider, .gr-dropdown, .gr-markdown { |
|
font-family: 'Vazirmatn', sans-serif !important; |
|
} |
|
|
|
._button { |
|
font-size: 20px; |
|
} |
|
|
|
pre, code { |
|
direction: ltr !important; |
|
unicode-bidi: plaintext !important; |
|
} |
|
""" |
|
|
|
|
|
system_prompt = str(os.getenv("SYSTEM_PROMPT")) |
|
|
|
secret_key = str(os.getenv("LANGFUSE_SECRET_KEY")) |
|
public_key = str(os.getenv("LANGFUSE_PUBLIC_KEY")) |
|
host = str(os.getenv("LANGFUSE_HOST")) |
|
|
|
langfuse = Langfuse( |
|
secret_key=secret_key, |
|
public_key=public_key, |
|
host=host |
|
) |
|
|
|
|
|
def execution_time_calculator(start_time, log=True): |
|
delta = time.time() - start_time |
|
if log: |
|
print("--- %s seconds ---" % (delta)) |
|
return delta |
|
|
|
def token_per_second_calculator(tokens_count, time_delta): |
|
return tokens_count/time_delta |
|
|
|
if not torch.cuda.is_available(): |
|
DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
model_id = "PartAI/Dorna-Llama3-8B-Instruct" |
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
generation_speed = 0 |
|
|
|
def get_generation_speed(): |
|
global generation_speed |
|
|
|
return generation_speed |
|
|
|
@observe() |
|
def log_to_langfuse(message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, do_sample, generation_speed, model_outputs): |
|
print(f"generation_speed: {generation_speed}") |
|
return "".join(model_outputs) |
|
|
|
|
|
@spaces.GPU |
|
def generate( |
|
message: str, |
|
chat_history: list[tuple[str, str]], |
|
max_new_tokens: int = 1024, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.2, |
|
do_sample: bool =True, |
|
) -> Iterator[str]: |
|
global generation_speed |
|
global system_prompt |
|
|
|
conversation = [] |
|
if system_prompt: |
|
conversation.append({"role": "system", "content": system_prompt}) |
|
for user, assistant in chat_history: |
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
conversation.append({"role": "user", "content": message}) |
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") |
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
|
input_ids = input_ids.to(model.device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
{"input_ids": input_ids}, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature=temperature, |
|
num_beams=1, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
|
|
start_time = time.time() |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
sum_tokens = 0 |
|
for text in streamer: |
|
num_tokens = len(tokenizer.tokenize(text)) |
|
sum_tokens += num_tokens |
|
|
|
outputs.append(text) |
|
yield "".join(outputs) |
|
|
|
time_delta = execution_time_calculator(start_time, log=False) |
|
|
|
generation_speed = token_per_second_calculator(sum_tokens, time_delta) |
|
|
|
log_function = log_to_langfuse(message, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, do_sample, generation_speed, outputs) |
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, show_copy_button=True, height="5%", rtl=True) |
|
chat_input = gr.Textbox(show_label=False, lines=2, rtl=True, placeholder="ورودی", show_copy_button=True, scale=4) |
|
submit_btn = gr.Button(variant="primary", value="ارسال", size="sm", scale=1, elem_classes=["_button"]) |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
fn=generate, |
|
additional_inputs_accordion=gr.Accordion(label="ورودیهای اضافی", open=False), |
|
additional_inputs=[ |
|
gr.Slider( |
|
label="حداکثر تعداد توکن ها", |
|
minimum=1, |
|
maximum=MAX_MAX_NEW_TOKENS, |
|
step=1, |
|
value=DEFAULT_MAX_NEW_TOKENS, |
|
), |
|
gr.Slider( |
|
label="Temperature", |
|
minimum=0.01, |
|
maximum=4.0, |
|
step=0.01, |
|
value=0.5, |
|
), |
|
gr.Slider( |
|
label="Top-p", |
|
minimum=0.05, |
|
maximum=1.0, |
|
step=0.01, |
|
value=0.9, |
|
), |
|
gr.Slider( |
|
label="Top-k", |
|
minimum=1, |
|
maximum=1000, |
|
step=1, |
|
value=20, |
|
), |
|
gr.Slider( |
|
label="جریمه تکرار", |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
value=1.2, |
|
), |
|
gr.Dropdown( |
|
label="نمونهگیری", |
|
choices=[False, True], |
|
value=True) |
|
], |
|
stop_btn="توقف", |
|
chatbot=chatbot, |
|
textbox=chat_input, |
|
submit_btn=submit_btn, |
|
retry_btn="🔄 تلاش مجدد", |
|
undo_btn="↩️ بازگشت", |
|
clear_btn="🗑️ پاک کردن", |
|
title="تست llama3" |
|
) |
|
|
|
|
|
with gr.Blocks(css=custom_css, fill_height=False) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
chat_interface.render() |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch() |
|
|