Spaces:
Sleeping
Sleeping
import gc | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
import os | |
# Load Hugging Face token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
login(token=HF_TOKEN) | |
# Define models | |
MODELS = { | |
"atlas-flash-1215": { | |
"name": "π¦ Atlas-Flash 1215", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview", | |
}, | |
"emoji": "π¦", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_FLASH_1215", | |
}, | |
"atlas-pro-0403": { | |
"name": "π Atlas-Pro 0403", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview", | |
}, | |
"emoji": "π", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_PRO_0403", | |
}, | |
} | |
# Clear memory | |
def clear_memory(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Load model | |
def load_model(model_key, model_size): | |
try: | |
clear_memory() | |
# Unload previous model if any | |
global current_model | |
if current_model is not None: | |
del current_model["model"] | |
del current_model["tokenizer"] | |
clear_memory() | |
model_path = MODELS[model_key]["sizes"][model_size] | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
current_model.update({ | |
"tokenizer": tokenizer, | |
"model": model, | |
"config": { | |
"name": f"{MODELS[model_key]['name']} {model_size}", | |
"path": model_path, | |
"system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"), | |
} | |
}) | |
return f"β {MODELS[model_key]['name']} {model_size} loaded successfully!" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
# Respond to input | |
def respond(prompt, max_tokens, temperature, top_p, top_k): | |
if not current_model["model"] or not current_model["tokenizer"]: | |
return "β οΈ Please select and load a model first" | |
try: | |
system_prompt = current_model["config"]["system_prompt"] | |
if not system_prompt: | |
return "β οΈ System prompt not found for the selected model." | |
full_prompt = f"{system_prompt}\n\n### Instruction:\n{prompt}\n\n### Response:" | |
inputs = current_model["tokenizer"]( | |
full_prompt, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
with torch.no_grad(): | |
output = current_model["model"].generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
do_sample=True, | |
pad_token_id=current_model["tokenizer"].pad_token_id, | |
eos_token_id=current_model["tokenizer"].eos_token_id, | |
) | |
response = current_model["tokenizer"].decode(output[0], skip_special_tokens=True) | |
if full_prompt in response: | |
response = response.replace(full_prompt, "").strip() | |
return response | |
except Exception as e: | |
return f"β οΈ Generation Error: {str(e)}" | |
finally: | |
clear_memory() | |
# Initialize model storage | |
current_model = {"tokenizer": None, "model": None, "config": None} | |
# UI for Gradio | |
def gradio_ui(): | |
def load_and_set_model(model_key, model_size): | |
return load_model(model_key, model_size) | |
with gr.Blocks() as app: | |
gr.Markdown("## π¦ Atlas Inference Platform - Experimental π§ͺ") | |
with gr.Row(): | |
model_key_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0], | |
label="Select Model Variant", | |
interactive=True | |
) | |
model_size_dropdown = gr.Dropdown( | |
choices=list(MODELS[list(MODELS.keys())[0]]["sizes"].keys()), | |
value="1.5B", | |
label="Select Model Size", | |
interactive=True | |
) | |
load_button = gr.Button("Load Model") | |
load_status = gr.Textbox(label="Model Load Status", interactive=False) | |
load_button.click( | |
load_and_set_model, | |
inputs=[model_key_dropdown, model_size_dropdown], | |
outputs=load_status, | |
) | |
with gr.Row(): | |
prompt_input = gr.Textbox(label="Input Prompt", lines=4) | |
max_tokens_slider = gr.Slider(10, 512, value=256, step=10, label="Max Tokens") | |
temperature_slider = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature") | |
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-P") | |
top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top-K") | |
generate_button = gr.Button("Generate Response") | |
response_output = gr.Textbox(label="Model Response", lines=6, interactive=False) | |
generate_button.click( | |
respond, | |
inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider], | |
outputs=response_output, | |
) | |
return app | |
if __name__ == "__main__": | |
gradio_ui().launch() | |