DrishtiSharma commited on
Commit
77cb690
Β·
verified Β·
1 Parent(s): 6ad8e4c

Create lab/interim1.py

Browse files
Files changed (1) hide show
  1. lab/interim1.py +279 -0
lab/interim1.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ from typing import Iterable
4
+ from moa.agent import MOAgent
5
+ from moa.agent.moa import ResponseChunk
6
+ from streamlit_ace import st_ace
7
+ import copy
8
+
9
+ # Default configuration
10
+ default_config = {
11
+ "main_model": "llama-3.3-70b-versatile",
12
+ "cycles": 3,
13
+ "layer_agent_config": {}
14
+ }
15
+
16
+ layer_agent_config_def = {
17
+ "layer_agent_1": {
18
+ "system_prompt": "Think through your response step by step. {helper_response}",
19
+ "model_name": "llama-3.1-8b-instant"
20
+ },
21
+ "layer_agent_2": {
22
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
23
+ "model_name": "gemma2-9b-it",
24
+ "temperature": 0.7
25
+ },
26
+ "layer_agent_3": {
27
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
28
+ "model_name": "llama-3.1-8b-instant"
29
+ },
30
+
31
+ }
32
+
33
+ # Recommended Configuration
34
+
35
+ rec_config = {
36
+ "main_model": "llama-3.3-70b-versatile",
37
+ "cycles": 2,
38
+ "layer_agent_config": {}
39
+ }
40
+
41
+ layer_agent_config_rec = {
42
+ "layer_agent_1": {
43
+ "system_prompt": "Think through your response step by step. {helper_response}",
44
+ "model_name": "llama-3.1-8b-instant",
45
+ "temperature": 0.1
46
+ },
47
+ "layer_agent_2": {
48
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
49
+ "model_name": "llama-3.1-8b-instant",
50
+ "temperature": 0.2
51
+ },
52
+ "layer_agent_3": {
53
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
54
+ "model_name": "llama-3.1-8b-instant",
55
+ "temperature": 0.4
56
+ },
57
+ "layer_agent_4": {
58
+ "system_prompt": "You are an expert planner agent. Create a plan for how to answer the human's query. {helper_response}",
59
+ "model_name": "mixtral-8x7b-32768",
60
+ "temperature": 0.5
61
+ },
62
+ }
63
+
64
+
65
+ def stream_response(messages: Iterable[ResponseChunk]):
66
+ layer_outputs = {}
67
+ for message in messages:
68
+ if message['response_type'] == 'intermediate':
69
+ layer = message['metadata']['layer']
70
+ if layer not in layer_outputs:
71
+ layer_outputs[layer] = []
72
+ layer_outputs[layer].append(message['delta'])
73
+ else:
74
+ # Display accumulated layer outputs
75
+ for layer, outputs in layer_outputs.items():
76
+ st.write(f"Layer {layer}")
77
+ cols = st.columns(len(outputs))
78
+ for i, output in enumerate(outputs):
79
+ with cols[i]:
80
+ st.expander(label=f"Agent {i+1}", expanded=False).write(output)
81
+
82
+ # Clear layer outputs for the next iteration
83
+ layer_outputs = {}
84
+
85
+ # Yield the main agent's output
86
+ yield message['delta']
87
+
88
+
89
+
90
+
91
+ def set_moa_agent(
92
+ main_model: str = default_config['main_model'],
93
+ cycles: int = default_config['cycles'],
94
+ layer_agent_config: dict[dict[str, any]] = copy.deepcopy(layer_agent_config_def),
95
+ main_model_temperature: float = 0.1,
96
+ override: bool = False
97
+ ):
98
+ if override or ("main_model" not in st.session_state):
99
+ st.session_state.main_model = main_model
100
+ else:
101
+ if "main_model" not in st.session_state: st.session_state.main_model = main_model
102
+
103
+ if override or ("cycles" not in st.session_state):
104
+ st.session_state.cycles = cycles
105
+ else:
106
+ if "cycles" not in st.session_state: st.session_state.cycles = cycles
107
+
108
+ if override or ("layer_agent_config" not in st.session_state):
109
+ st.session_state.layer_agent_config = layer_agent_config
110
+ else:
111
+ if "layer_agent_config" not in st.session_state: st.session_state.layer_agent_config = layer_agent_config
112
+
113
+ if override or ("main_temp" not in st.session_state):
114
+ st.session_state.main_temp = main_model_temperature
115
+ else:
116
+ if "main_temp" not in st.session_state: st.session_state.main_temp = main_model_temperature
117
+
118
+ cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config)
119
+
120
+ if override or ("moa_agent" not in st.session_state):
121
+ st.session_state.moa_agent = MOAgent.from_config(
122
+ main_model=st.session_state.main_model,
123
+ cycles=st.session_state.cycles,
124
+ layer_agent_config=cls_ly_conf,
125
+ temperature=st.session_state.main_temp
126
+ )
127
+
128
+ del cls_ly_conf
129
+ del layer_agent_config
130
+
131
+ st.set_page_config(
132
+ page_title="Mixture of Agents",
133
+ menu_items={
134
+ 'About': "## Groq Mixture-Of-Agents \n Powered by [Groq](https://groq.com)"
135
+ },
136
+ layout="wide"
137
+ )
138
+ valid_model_names = [
139
+ 'llama-3.1-8b-instant',
140
+ 'llama-3.3-70b-versatile',
141
+ 'gemma2-9b-it',
142
+ 'mixtral-8x7b-32768'
143
+ ]
144
+
145
+ # Initialize session state
146
+ if "messages" not in st.session_state:
147
+ st.session_state.messages = []
148
+
149
+ set_moa_agent()
150
+
151
+ # Sidebar for configuration
152
+ # Sidebar for configuration
153
+ with st.sidebar:
154
+ st.title("MOA Configuration")
155
+ with st.form("Agent Configuration", border=False):
156
+ if st.form_submit_button("Use Recommended Config"):
157
+ try:
158
+ set_moa_agent(
159
+ main_model=rec_config['main_model'],
160
+ cycles=rec_config['cycles'],
161
+ layer_agent_config=layer_agent_config_rec,
162
+ override=True
163
+ )
164
+ st.session_state.messages = []
165
+ st.success("Configuration updated successfully!")
166
+ except json.JSONDecodeError:
167
+ st.error("Invalid JSON in Layer Agent Configuration. Please check your input.")
168
+ except Exception as e:
169
+ st.error(f"Error updating configuration: {str(e)}")
170
+
171
+ # Main model selection
172
+ new_main_model = st.selectbox(
173
+ "Select Main Model",
174
+ options=valid_model_names,
175
+ index=valid_model_names.index(st.session_state.main_model)
176
+ )
177
+
178
+ # Cycles input
179
+ new_cycles = st.number_input(
180
+ "Number of Layers",
181
+ min_value=1,
182
+ max_value=10,
183
+ value=st.session_state.cycles
184
+ )
185
+
186
+ # Main Model Temperature
187
+ main_temperature = st.slider(
188
+ "Main Model Temperature",
189
+ min_value=0.0,
190
+ max_value=1.0,
191
+ value=st.session_state.main_temp,
192
+ step=0.05
193
+ )
194
+
195
+ # Layer agent configuration
196
+ tooltip = "Agents in the layer agent configuration run in parallel _per cycle_. Each layer agent supports all initialization parameters of [Langchain's ChatGroq](https://api.python.langchain.com/en/latest/chat_models/langchain_groq.chat_models.ChatGroq.html) class as valid dictionary fields."
197
+ st.markdown("Layer Agent Config", help=tooltip)
198
+ new_layer_agent_config = st_ace(
199
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
200
+ language='json',
201
+ placeholder="Layer Agent Configuration (JSON)",
202
+ show_gutter=False,
203
+ wrap=True,
204
+ auto_update=True
205
+ )
206
+
207
+
208
+ if st.form_submit_button("Update Configuration"):
209
+ try:
210
+ new_layer_config = json.loads(new_layer_agent_config)
211
+ set_moa_agent(
212
+ main_model=new_main_model,
213
+ cycles=new_cycles,
214
+ layer_agent_config=new_layer_config,
215
+ main_model_temperature=main_temperature,
216
+ override=True
217
+ )
218
+ st.session_state.messages = []
219
+ st.success("Configuration updated successfully!")
220
+ except json.JSONDecodeError:
221
+ st.error("Invalid JSON in Layer Agent Configuration. Please check your input.")
222
+ except Exception as e:
223
+ st.error(f"Error updating configuration: {str(e)}")
224
+
225
+ st.markdown("---")
226
+ st.markdown("""
227
+ ### Credits
228
+ - MOA: [Together AI](https://www.together.ai/blog/together-moa)
229
+ - LLMs: [Groq](https://groq.com/)
230
+ - Paper: [arXiv:2406.04692](https://arxiv.org/abs/2406.04692)
231
+ """)
232
+
233
+ # Main app layout
234
+ st.header("Mixture of Agents", anchor=False)
235
+ st.write("A this project oversees implementation of Mixture of Agents architecture Powered by Groq LLMs.")
236
+ # st.image("./static/moa_groq.svg", caption="Mixture of Agents Workflow", width=800)
237
+
238
+ # Display current configuration
239
+ with st.expander("Current MOA Configuration", expanded=False):
240
+ st.markdown(f"**Main Model**: ``{st.session_state.main_model}``")
241
+ st.markdown(f"**Main Model Temperature**: ``{st.session_state.main_temp:.1f}``")
242
+ st.markdown(f"**Layers**: ``{st.session_state.cycles}``")
243
+ st.markdown(f"**Layer Agents Config**:")
244
+ new_layer_agent_config = st_ace(
245
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
246
+ language='json',
247
+ placeholder="Layer Agent Configuration (JSON)",
248
+ show_gutter=False,
249
+ wrap=True,
250
+ readonly=True,
251
+ auto_update=True
252
+ )
253
+
254
+ # Chat interface
255
+ for message in st.session_state.messages:
256
+ with st.chat_message(message["role"]):
257
+ st.markdown(message["content"])
258
+
259
+ if query := st.chat_input("Ask a question"):
260
+ st.session_state.messages.append({"role": "user", "content": query})
261
+ with st.chat_message("user"):
262
+ st.markdown(query)
263
+
264
+ moa_agent: MOAgent = st.session_state.moa_agent
265
+ with st.chat_message("assistant"):
266
+ ast_mess = stream_response(moa_agent.chat(query, output_format="json"))
267
+ response = st.write_stream(ast_mess)
268
+
269
+ # Save the final response to session state
270
+ st.session_state.messages.append({"role": "assistant", "content": response})
271
+
272
+
273
+
274
+ # Add acknowledgment at the bottom
275
+ st.markdown("---")
276
+ st.markdown("""
277
+ ####
278
+ This app is based on [Emmanuel M. Ndaliro's work](https://github.com/kram254/Mixture-of-Agents-running-on-Groq/tree/main).
279
+ """)