atelagim's picture
fix UI scaling
7c5b12d
import os
import gradio as gr
from typing import Iterator
from dialog import get_dialog_box
from gateway import check_server_health, request_generation
# CONSTANTS
MAX_NEW_TOKENS: int = 2048
# GET ENVIRONMENT VARIABLES
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
def toggle_ui():
"""
Function to toggle the visibility of the UI based on the server health
Returns:
hide/show main ui/dialog
"""
health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API)
if health:
return gr.update(visible=True), gr.update(visible=False) # Show main UI, hide dialog
else:
return gr.update(visible=False), gr.update(visible=True) # Hide main UI, show dialog
def generate(
message: str,
chat_history: list,
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
"""Send a request to backend, fetch the streaming responses and emit to the UI.
Args:
message (str): input message from the user
chat_history (list[tuple[str, str]]): entire chat history of the session
system_prompt (str): system prompt
max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the
prompt. Defaults to 1024.
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation. Defaults to 0.9.
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
Defaults to 50.
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
Defaults to 1.2.
Yields:
Iterator[str]: Streaming responses to the UI
"""
# sample method to yield responses from the llm model
outputs = []
for text in request_generation(message=message,
system_prompt=system_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
cloud_gateway_api=CLOUD_GATEWAY_API):
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max New Tokens",
minimum=1,
maximum=MAX_NEW_TOKENS,
step=1,
value=1024,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.1,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'."],
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True) as demo:
# Get the server status before displaying UI
visibility = check_server_health(CLOUD_GATEWAY_API)
# Container for the main interface
with gr.Column(visible=visibility, elem_id="main_ui") as main_ui:
gr.Markdown(f"""
# Llama-3 8B Chat
This Space is an Alpha release that demonstrates [Meta-Llama-3-8B-Instruct](https://huggingface.co./meta-llama/Meta-Llama-3-8B-Instruct) model running on AMD MI210 infrastructure. The space is built with Meta Llama 3 [License](https://www.llama.com/llama3/license/). Feel free to play with it!
""")
chat_interface.render()
# Dialog box using Markdown for the error message
with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box:
# Add spinner and message
get_dialog_box()
# Timer to check server health every 5 seconds and update UI
timer = gr.Timer(value=10)
timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box])
if __name__ == "__main__":
demo.queue(max_size=int(os.getenv("QUEUE")), default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT"))).launch()