import gc import torch import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login import os # Load Hugging Face token HF_TOKEN = os.getenv("HF_TOKEN") login(token=HF_TOKEN) # Define models MODELS = { "atlas-flash-1215": { "name": "๐Ÿฆ Atlas-Flash 1215", "sizes": { "1.5B": "Spestly/Atlas-Flash-1.5B-Preview", }, "emoji": "๐Ÿฆ", "experimental": True, "is_vision": False, "system_prompt_env": "ATLAS_FLASH_1215", }, "atlas-pro-0403": { "name": "๐Ÿ† Atlas-Pro 0403", "sizes": { "1.5B": "Spestly/Atlas-Pro-1.5B-Preview", }, "emoji": "๐Ÿ†", "experimental": True, "is_vision": False, "system_prompt_env": "ATLAS_PRO_0403", }, } # Profile pictures USER_PFP = "user.png" AI_PFP = "ai_pfp.png" st.set_page_config( page_title="Atlas Model Inference", page_icon="๐Ÿฆ ", layout="wide", menu_items={ 'Get Help': 'https://huggingface.co./collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86', 'Report a bug': 'https://huggingface.co./Spestly/Athena-1-1.5B/discussions/new', 'About': 'Athena Model Inference Platform' } ) st.markdown( """ """, unsafe_allow_html=True, ) class AtlasInferenceApp: def __init__(self): if "current_model" not in st.session_state: st.session_state.current_model = {"tokenizer": None, "model": None, "config": None} if "chat_history" not in st.session_state: st.session_state.chat_history = [] def clear_memory(self): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def load_model(self, model_key, model_size): try: self.clear_memory() if st.session_state.current_model["model"] is not None: del st.session_state.current_model["model"] del st.session_state.current_model["tokenizer"] self.clear_memory() model_path = MODELS[model_key]["sizes"][model_size] tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True, low_cpu_mem_usage=True ) st.session_state.current_model.update({ "tokenizer": tokenizer, "model": model, "config": { "name": f"{MODELS[model_key]['name']} {model_size}", "path": model_path, "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"), } }) return f"โœ… {MODELS[model_key]['name']} {model_size} loaded successfully!" except Exception as e: return f"โŒ Error: {str(e)}" def respond(self, message, max_tokens, temperature, top_p, top_k, image=None): if not st.session_state.current_model["model"] or not st.session_state.current_model["tokenizer"]: return "โš ๏ธ Please select and load a model first" try: system_prompt = st.session_state.current_model["config"]["system_prompt"] if not system_prompt: return "โš ๏ธ System prompt not found for the selected model." prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:" inputs = st.session_state.current_model["tokenizer"]( prompt, return_tensors="pt", max_length=512, truncation=True, padding=True ) with torch.no_grad(): output = st.session_state.current_model["model"].generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id, eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id, ) response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True) if prompt in response: response = response.replace(prompt, "").strip() return response except Exception as e: return f"โš ๏ธ Generation Error: {str(e)}" finally: self.clear_memory() def main(self): st.title("๐Ÿฆ AtlasUI - Experimental ๐Ÿงช") with st.sidebar: st.header("๐Ÿ›  Model Selection") model_key = st.selectbox( "Choose Atlas Variant", list(MODELS.keys()), format_func=lambda x: f"{MODELS[x]['name']} {'๐Ÿงช' if MODELS[x]['experimental'] else ''}" ) model_size = st.selectbox( "Choose Model Size", list(MODELS[model_key]["sizes"].keys()) ) if st.button("Load Model"): with st.spinner("Loading model... This may take a few minutes."): status = self.load_model(model_key, model_size) st.success(status) st.header("๐Ÿ”ง Generation Parameters") max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10) temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1) top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1) top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1) if st.button("Clear Chat History"): st.session_state.chat_history = [] st.rerun() st.markdown("*โš ๏ธ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*") for message in st.session_state.chat_history: with st.chat_message( message["role"], avatar=USER_PFP if message["role"] == "user" else AI_PFP ): st.markdown(message["content"]) if "image" in message and message["image"]: st.image(message["image"], caption="Uploaded Image", use_column_width=True) if prompt := st.chat_input("Message Atlas..."): uploaded_image = None if MODELS[model_key]["is_vision"]: uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) st.session_state.chat_history.append({"role": "user", "content": prompt, "image": uploaded_image}) with st.chat_message("user", avatar=USER_PFP): st.markdown(prompt) if uploaded_image: st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) with st.chat_message("assistant", avatar=AI_PFP): with st.spinner("Generating response..."): response = self.respond(prompt, max_tokens, temperature, top_p, top_k, image=uploaded_image) st.markdown(response) st.session_state.chat_history.append({"role": "assistant", "content": response}) def run(): try: app = AtlasInferenceApp() app.main() except Exception as e: st.error(f"โš ๏ธ Application Error: {str(e)}") if __name__ == "__main__": run()