Spestly commited on
Commit
1b2a38a
Β·
verified Β·
1 Parent(s): a1a1a93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -185
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gc
2
  import torch
3
- import streamlit as st
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from huggingface_hub import login
6
  import os
@@ -21,7 +21,7 @@ MODELS = {
21
  "is_vision": False,
22
  "system_prompt_env": "ATLAS_FLASH_1215",
23
  },
24
- "atlas-pro-0403": {
25
  "name": "πŸ† Atlas-Pro 0403",
26
  "sizes": {
27
  "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
@@ -33,193 +33,141 @@ MODELS = {
33
  },
34
  }
35
 
36
- # Profile pictures
37
- USER_PFP = "user.png"
38
- AI_PFP = "ai_pfp.png"
39
-
40
- st.set_page_config(
41
- page_title="Atlas Model Inference",
42
- page_icon="🦁 ",
43
- layout="wide",
44
- menu_items={
45
- 'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86',
46
- 'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new',
47
- 'About': 'Athena Model Inference Platform'
48
- }
49
- )
50
-
51
- st.markdown(
52
- """
53
- <style>
54
- .stSlider > div > div > div > div {
55
- background-color: #1f78b4 !important;
56
- }
57
- .stButton > button {
58
- background-color: #1f78b4 !important;
59
- color: white !important;
60
- border: none !important;
61
- }
62
- .stButton > button:hover {
63
- background-color: #16609a !important;
64
- }
65
- </style>
66
- """,
67
- unsafe_allow_html=True,
68
- )
69
-
70
- class AtlasInferenceApp:
71
- def __init__(self):
72
- if "current_model" not in st.session_state:
73
- st.session_state.current_model = {"tokenizer": None, "model": None, "config": None}
74
- if "chat_history" not in st.session_state:
75
- st.session_state.chat_history = []
76
-
77
- def clear_memory(self):
78
- if torch.cuda.is_available():
79
- torch.cuda.empty_cache()
80
- gc.collect()
81
-
82
- def load_model(self, model_key, model_size):
83
- try:
84
- self.clear_memory()
85
-
86
- if st.session_state.current_model["model"] is not None:
87
- del st.session_state.current_model["model"]
88
- del st.session_state.current_model["tokenizer"]
89
- self.clear_memory()
90
-
91
- model_path = MODELS[model_key]["sizes"][model_size]
92
-
93
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
94
- model = AutoModelForCausalLM.from_pretrained(
95
- model_path,
96
- device_map="auto",
97
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
98
- trust_remote_code=True,
99
- low_cpu_mem_usage=True
100
- )
101
 
102
- st.session_state.current_model.update({
103
- "tokenizer": tokenizer,
104
- "model": model,
105
- "config": {
106
- "name": f"{MODELS[model_key]['name']} {model_size}",
107
- "path": model_path,
108
- "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
109
- }
110
- })
111
- return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
112
- except Exception as e:
113
- return f"❌ Error: {str(e)}"
114
-
115
- def respond(self, message, max_tokens, temperature, top_p, top_k, image=None):
116
- if not st.session_state.current_model["model"] or not st.session_state.current_model["tokenizer"]:
117
- return "⚠️ Please select and load a model first"
118
-
119
- try:
120
- system_prompt = st.session_state.current_model["config"]["system_prompt"]
121
- if not system_prompt:
122
- return "⚠️ System prompt not found for the selected model."
123
-
124
- prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
125
-
126
- inputs = st.session_state.current_model["tokenizer"](
127
- prompt,
128
- return_tensors="pt",
129
- max_length=512,
130
- truncation=True,
131
- padding=True
132
- )
133
- with torch.no_grad():
134
- output = st.session_state.current_model["model"].generate(
135
- input_ids=inputs.input_ids,
136
- attention_mask=inputs.attention_mask,
137
- max_new_tokens=max_tokens,
138
- temperature=temperature,
139
- top_p=top_p,
140
- top_k=top_k,
141
- do_sample=True,
142
- pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
143
- eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
144
- )
145
- response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
146
-
147
- if prompt in response:
148
- response = response.replace(prompt, "").strip()
149
-
150
- return response
151
- except Exception as e:
152
- return f"⚠️ Generation Error: {str(e)}"
153
- finally:
154
- self.clear_memory()
155
-
156
- def main(self):
157
- st.title("🦁 AtlasUI - Experimental πŸ§ͺ")
158
-
159
- with st.sidebar:
160
- st.header("πŸ›  Model Selection")
161
-
162
- model_key = st.selectbox(
163
- "Choose Atlas Variant",
164
- list(MODELS.keys()),
165
- format_func=lambda x: f"{MODELS[x]['name']} {'πŸ§ͺ' if MODELS[x]['experimental'] else ''}"
166
- )
167
 
168
- model_size = st.selectbox(
169
- "Choose Model Size",
170
- list(MODELS[model_key]["sizes"].keys())
171
- )
172
 
173
- if st.button("Load Model"):
174
- with st.spinner("Loading model... This may take a few minutes."):
175
- status = self.load_model(model_key, model_size)
176
- st.success(status)
177
-
178
- st.header("πŸ”§ Generation Parameters")
179
- max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10)
180
- temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1)
181
- top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1)
182
- top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1)
183
-
184
- if st.button("Clear Chat History"):
185
- st.session_state.chat_history = []
186
- st.rerun()
187
-
188
- st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
189
-
190
- for message in st.session_state.chat_history:
191
- with st.chat_message(
192
- message["role"],
193
- avatar=USER_PFP if message["role"] == "user" else AI_PFP
194
- ):
195
- st.markdown(message["content"])
196
- if "image" in message and message["image"]:
197
- st.image(message["image"], caption="Uploaded Image", use_column_width=True)
198
-
199
- if prompt := st.chat_input("Message Atlas..."):
200
- uploaded_image = None
201
- if MODELS[model_key]["is_vision"]:
202
- uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
203
-
204
- st.session_state.chat_history.append({"role": "user", "content": prompt, "image": uploaded_image})
205
- with st.chat_message("user", avatar=USER_PFP):
206
- st.markdown(prompt)
207
- if uploaded_image:
208
- st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
209
-
210
- with st.chat_message("assistant", avatar=AI_PFP):
211
- with st.spinner("Generating response..."):
212
- response = self.respond(prompt, max_tokens, temperature, top_p, top_k, image=uploaded_image)
213
- st.markdown(response)
214
-
215
- st.session_state.chat_history.append({"role": "assistant", "content": response})
216
-
217
- def run():
218
  try:
