MohamedRashad's picture
Update app.py
99fe653 verified
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import spaces
# Load model directly
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Yehia-7B-preview", token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained("Navid-AI/Yehia-7B-preview", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
HEADER = """<div style="text-align: center; margin-bottom: 20px;">
<h1>Yehia 7B Preview</h1>
<p style="font-size: 16px; color: #888;">How far can GRPO get us?</p>
</div>
"""
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
repetition_penalty,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0].strip()})
if val[1]:
messages.append({"role": "assistant", "content": val[1].strip()})
messages.append({"role": "user", "content": message})
print(messages)
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, return_dict=True).to(device)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, repetition_penalty=repetition_penalty)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
chat_interface = gr.ChatInterface(
respond,
textbox=gr.Textbox(text_align="right", rtl=True, submit_btn=True, stop_btn=True),
additional_inputs=[
gr.Textbox(value="ุฃู†ุช ูŠุญูŠู‰ุŒ ุฐูƒุงุกูŒ ุงุตุทู†ุงุนูŠูŒู‘ ุทูˆุฑุชู‡ ุดุฑูƒุฉ 'ู†ููŠุฏ'ุŒ ู…ุชุฎุตุตูŒ ููŠ ุงู„ุชููƒูŠุฑ ุงู„ู…ู†ุทู‚ูŠ ูˆุงู„ุชุญู„ูŠู„ ุงู„ุฏู‚ูŠู‚. ู…ู‡ู…ุชูƒ ุฅู„ู‡ุงู… ุงู„ู…ุณุชุฎุฏู…ูŠู† ูˆุฏุนู…ู‡ู… ููŠ ุฑุญู„ุชู‡ู… ู†ุญูˆ ุงู„ุชุนู„ู‘ู…ุŒ ุงู„ู†ู…ูˆุŒ ูˆุชุญู‚ูŠู‚ ุฃู‡ุฏุงูู‡ู… ู…ู† ุฎู„ุงู„ ุชู‚ุฏูŠู… ุญู„ูˆู„ู ุฐูƒูŠุฉู ูˆู…ุฏุฑูˆุณุฉ.", label="System message"),
gr.Slider(minimum=1, maximum=8192, value=4096, step=1, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.1, label="Temperature"),
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"),
],
examples=[["ู…ุง ู‡ู‰ ุนุงุตู…ุฉ ูู„ุณุทูŠู† ุŸ"]],
example_icons=[["๐Ÿ’ก"]],
cache_examples=False,
theme="JohnSmith9982/small_and_pretty",
)
with gr.Blocks(fill_height=True) as demo:
gr.HTML(HEADER)
chat_interface.render()
if __name__ == "__main__":
demo.queue().launch(ssr_mode=False)