huntingcarlisle commited on
Commit
d3e0e2f
1 Parent(s): 4bab302

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -19
app.py CHANGED
@@ -7,6 +7,7 @@ from io import BytesIO
7
  # from IPython.display import display
8
  import base64
9
  import time
 
10
 
11
 
12
 
@@ -21,6 +22,10 @@ def display_image(image=None,width=500,height=500):
21
  img = image.resize((width, height))
22
  return img
23
 
 
 
 
 
24
  # API Gateway endpoint URL
25
  api_url = 'https://a02q342s5b.execute-api.us-east-2.amazonaws.com/reinvent-demo-inf2-sm-20231114'
26
 
@@ -45,18 +50,50 @@ api_url = 'https://a02q342s5b.execute-api.us-east-2.amazonaws.com/reinvent-demo-
45
 
46
 
47
  # Creating Tabs
48
- tab1, tab2, tab3 = st.tabs(["Image Generation", "Architecture", "Code"])
49
 
50
  with tab1:
51
  # Create two columns for layout
52
  left_column, right_column = st.columns(2)
 
 
 
 
53
  # ===========
54
  with left_column:
55
  # Define Streamlit UI elements
56
- st.title('Stable Diffusion XL Image Generation with AWS Inferentia')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  prompt_one = st.text_area("Enter your prompt:",
59
- f"Raccoon astronaut in space, sci-fi, future, cold color palette, muted colors, detailed, 8k")
 
 
 
60
 
61
  # Number of inference steps
62
  num_inference_steps_one = st.slider("Number of Inference Steps",
@@ -76,15 +113,8 @@ with tab1:
76
  negative_prompt_one = st.text_area("Enter your negative prompt:",
77
  "cartoon, graphic, text, painting, crayon, graphite, abstract glitch, blurry")
78
 
79
-
80
-
81
-
82
-
83
-
84
-
85
- if st.button('Generate Image'):
86
- with st.spinner(f'Generating Image with {num_inference_steps_one} iterations'):
87
- with right_column:
88
  start_time = time.time()
89
  # ===============
90
  # Example input data
@@ -94,7 +124,8 @@ with tab1:
94
  "num_inference_steps": num_inference_steps_one,
95
  "seed": seed_one,
96
  "negative_prompt": negative_prompt_one
97
- }
 
98
  }
99
 
100
  # Make API request
@@ -105,20 +136,93 @@ with tab1:
105
  result_one = response_one.json()
106
  # st.success(f"Prediction result: {result}")
107
  image_one = display_image(decode_base64_image(result_one["generated_images"][0]))
108
- st.image(image_one,
109
  caption=f"{prompt_one}")
110
  end_time = time.time()
111
  total_time = round(end_time - start_time, 2)
112
- st.text(f"Prompt: {prompt_one}")
113
- st.text(f"Number of Iterations: {num_inference_steps_one}")
114
- st.text(f"Random Seed: {seed_one}")
115
- st.text(f'Total time taken: {total_time} seconds')
116
  # Calculate and display the time per iteration in milliseconds
117
  time_per_iteration_ms = (total_time / num_inference_steps_one)
118
- st.text(f'Time per iteration: {time_per_iteration_ms:.2f} seconds')
119
  else:
120
  st.error(f"Error: {response_one.text}")
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  with tab2:
124
  # ===========
 
7
  # from IPython.display import display
8
  import base64
9
  import time
10
+ import random
11
 
12
 
13
 
 
22
  img = image.resize((width, height))
23
  return img
24
 
25
+ def pretty_print(messages):
26
+ for message in messages:
27
+ return f"{message['role']}: {message['content']}"
28
+
29
  # API Gateway endpoint URL
30
  api_url = 'https://a02q342s5b.execute-api.us-east-2.amazonaws.com/reinvent-demo-inf2-sm-20231114'
31
 
 
50
 
51
 
52
  # Creating Tabs
53
+ tab1, tab2, tab3 = st.tabs(["Image Generation", "Architecture", "Code"])
54
 
55
  with tab1:
56
  # Create two columns for layout
57
  left_column, right_column = st.columns(2)
58
+
59
+ with right_column:
60
+ cont = st.container()
61
+
62
  # ===========
63
  with left_column:
64
  # Define Streamlit UI elements
65
+ st.title('Stable Diffusion XL Image Generation with AWS Inferentia 2')
66
+
67
+ sample_prompts = [
68
+ "A futuristic cityscape at sunset, cyberpunk",
69
+ "A serene landscape with mountains and a river, photorealistic style",
70
+ "An astronaut riding a horse, artistic and surreal",
71
+ "A robot playing chess in a medieval setting, high detail",
72
+ "An underwater scene with colorful coral reefs and fish, vibrant colors",
73
+ "Raccoon astronaut in space, sci-fi, future, cold color palette, muted colors, detailed, 8k",
74
+ "A lost city rediscovered in the Amazon jungle, overgrown with plants, in the style of a vintage travel poster",
75
+ "A steampunk train emitting clouds of steam as it races through a mountain pass, digital art",
76
+ "An enchanted forest with bioluminescent trees and fairies dancing, in a Studio Ghibli style",
77
+ "A portrait of an elegant alien empress with a detailed headdress, reminiscent of Art Nouveau",
78
+ "A post-apocalyptic Tokyo with nature reclaiming skyscrapers, in the style of a concept art",
79
+ "A mythical phoenix rising from ashes, vibrant colors, with a nebula in the background",
80
+ "A cybernetic wolf in a neon-lit city, cyberpunk theme, rain-drenched streets",
81
+ "A high fantasy battle scene with dragons in the sky and knights on the ground, epic scale",
82
+ "An ice castle on a lonely mountain peak, under the northern lights, fantasy illustration",
83
+ "A surreal landscape where giant flowers bloom in the desert, with a distant thunderstorm, hyperrealism"
84
+ ]
85
+
86
+ def set_random_prompt():
87
+ # This function will be called when the button is clicked
88
+ random_prompt = random.choice(sample_prompts)
89
+ # Update the session state for the input field
90
+ st.session_state.prompt_one = random_prompt
91
 
