import gc import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login import os # Load Hugging Face token HF_TOKEN = os.getenv("HF_TOKEN") login(token=HF_TOKEN) # Define models MODELS = { "atlas-flash-1215": { "name": "๐Ÿฆ Atlas-Flash 1215", "sizes": { "1.5B": "Spestly/Atlas-Flash-1.5B-Preview", }, "emoji": "๐Ÿฆ", "experimental": True, "is_vision": False, "system_prompt_env": "ATLAS_FLASH_1215", }, "atlas-pro-0403": { "name": "๐Ÿ† Atlas-Pro 0403", "sizes": { "1.5B": "Spestly/Atlas-Pro-1.5B-Preview", }, "emoji": "๐Ÿ†", "experimental": True, "is_vision": False, "system_prompt_env": "ATLAS_PRO_0403", }, } # Clear memory def clear_memory(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Load model def load_model(model_key, model_size): try: clear_memory() # Unload previous model if any global current_model if current_model is not None: del current_model["model"] del current_model["tokenizer"] clear_memory() model_path = MODELS[model_key]["sizes"][model_size] tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True, low_cpu_mem_usage=True ) current_model.update({ "tokenizer": tokenizer, "model": model, "config": { "name": f"{MODELS[model_key]['name']} {model_size}", "path": model_path, "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"), } }) return f"โœ… {MODELS[model_key]['name']} {model_size} loaded successfully!" except Exception as e: return f"โŒ Error: {str(e)}" # Respond to input def respond(prompt, max_tokens, temperature, top_p, top_k): if not current_model["model"] or not current_model["tokenizer"]: return "โš ๏ธ Please select and load a model first" try: system_prompt = current_model["config"]["system_prompt"] if not system_prompt: return "โš ๏ธ System prompt not found for the selected model." full_prompt = f"{system_prompt}\n\n### Instruction:\n{prompt}\n\n### Response:" inputs = current_model["tokenizer"]( full_prompt, return_tensors="pt", max_length=512, truncation=True, padding=True ) with torch.no_grad(): output = current_model["model"].generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=current_model["tokenizer"].pad_token_id, eos_token_id=current_model["tokenizer"].eos_token_id, ) response = current_model["tokenizer"].decode(output[0], skip_special_tokens=True) if full_prompt in response: response = response.replace(full_prompt, "").strip() return response except Exception as e: return f"โš ๏ธ Generation Error: {str(e)}" finally: clear_memory() # Initialize model storage current_model = {"tokenizer": None, "model": None, "config": None} # UI for Gradio def gradio_ui(): def load_and_set_model(model_key, model_size): return load_model(model_key, model_size) with gr.Blocks() as app: gr.Markdown("## ๐Ÿฆ Atlas Inference Platform - Experimental ๐Ÿงช") with gr.Row(): model_key_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Select Model Variant", interactive=True ) model_size_dropdown = gr.Dropdown( choices=list(MODELS[list(MODELS.keys())[0]]["sizes"].keys()), value="1.5B", label="Select Model Size", interactive=True ) load_button = gr.Button("Load Model") load_status = gr.Textbox(label="Model Load Status", interactive=False) load_button.click( load_and_set_model, inputs=[model_key_dropdown, model_size_dropdown], outputs=load_status, ) with gr.Row(): prompt_input = gr.Textbox(label="Input Prompt", lines=4) max_tokens_slider = gr.Slider(10, 512, value=256, step=10, label="Max Tokens") temperature_slider = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature") top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-P") top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top-K") generate_button = gr.Button("Generate Response") response_output = gr.Textbox(label="Model Response", lines=6, interactive=False) generate_button.click( respond, inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider], outputs=response_output, ) return app if __name__ == "__main__": gradio_ui().launch()