Spaces:
Running
Running
import os | |
import gradio as gr | |
from typing import Iterator | |
from dialog import get_dialog_box | |
from gateway import check_server_health, request_generation | |
# CONSTANTS | |
MAX_NEW_TOKENS: int = 2048 | |
# GET ENVIRONMENT VARIABLES | |
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") | |
def toggle_ui(): | |
""" | |
Function to toggle the visibility of the UI based on the server health | |
Returns: | |
hide/show main ui/dialog | |
""" | |
health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API) | |
if health: | |
return gr.update(visible=True), gr.update(visible=False) # Show main UI, hide dialog | |
else: | |
return gr.update(visible=False), gr.update(visible=True) # Hide main UI, show dialog | |
def generate( | |
message: str, | |
chat_history: list, | |
system_prompt: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
"""Send a request to backend, fetch the streaming responses and emit to the UI. | |
Args: | |
message (str): input message from the user | |
chat_history (list[tuple[str, str]]): entire chat history of the session | |
system_prompt (str): system prompt | |
max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the | |
prompt. Defaults to 1024. | |
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. | |
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities | |
that add up to top_p or higher are kept for generation. Defaults to 0.9. | |
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. | |
Defaults to 50. | |
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. | |
Defaults to 1.2. | |
Yields: | |
Iterator[str]: Streaming responses to the UI | |
""" | |
# sample method to yield responses from the llm model | |
outputs = [] | |
for text in request_generation(message=message, | |
system_prompt=system_prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
cloud_gateway_api=CLOUD_GATEWAY_API): | |
outputs.append(text) | |
yield "".join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Textbox(label="System prompt", lines=6), | |
gr.Slider( | |
label="Max New Tokens", | |
minimum=1, | |
maximum=MAX_NEW_TOKENS, | |
step=1, | |
value=1024, | |
), | |
gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.1, | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
value=0.95, | |
), | |
gr.Slider( | |
label="Top-k", | |
minimum=1, | |
maximum=1000, | |
step=1, | |
value=50, | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.2, | |
), | |
], | |
stop_btn=None, | |
examples=[ | |
["Hello there! How are you doing?"], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Explain the plot of Cinderella in a sentence."], | |
["How many hours does it take a man to eat a Helicopter?"], | |
["Write a 100-word article on 'Benefits of Open-Source in AI research'."], | |
], | |
cache_examples=False, | |
) | |
with gr.Blocks(css="style.css", fill_height=True) as demo: | |
# Get the server status before displaying UI | |
visibility = check_server_health(CLOUD_GATEWAY_API) | |
# Container for the main interface | |
with gr.Column(visible=visibility, elem_id="main_ui") as main_ui: | |
gr.Markdown(f""" | |
# Llama-3 8B Chat | |
This Space is an Alpha release that demonstrates [Meta-Llama-3-8B-Instruct](https://huggingface.co./meta-llama/Meta-Llama-3-8B-Instruct) model running on AMD MI210 infrastructure. The space is built with Meta Llama 3 [License](https://www.llama.com/llama3/license/). Feel free to play with it! | |
""") | |
chat_interface.render() | |
# Dialog box using Markdown for the error message | |
with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box: | |
# Add spinner and message | |
get_dialog_box() | |
# Timer to check server health every 5 seconds and update UI | |
timer = gr.Timer(value=10) | |
timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box]) | |
if __name__ == "__main__": | |
demo.queue(max_size=int(os.getenv("QUEUE")), default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT"))).launch() | |