Michael Yang commited on
Commit
e88510d
·
1 Parent(s): 7134722

add b64 to baseline

Browse files
Files changed (2) hide show
  1. app.py +8 -13
  2. baseline.py +7 -1
app.py CHANGED
@@ -10,9 +10,6 @@ from baseline import run as run_baseline
10
  import torch
11
  from shared import DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
12
  from examples import stage1_examples, stage2_examples
13
- import pickle
14
- import codecs
15
- import subprocess
16
  import base64
17
  import io
18
 
@@ -126,11 +123,7 @@ def get_ours_image(response, overall_prompt_override="", seed=0, num_inference_s
126
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
127
  so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt, so_batch_size=2
128
  )
129
- print(type(image_np))
130
- pic_IObytes = io.BytesIO()
131
- plt.savefig(pic_IObytes, format='png')
132
- pic_IObytes.seek(0)
133
- canvasbase64 = base64.b64encode(pic_IObytes.read()).decode()
134
  images = [image_np, b64]
135
  # if show_so_imgs:
136
  # images.extend([np.asarray(so_img) for so_img in so_img_list])
@@ -143,8 +136,9 @@ def get_baseline_image(prompt, seed=0):
143
  scheduler_key = "dpm_scheduler"
144
  num_inference_steps = 20
145
 
146
- image_np = run_baseline(prompt, bg_seed=seed, scheduler_key=scheduler_key, num_inference_steps=num_inference_steps)
147
- return [image_np]
 
148
 
149
  def parse_input(text=None):
150
  try:
@@ -298,10 +292,11 @@ with gr.Blocks(
298
  # with gr.Column(scale=1):
299
  # output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img")
300
  with gr.Column(scale=1):
301
- gallery = gr.Gallery(
302
- label="Generated image", show_label=False, elem_id="gallery2", columns=[1], rows=[1], object_fit="contain", preview=True
303
  )
304
- generate_btn.click(fn=get_baseline_image, inputs=[sd_prompt, seed], outputs=gallery, api_name="baseline")
 
305
 
306
  gr.Examples(
307
  examples=stage1_examples,
 
10
  import torch
11
  from shared import DEFAULT_SO_NEGATIVE_PROMPT, DEFAULT_OVERALL_NEGATIVE_PROMPT
12
  from examples import stage1_examples, stage2_examples
 
 
 
13
  import base64
14
  import io
15
 
 
123
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
124
  so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt, so_batch_size=2
125
  )
126
+
 
 
 
 
127
  images = [image_np, b64]
128
  # if show_so_imgs:
129
  # images.extend([np.asarray(so_img) for so_img in so_img_list])
 
136
  scheduler_key = "dpm_scheduler"
137
  num_inference_steps = 20
138
 
139
+ image_np, b64 = run_baseline(prompt, bg_seed=seed, scheduler_key=scheduler_key, num_inference_steps=num_inference_steps)
140
+ images = [image_np, b64]
141
+ return images
142
 
143
  def parse_input(text=None):
144
  try:
 
292
  # with gr.Column(scale=1):
293
  # output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img")
294
  with gr.Column(scale=1):
295
+ gallery = gr.Image(
296
+ label="Generated image", show_label=False, elem_id="gallery", columns=[1], rows=[1], object_fit="contain"
297
  )
298
+ b64 = gr.Textbox(label="base64", placeholder="base64", lines = 2)
299
+ generate_btn.click(fn=get_baseline_image, inputs=[sd_prompt, seed], outputs=[gallery,b64], api_name="baseline")
300
 
301
  gr.Examples(
302
  examples=stage1_examples,
baseline.py CHANGED
@@ -5,6 +5,8 @@ import models
5
  from models import pipelines
6
  from shared import model_dict, DEFAULT_OVERALL_NEGATIVE_PROMPT
7
  import gc
 
 
8
 
9
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
10
 
@@ -41,4 +43,8 @@ def run(prompt, scheduler_key='dpm_scheduler', bg_seed=1, num_inference_steps=20
41
  gc.collect()
42
  torch.cuda.empty_cache()
43
 
44
- return images[0]
 
 
 
 
 
5
  from models import pipelines
6
  from shared import model_dict, DEFAULT_OVERALL_NEGATIVE_PROMPT
7
  import gc
8
+ from io import BytesIO
9
+ import base64
10
 
11
  vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
12
 
 
43
  gc.collect()
44
  torch.cuda.empty_cache()
45
 
46
+ with BytesIO() as buffer:
47
+ np.save(buffer, images[0])
48
+ img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
49
+
50
+ return images[0], img_str