|
import gradio as gr |
|
import spaces |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
import torch |
|
from threading import Thread |
|
|
|
phi4_model_path = "microsoft/phi-4" |
|
phi4_mini_model_path = "microsoft/Phi-4-mini-instruct" |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, torch_dtype="auto").to(device) |
|
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) |
|
phi4_mini_model = AutoModelForCausalLM.from_pretrained(phi4_mini_model_path, torch_dtype="auto").to(device) |
|
phi4_mini_tokenizer = AutoTokenizer.from_pretrained(phi4_mini_model_path) |
|
|
|
@spaces.GPU(duration=60) |
|
def generate_response(user_message, model_name, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): |
|
if not user_message.strip(): |
|
return history_state, history_state |
|
|
|
|
|
if model_name == "Phi-4": |
|
model = phi4_model |
|
tokenizer = phi4_tokenizer |
|
start_tag = "<|im_start|>" |
|
sep_tag = "<|im_sep|>" |
|
end_tag = "<|im_end|>" |
|
elif model_name == "Phi-4-mini-instruct": |
|
model = phi4_mini_model |
|
tokenizer = phi4_mini_tokenizer |
|
start_tag = "" |
|
sep_tag = "" |
|
end_tag = "<|end|>" |
|
else: |
|
raise ValueError("Error loading on models") |
|
|
|
|
|
system_message = "You are a friendly and knowledgeable assistant, here to help with any questions or tasks." |
|
if model_name == "Phi-4": |
|
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" |
|
for message in history_state: |
|
if message["role"] == "user": |
|
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}" |
|
elif message["role"] == "assistant" and message["content"]: |
|
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}" |
|
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" |
|
else: |
|
prompt = f"<|system|>{system_message}{end_tag}" |
|
for message in history_state: |
|
if message["role"] == "user": |
|
prompt += f"<|user|>{message['content']}{end_tag}" |
|
elif message["role"] == "assistant" and message["content"]: |
|
prompt += f"<|assistant|>{message['content']}{end_tag}" |
|
prompt += f"<|user|>{user_message}{end_tag}<|assistant|>" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
do_sample = not (temperature == 1.0 and top_k >= 100 and top_p == 1.0) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) |
|
|
|
|
|
generation_kwargs = { |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"], |
|
"max_new_tokens": int(max_tokens), |
|
"do_sample": do_sample, |
|
"temperature": temperature, |
|
"top_k": int(top_k), |
|
"top_p": top_p, |
|
"repetition_penalty": repetition_penalty, |
|
"streamer": streamer, |
|
} |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
assistant_response = "" |
|
new_history = history_state + [ |
|
{"role": "user", "content": user_message}, |
|
{"role": "assistant", "content": ""} |
|
] |
|
for new_token in streamer: |
|
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "").replace("<|end|>", "").replace("<|system|>", "").replace("<|user|>", "").replace("<|assistant|>", "") |
|
assistant_response += cleaned_token |
|
new_history[-1]["content"] = assistant_response.strip() |
|
yield new_history, new_history |
|
|
|
yield new_history, new_history |
|
|
|
example_messages = { |
|
"Learn about physics": "Explain Newton’s laws of motion.", |
|
"Discover space facts": "What are some interesting facts about black holes?", |
|
"Write a factorial function": "Write a Python function to calculate the factorial of a number." |
|
} |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# Phi-4 Models Chatbot |
|
Welcome to the Phi-4 Chatbot! You can chat with Microsoft's Phi-4 or Phi-4-mini-instruct models. Adjust the settings on the left to customize the model's responses. |
|
""" |
|
) |
|
|
|
history_state = gr.State([]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Settings") |
|
model_dropdown = gr.Dropdown( |
|
choices=["Phi-4", "Phi-4-mini-instruct"], |
|
label="Select Model", |
|
value="Phi-4" |
|
) |
|
max_tokens_slider = gr.Slider( |
|
minimum=64, |
|
maximum=4096, |
|
step=50, |
|
value=512, |
|
label="Max Tokens" |
|
) |
|
with gr.Accordion("Advanced Settings", open=False): |
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1.0, |
|
label="Temperature" |
|
) |
|
top_k_slider = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=50, |
|
label="Top-k" |
|
) |
|
top_p_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
label="Top-p" |
|
) |
|
repetition_penalty_slider = gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.0, |
|
label="Repetition Penalty" |
|
) |
|
|
|
with gr.Column(scale=4): |
|
chatbot = gr.Chatbot(label="Chat", type="messages") |
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
label="Your message", |
|
placeholder="Type your message here...", |
|
scale=3 |
|
) |
|
submit_button = gr.Button("Send", variant="primary", scale=1) |
|
clear_button = gr.Button("Clear", scale=1) |
|
gr.Markdown("**Try these examples:**") |
|
with gr.Row(): |
|
example1_button = gr.Button("Learn about physics") |
|
example2_button = gr.Button("Discover space facts") |
|
example3_button = gr.Button("Write a factorial function") |
|
|
|
submit_button.click( |
|
fn=generate_response, |
|
inputs=[user_input, model_dropdown, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], |
|
outputs=[chatbot, history_state] |
|
).then( |
|
fn=lambda: gr.update(value=""), |
|
inputs=None, |
|
outputs=user_input |
|
) |
|
|
|
clear_button.click( |
|
fn=lambda: ([], []), |
|
inputs=None, |
|
outputs=[chatbot, history_state] |
|
) |
|
|
|
example1_button.click( |
|
fn=lambda: gr.update(value=example_messages["Learn about physics"]), |
|
inputs=None, |
|
outputs=user_input |
|
) |
|
example2_button.click( |
|
fn=lambda: gr.update(value=example_messages["Discover space facts"]), |
|
inputs=None, |
|
outputs=user_input |
|
) |
|
example3_button.click( |
|
fn=lambda: gr.update(value=example_messages["Write a factorial function"]), |
|
inputs=None, |
|
outputs=user_input |
|
) |
|
|
|
demo.launch(ssr_mode=False) |