Spestly commited on
Commit
7e978bb
Β·
verified Β·
1 Parent(s): fdc055b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ import streamlit as st
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import re
6
+ import os
7
+
8
+ MODELS = {
9
+ "athena-1": {
10
+ "name": "🦁 Athena-Flash",
11
+ "sizes": {
12
+ "1.5B": "Spestly/Atlas-R1-1.5B-Preview",
13
+ },
14
+ "emoji": "🦁",
15
+ "experimental": True,
16
+ },
17
+ }
18
+
19
+ class AtlasInferenceApp:
20
+ def __init__(self):
21
+ if "current_model" not in st.session_state:
22
+ st.session_state.current_model = {"tokenizer": None, "model": None, "config": None}
23
+ if "chat_history" not in st.session_state:
24
+ st.session_state.chat_history = []
25
+
26
+ st.set_page_config(
27
+ page_title="Atlas Model Inference",
28
+ page_icon="🦁 ",
29
+ layout="wide",
30
+ menu_items={
31
+ 'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86',
32
+ 'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new',
33
+ 'About': 'Athena Model Inference Platform'
34
+ }
35
+ )
36
+
37
+ def clear_memory(self):
38
+ """Optimize memory management for CPU inference"""
39
+ if torch.cuda.is_available():
40
+ torch.cuda.empty_cache()
41
+ gc.collect()
42
+
43
+ def load_model(self, model_key, model_size):
44
+ try:
45
+ self.clear_memory()
46
+
47
+ if st.session_state.current_model["model"] is not None:
48
+ del st.session_state.current_model["model"]
49
+ del st.session_state.current_model["tokenizer"]
50
+ self.clear_memory()
51
+
52
+ model_path = MODELS[model_key]["sizes"][model_size]
53
+
54
+ # Load Qwen-compatible tokenizer and model
55
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ model_path,
58
+ device_map="cpu", # Force CPU usage
59
+ torch_dtype=torch.float32, # Use float32 for CPU
60
+ trust_remote_code=True,
61
+ low_cpu_mem_usage=True
62
+ )
63
+
64
+ # Update session state
65
+ st.session_state.current_model.update({
66
+ "tokenizer": tokenizer,
67
+ "model": model,
68
+ "config": {
69
+ "name": f"{MODELS[model_key]['name']} {model_size}",
70
+ "path": model_path,
71
+ }
72
+ })
73
+ return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
74
+ except Exception as e:
75
+ return f"❌ Error: {str(e)}"
76
+
77
+ def respond(self, message, max_tokens, temperature, top_p, top_k):
78
+ if not st.session_state.current_model["model"]:
79
+ return "⚠️ Please select and load a model first"
80
+
81
+ try:
82
+ # Add a system instruction to guide the model's behavior
83
+ system_instruction = "You are Atlas, a helpful AI assistant trained by Spestly. You are a Deepseek R1 fine-tune."
84
+ prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
85
+
86
+ inputs = st.session_state.current_model["tokenizer"](
87
+ prompt,
88
+ return_tensors="pt",
89
+ max_length=512,
90
+ truncation=True,
91
+ padding=True
92
+ )
93
+
94
+ with torch.no_grad():
95
+ output = st.session_state.current_model["model"].generate(
96
+ input_ids=inputs.input_ids,
97
+ attention_mask=inputs.attention_mask,
98
+ max_new_tokens=max_tokens,
99
+ temperature=temperature,
100
+ top_p=top_p,
101
+ top_k=top_k,
102
+ do_sample=True,
103
+ pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
104
+ eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
105
+ )
106
+ response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
107
+ return response.split("### Response:")[-1].strip() # Extract the response
108
+ except Exception as e:
109
+ return f"⚠️ Generation Error: {str(e)}"
110
+ finally:
111
+ self.clear_memory()
112
+
113
+ def main(self):
114
+ st.title("🦁 AtlasUI - Experimental πŸ§ͺ")
115
+
116
+ with st.sidebar:
117
+ st.header("πŸ›  Model Selection")
118
+
119
+ model_key = st.selectbox(
120
+ "Choose Atlas Variant",
121
+ list(MODELS.keys()),
122
+ format_func=lambda x: f"{MODELS[x]['name']} {'πŸ§ͺ' if MODELS[x]['experimental'] else ''}"
123
+ )
124
+
125
+ model_size = st.selectbox(
126
+ "Choose Model Size",
127
+ list(MODELS[model_key]["sizes"].keys())
128
+ )
129
+
130
+ if st.button("Load Model"):
131
+ with st.spinner("Loading model... This may take a few minutes."):
132
+ status = self.load_model(model_key, model_size)
133
+ st.success(status)
134
+
135
+ st.header("πŸ”§ Generation Parameters")
136
+ max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10)
137
+ temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1)
138
+ top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1)
139
+ top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1)
140
+
141
+ if st.button("Clear Chat History"):
142
+ st.session_state.chat_history = []
143
+ st.rerun()
144
+
145
+ st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
146
+
147
+ for message in st.session_state.chat_history:
148
+ with st.chat_message(message["role"]):
149
+ st.markdown(message["content"])
150
+
151
+ if prompt := st.chat_input("Message Atlas..."):
152
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
153
+ with st.chat_message("user"):
154
+ st.markdown(prompt)
155
+
156
+ with st.chat_message("assistant"):
157
+ with st.spinner("Generating response..."):
158
+ response = self.respond(prompt, max_tokens, temperature, top_p, top_k)
159
+ st.markdown(response)
160
+
161
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
162
+
163
+ def run():
164
+ try:
165
+ app = AtlasInferenceApp()
166
+ app.main()
167
+ except Exception as e:
168
+ st.error(f"⚠️ Application Error: {str(e)}")
169
+
170
+ if __name__ == "__main__":
171
+ run()