219
- app = AtlasInferenceApp()
220
- app.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  except Exception as e:
222
- st.error(f"⚠️ Application Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  if __name__ == "__main__":
225
- run()
 
1
  import gc
2
  import torch
3
+ import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from huggingface_hub import login
6
  import os
 
21
  "is_vision": False,
22
  "system_prompt_env": "ATLAS_FLASH_1215",
23
  },
24
+ "atlas-pro-0403": {
25
  "name": "πŸ† Atlas-Pro 0403",
26
  "sizes": {
27
  "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
 
33
  },
34
  }
35
 
36
+ # Clear memory
37
+ def clear_memory():
38
+ if torch.cuda.is_available():
39
+ torch.cuda.empty_cache()
40
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Load model
43
+ def load_model(model_key, model_size):
44
+ try:
45
+ clear_memory()
46
+
47
+ # Unload previous model if any
48
+ global current_model
49
+ if current_model is not None:
50
+ del current_model["model"]
51
+ del current_model["tokenizer"]
52
+ clear_memory()
53
+
54
+ model_path = MODELS[model_key]["sizes"][model_size]
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ model_path,
59
+ device_map="auto",
60
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
+ trust_remote_code=True,
62
+ low_cpu_mem_usage=True
63
+ )
64
+
65
+ current_model.update({
66
+ "tokenizer": tokenizer,
67
+ "model": model,
68
+ "config": {
69
+ "name": f"{MODELS[model_key]['name']} {model_size}",
70
+ "path": model_path,
71
+ "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
72
+ }
73
+ })
74
+ return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
75
+ except Exception as e:
76
+ return f"❌ Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Respond to input
79
+ def respond(prompt, max_tokens, temperature, top_p, top_k):
80
+ if not current_model["model"] or not current_model["tokenizer"]:
81
+ return "⚠️ Please select and load a model first"
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
+ system_prompt = current_model["config"]["system_prompt"]
85
+ if not system_prompt:
86
+ return "⚠️ System prompt not found for the selected model."
87
+
88
+ full_prompt = f"{system_prompt}\n\n### Instruction:\n{prompt}\n\n### Response:"
89
+
90
+ inputs = current_model["tokenizer"](
91
+ full_prompt,
92
+ return_tensors="pt",
93
+ max_length=512,
94
+ truncation=True,
95
+ padding=True
96
+ )
97
+ with torch.no_grad():
98
+ output = current_model["model"].generate(
99
+ input_ids=inputs.input_ids,
100
+ attention_mask=inputs.attention_mask,
101
+ max_new_tokens=max_tokens,
102
+ temperature=temperature,
103
+ top_p=top_p,
104
+ top_k=top_k,
105
+ do_sample=True,
106
+ pad_token_id=current_model["tokenizer"].pad_token_id,
107
+ eos_token_id=current_model["tokenizer"].eos_token_id,
108
+ )
109
+ response = current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
110
+
111
+ if full_prompt in response:
112
+ response = response.replace(full_prompt, "").strip()
113
+
114
+ return response
115
  except Exception as e:
