Spestly commited on
Commit
8653d1f
·
verified ·
1 Parent(s): 2b609ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -16
app.py CHANGED
@@ -6,9 +6,11 @@ from huggingface_hub import login
6
  import re
7
  import os
8
 
 
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  login(token=HF_TOKEN)
11
 
 
12
  MODELS = {
13
  "athena-1": {
14
  "name": "🦁 Atlas-Flash",
@@ -20,9 +22,9 @@ MODELS = {
20
  },
21
  }
22
 
23
-
24
- USER_PFP = "user.png"
25
- AI_PFP = "ai_pfp.png"
26
 
27
  class AtlasInferenceApp:
28
  def __init__(self):
@@ -59,17 +61,17 @@ class AtlasInferenceApp:
59
 
60
  model_path = MODELS[model_key]["sizes"][model_size]
61
 
62
-
63
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
64
  model = AutoModelForCausalLM.from_pretrained(
65
  model_path,
66
- device_map="cpu",
67
- torch_dtype=torch.float32,
68
  trust_remote_code=True,
69
  low_cpu_mem_usage=True
70
  )
71
 
72
-
73
  st.session_state.current_model.update({
74
  "tokenizer": tokenizer,
75
  "model": model,
@@ -87,7 +89,7 @@ class AtlasInferenceApp:
87
  return "⚠️ Please select and load a model first"
88
 
89
  try:
90
-
91
  system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune."
92
  prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
93
 
@@ -99,8 +101,8 @@ class AtlasInferenceApp:
99
  padding=True
100
  )
101
 
102
-
103
- response_container = st.empty()
104
  full_response = ""
105
  with torch.no_grad():
106
  for chunk in st.session_state.current_model["model"].generate(
@@ -113,13 +115,27 @@ class AtlasInferenceApp:
113
  do_sample=True,
114
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
115
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
116
- streamer=None, # Use a custom streamer for real-time updates
117
  ):
118
- chunk_text = st.session_state.current_model["tokenizer"].decode(chunk, skip_special_tokens=True)
119
- full_response += chunk_text
120
- response_container.markdown(full_response)
121
-
122
- return full_response.split("### Response:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
  return f"⚠️ Generation Error: {str(e)}"
125
  finally:
@@ -159,6 +175,7 @@ class AtlasInferenceApp:
159
 
160
  st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
161
 
 
162
  for message in st.session_state.chat_history:
163
  with st.chat_message(
164
  message["role"],
@@ -166,6 +183,7 @@ class AtlasInferenceApp:
166
  ):
167
  st.markdown(message["content"])
168
 
 
169
  if prompt := st.chat_input("Message Atlas..."):
170
  st.session_state.chat_history.append({"role": "user", "content": prompt})
171
  with st.chat_message("user", avatar=USER_PFP):
 
6
  import re
7
  import os
8
 
9
+ # Load Hugging Face token
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  login(token=HF_TOKEN)
12
 
13
+ # Define models
14
  MODELS = {
15
  "athena-1": {
16
  "name": "🦁 Atlas-Flash",
 
22
  },
23
  }
24
 
25
+ # Profile pictures
26
+ USER_PFP = "user.png" # Hugging Face user avatar
27
+ AI_PFP = "ai_pfp.png" # Replace with the path to your AI's image or a URL
28
 
29
  class AtlasInferenceApp:
30
  def __init__(self):
 
61
 
62
  model_path = MODELS[model_key]["sizes"][model_size]
63
 
64
+ # Load Qwen-compatible tokenizer and model
65
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
66
  model = AutoModelForCausalLM.from_pretrained(
67
  model_path,
68
+ device_map="auto", # Use GPU if available
69
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
70
  trust_remote_code=True,
71
  low_cpu_mem_usage=True
72
  )
73
 
74
+ # Update session state
75
  st.session_state.current_model.update({
76
  "tokenizer": tokenizer,
77
  "model": model,
 
89
  return "⚠️ Please select and load a model first"
90
 
91
  try:
92
+ # Add a system instruction to guide the model's behavior
93
  system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune."
94
  prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
95
 
 
101
  padding=True
102
  )
103
 
104
+ # Generate response with streaming
105
+ response_container = st.empty() # Placeholder for streaming text
106
  full_response = ""
107
  with torch.no_grad():
108
  for chunk in st.session_state.current_model["model"].generate(
 
115
  do_sample=True,
116
  pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
117
  eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
 
118
  ):
119
+ # Decode the chunk and update the response
120
+ try:
121
+ chunk_text = st.session_state.current_model["tokenizer"].decode(chunk, skip_special_tokens=True)
122
+
123
+ # Remove the prompt from the response
124
+ if prompt in chunk_text:
125
+ chunk_text = chunk_text.replace(prompt, "").strip()
126
+
127
+ full_response += chunk_text
128
+ response_container.markdown(full_response)
129
+ except Exception as decode_error:
130
+ st.error(f"⚠️ Token Decoding Error: {str(decode_error)}")
131
+ break
132
+
133
+ # Stop if the response is too long or incomplete
134
+ if len(full_response) >= max_tokens * 4: # Approximate token-to-character ratio
135
+ st.warning("⚠️ Response truncated due to length limit.")
136
+ break
137
+
138
+ return full_response.strip() # Return the cleaned response
139
  except Exception as e:
140
  return f"⚠️ Generation Error: {str(e)}"
141
  finally:
 
175
 
176
  st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
177
 
178
+ # Display chat history
179
  for message in st.session_state.chat_history:
180
  with st.chat_message(
181
  message["role"],
 
183
  ):
184
  st.markdown(message["content"])
185
 
186
+ # Input box for user messages
187
  if prompt := st.chat_input("Message Atlas..."):
188
  st.session_state.chat_history.append({"role": "user", "content": prompt})
189
  with st.chat_message("user", avatar=USER_PFP):