AtlasUI / app.py
Spestly's picture
Update app.py
1b2a38a verified
raw
history blame
5.74 kB
import gc
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)
# Define models
MODELS = {
"atlas-flash-1215": {
"name": "🦁 Atlas-Flash 1215",
"sizes": {
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
},
"emoji": "🦁",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_FLASH_1215",
},
"atlas-pro-0403": {
"name": "πŸ† Atlas-Pro 0403",
"sizes": {
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
},
"emoji": "πŸ†",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_PRO_0403",
},
}
# Clear memory
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Load model
def load_model(model_key, model_size):
try:
clear_memory()
# Unload previous model if any
global current_model
if current_model is not None:
del current_model["model"]
del current_model["tokenizer"]
clear_memory()
model_path = MODELS[model_key]["sizes"][model_size]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True
)
current_model.update({
"tokenizer": tokenizer,
"model": model,
"config": {
"name": f"{MODELS[model_key]['name']} {model_size}",
"path": model_path,
"system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
}
})
return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
except Exception as e:
return f"❌ Error: {str(e)}"
# Respond to input
def respond(prompt, max_tokens, temperature, top_p, top_k):
if not current_model["model"] or not current_model["tokenizer"]:
return "⚠️ Please select and load a model first"
try:
system_prompt = current_model["config"]["system_prompt"]
if not system_prompt:
return "⚠️ System prompt not found for the selected model."
full_prompt = f"{system_prompt}\n\n### Instruction:\n{prompt}\n\n### Response:"
inputs = current_model["tokenizer"](
full_prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
with torch.no_grad():
output = current_model["model"].generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=current_model["tokenizer"].pad_token_id,
eos_token_id=current_model["tokenizer"].eos_token_id,
)
response = current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
if full_prompt in response:
response = response.replace(full_prompt, "").strip()
return response
except Exception as e:
return f"⚠️ Generation Error: {str(e)}"
finally:
clear_memory()
# Initialize model storage
current_model = {"tokenizer": None, "model": None, "config": None}
# UI for Gradio
def gradio_ui():
def load_and_set_model(model_key, model_size):
return load_model(model_key, model_size)
with gr.Blocks() as app:
gr.Markdown("## 🦁 Atlas Inference Platform - Experimental πŸ§ͺ")
with gr.Row():
model_key_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model Variant",
interactive=True
)
model_size_dropdown = gr.Dropdown(
choices=list(MODELS[list(MODELS.keys())[0]]["sizes"].keys()),
value="1.5B",
label="Select Model Size",
interactive=True
)
load_button = gr.Button("Load Model")
load_status = gr.Textbox(label="Model Load Status", interactive=False)
load_button.click(
load_and_set_model,
inputs=[model_key_dropdown, model_size_dropdown],
outputs=load_status,
)
with gr.Row():
prompt_input = gr.Textbox(label="Input Prompt", lines=4)
max_tokens_slider = gr.Slider(10, 512, value=256, step=10, label="Max Tokens")
temperature_slider = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-P")
top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top-K")
generate_button = gr.Button("Generate Response")
response_output = gr.Textbox(label="Model Response", lines=6, interactive=False)
generate_button.click(
respond,
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider],
outputs=response_output,
)
return app
if __name__ == "__main__":
gradio_ui().launch()