116
+ return f"⚠️ Generation Error: {str(e)}"
117
+ finally:
118
+ clear_memory()
119
+
120
+ # Initialize model storage
121
+ current_model = {"tokenizer": None, "model": None, "config": None}
122
+
123
+ # UI for Gradio
124
+ def gradio_ui():
125
+ def load_and_set_model(model_key, model_size):
126
+ return load_model(model_key, model_size)
127
+
128
+ with gr.Blocks() as app:
129
+ gr.Markdown("## 🦁 Atlas Inference Platform - Experimental πŸ§ͺ")
130
+
131
+ with gr.Row():
132
+ model_key_dropdown = gr.Dropdown(
133
+ choices=list(MODELS.keys()),
134
+ value=list(MODELS.keys())[0],
135
+ label="Select Model Variant",
136
+ interactive=True
137
+ )
138
+ model_size_dropdown = gr.Dropdown(
139
+ choices=list(MODELS[list(MODELS.keys())[0]]["sizes"].keys()),
140
+ value="1.5B",
141
+ label="Select Model Size",
142
+ interactive=True
143
+ )
144
+ load_button = gr.Button("Load Model")
145
+
146
+ load_status = gr.Textbox(label="Model Load Status", interactive=False)
147
+
148
+ load_button.click(
149
+ load_and_set_model,
150
+ inputs=[model_key_dropdown, model_size_dropdown],
151
+ outputs=load_status,
152
+ )
153
+
154
+ with gr.Row():
155
+ prompt_input = gr.Textbox(label="Input Prompt", lines=4)
156
+ max_tokens_slider = gr.Slider(10, 512, value=256, step=10, label="Max Tokens")
157
+ temperature_slider = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
158
+ top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-P")
159
+ top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top-K")
160
+
161
+ generate_button = gr.Button("Generate Response")
162
+ response_output = gr.Textbox(label="Model Response", lines=6, interactive=False)
163
+
164
+ generate_button.click(
165
+ respond,
166
+ inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider],
167
+ outputs=response_output,
168
+ )
169
+
170
+ return app
171
 
172
  if __name__ == "__main__":
173
+ gradio_ui().launch()