92
  prompt_one = st.text_area("Enter your prompt:",
93
+
94
+ key="prompt_one")
95
+
96
+ st.button('Random Prompt', on_click=set_random_prompt)
97
 
98
  # Number of inference steps
99
  num_inference_steps_one = st.slider("Number of Inference Steps",
 
113
  negative_prompt_one = st.text_area("Enter your negative prompt:",
114
  "cartoon, graphic, text, painting, crayon, graphite, abstract glitch, blurry")
115
 
116
+ if st.button('Generate Image'):
117
+ with st.spinner(f'Generating Image with {num_inference_steps_one} iterations'):
 
 
 
 
 
 
 
118
  start_time = time.time()
119
  # ===============
120
  # Example input data
 
124
  "num_inference_steps": num_inference_steps_one,
125
  "seed": seed_one,
126
  "negative_prompt": negative_prompt_one
127
+ },
128
+ "endpoint": "huggingface-pytorch-inference-neuronx-2023-11-14-21-22-10-388"
129
  }
130
 
131
  # Make API request
 
136
  result_one = response_one.json()
137
  # st.success(f"Prediction result: {result}")
138
  image_one = display_image(decode_base64_image(result_one["generated_images"][0]))
139
+ cont.image(image_one,
140
  caption=f"{prompt_one}")
141
  end_time = time.time()
142
  total_time = round(end_time - start_time, 2)
143
+ cont.text(f"Prompt: {prompt_one}")
144
+ cont.text(f"Number of Iterations: {num_inference_steps_one}")
145
+ cont.text(f"Random Seed: {seed_one}")
146
+ cont.text(f'Total time taken: {total_time} seconds')
147
  # Calculate and display the time per iteration in milliseconds
148
  time_per_iteration_ms = (total_time / num_inference_steps_one)
149
+ cont.text(f'Time per iteration: {time_per_iteration_ms:.2f} seconds')
150
  else:
151
  st.error(f"Error: {response_one.text}")
152
 
153
+ # with tab2:
154
+
155
+ # st.title('Llama 2 7B Text Generation with AWS Inferentia 2')
156
+
157
+ # params = {
158
+ # "do_sample" : True,
159
+ # "top_p": 0.6,
160
+ # "temperature": 0.9,
161
+ # "top_k": 50,
162
+ # "max_new_tokens": 512,
163
+ # "repetition_penalty": 1.03,
164
+ # }
165
+
166
+ # if "messages" not in st.session_state:
167
+ # st.session_state.messages = [
168
+ # {"role": "system", "content": "You are a helpful Travel Planning Assistant. You respond with only 1-2 sentences."},
169
+ # {'role': 'user', 'content': 'Where can I travel in the fall for cloudy, rainy, and beautiful views?'},
170
+ # ]
171
+
172
+ # for message in st.session_state.messages:
173
+ # with st.chat_message(message["role"]):
174
+ # st.markdown(message["content"])
175
+
176
+ # with st.chat_message("assistant"):
177
+ # message_placeholder = st.empty()
178
+ # full_response = ""
179
+ # prompt_input_one = {
180
+ # "prompt": st.session_state.messages,
181
+ # "parameters": params,
182
+ # "endpoint": "huggingface-pytorch-inference-neuronx-2023-11-28-16-09-51-708"
183
+ # }
184
+
185
+ # response_one = requests.post(api_url, json=prompt_input_one)
186
+
187
+ # if response_one.status_code == 200:
188
+ # result_one = response_one.json()
189
+ # # st.success(f"Prediction result: {result}")
190
+ # full_response += result_one["generation"]
191
+ # else:
192
+ # st.error(f"Error: {response_one.text}")
193
+
194
+ # message_placeholder.markdown(full_response)
195
+ # st.session_state.messages.append({"role": "assistant", "content": full_response})
196
+
197
+ # if prompt := st.chat_input("What is up?"):
198
+ # st.session_state.messages.append({"role": "user", "content": prompt})
199
+ # print(st.session_state.messages)
200
+ # with st.chat_message("user"):
201
+ # st.markdown(prompt)
202
+
203
+ # with st.chat_message("assistant"):
204
+ # message_placeholder = st.empty()
205
+ # new_response = ""
206
+ # prompt_input_one = {
207
+ # "prompt": st.session_state.messages,
208
+ # "parameters": params,
209
+ # "endpoint": "huggingface-pytorch-inference-neuronx-2023-11-28-16-09-51-708"
210
+ # }
211
+
212
+ # response_one = requests.post(api_url, json=prompt_input_one)
213
+
214
+ # if response_one.status_code == 200:
215
+ # result_one = response_one.json()
216
+ # # st.success(f"Prediction result: {result}")
217
+ # new_response += result_one["generation"]
218
+ # else:
219
+ # st.error(f"Error: {response_one.text}")
220
+
221
+ # message_placeholder.markdown(new_response)
222
+ # st.session_state.messages.append({"role": "assistant", "content": new_response})
223
+
224
+
225
+
226
 
227
  with tab2:
228
  # ===========