ohayonguy commited on
Commit
b51eadf
1 Parent(s): d18dfca

changed output format to png

Browse files
Files changed (1) hide show
  1. app.py +65 -50
app.py CHANGED
@@ -4,9 +4,9 @@ import os
4
 
5
  import matplotlib.pyplot as plt
6
 
7
- if os.getenv('SPACES_ZERO_GPU') == "true":
8
- os.environ['SPACES_ZERO_GPU'] = "1"
9
- os.environ['K_DIFFUSION_USE_COMPILE'] = "0"
10
 
11
  import spaces
12
  import cv2
@@ -24,11 +24,13 @@ from lightning_models.mmse_rectified_flow import MMSERectifiedFlow
24
  MAX_SEED = 10000
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
- os.makedirs('pretrained_models', exist_ok=True)
28
- realesr_model_path = 'pretrained_models/RealESRGAN_x4plus.pth'
29
  if not os.path.exists(realesr_model_path):
30
  os.system(
31
- "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O pretrained_models/RealESRGAN_x4plus.pth")
 
 
32
 
33
  # # background enhancer with RealESRGAN
34
  # model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
@@ -39,18 +41,15 @@ if not os.path.exists(realesr_model_path):
39
 
40
  def set_realesrgan():
41
  use_half = False
42
- if torch.cuda.is_available(): # set False in CPU/MPS mode
43
- no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
44
- if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
 
 
45
  use_half = True
46
 
47
  model = RRDBNet(
48
- num_in_ch=3,
49
- num_out_ch=3,
50
- num_feat=64,
51
- num_block=23,
52
- num_grow_ch=32,
53
- scale=2,
54
  )
55
  upsampler = RealESRGANer(
56
  scale=2,
@@ -59,20 +58,28 @@ def set_realesrgan():
59
  tile=400,
60
  tile_pad=40,
61
  pre_pad=0,
62
- half=use_half
63
  )
64
  return upsampler
65
 
 
66
  upsampler = set_realesrgan()
67
- pmrf = MMSERectifiedFlow.from_pretrained('ohayonguy/PMRF_blind_face_image_restoration').to(device=device)
 
 
 
68
 
69
  def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, device):
70
- source_dist_samples = pmrf_model.create_source_distribution_samples(x, y, non_noisy_z0)
 
 
71
  dt = (1.0 / num_flow_steps) * (1.0 - pmrf_model.hparams.eps)
72
  x_t_next = source_dist_samples.clone()
73
  t_one = torch.ones(x.shape[0], device=device)
74
  for i in tqdm(range(num_flow_steps)):
75
- num_t = (i / num_flow_steps) * (1.0 - pmrf_model.hparams.eps) + pmrf_model.hparams.eps
 
 
76
  v_t_next = pmrf_model(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
77
  x_t_next = x_t_next.clone() + v_t_next * dt
78
 
@@ -87,6 +94,7 @@ def resize(img, size):
87
  interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
88
  return cv2.resize(img, (w, h), interpolation=interp)
89
 
 
90
  @torch.inference_mode()
91
  @spaces.GPU()
92
  def enhance_face(img, face_helper, has_aligned, num_flow_steps, scale=2):
@@ -102,20 +110,26 @@ def enhance_face(img, face_helper, has_aligned, num_flow_steps, scale=2):
102
  if len(face_helper.cropped_faces) == 0:
103
  raise gr.Error("Could not identify any face in the image.")
104
  if has_aligned and len(face_helper.cropped_faces) > 1:
105
- raise gr.Error("You marked that the input image is aligned, but multiple faces were detected.")
 
 
106
 
107
  # face restoration
108
  for i, cropped_face in tqdm(enumerate(face_helper.cropped_faces)):
109
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
110
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
111
 
112
- output = generate_reconstructions(pmrf,
113
- torch.zeros_like(cropped_face_t),
114
- cropped_face_t,
115
- None,
116
- num_flow_steps,
117
- device)
118
- restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
 
 
 
 
119
  restored_face = restored_face.astype("uint8")
120
  face_helper.add_restored_face(restored_face)
121
 
@@ -126,8 +140,6 @@ def enhance_face(img, face_helper, has_aligned, num_flow_steps, scale=2):
126
  face_helper.get_inverse_affine(None)
127
  # paste each restored face to the input image
128
  restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
129
- print(bg_img.shape, img.shape,restored_img.shape)
130
-
131
  return face_helper.cropped_faces, face_helper.restored_faces, restored_img
132
  else:
133
  return face_helper.cropped_faces, face_helper.restored_faces, None
@@ -135,8 +147,15 @@ def enhance_face(img, face_helper, has_aligned, num_flow_steps, scale=2):
135
 
136
  @torch.inference_mode()
137
  @spaces.GPU()
138
- def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps,
139
- progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
140
  if img is None:
141
  raise gr.Error("Please upload an image before submitting.")
142
  if randomize_seed:
@@ -145,24 +164,23 @@ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps,
145
  img = cv2.imread(img, cv2.IMREAD_COLOR)
146
  h, w = img.shape[0:2]
147
  if h > 4500 or w > 4500:
148
- raise gr.Error('Image size too large.')
149
 
150
  face_helper = FaceRestoreHelper(
151
  scale,
152
  face_size=512,
153
  crop_ratio=(1, 1),
154
- det_model='retinaface_resnet50',
155
- save_ext='png',
156
  use_parse=True,
157
  device=device,
158
- model_rootpath=None)
 
159
 
160
  has_aligned = aligned
161
- cropped_face, restored_faces, restored_img = enhance_face(img,
162
- face_helper,
163
- has_aligned,
164
- num_flow_steps=num_flow_steps,
165
- scale=scale)
166
  if has_aligned:
167
  output = restored_faces[0]
168
  else:
@@ -231,7 +249,6 @@ css = """
231
  }
232
  """
233
 
234
-
235
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
236
  gr.HTML(intro)
237
  gr.Markdown(markdown_top)
@@ -255,15 +272,13 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
255
  value=1,
256
  )
257
  seed = gr.Slider(
258
- label="Seed",
259
- minimum=0,
260
- maximum=MAX_SEED,
261
- step=1,
262
- value=42,
263
  )
264
 
265
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
266
- aligned = gr.Checkbox(label="The input is an aligned face image.", value=False)
 
 
267
 
268
  with gr.Row():
269
  with gr.Column(scale=1):
@@ -272,9 +287,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
272
  clear_button = gr.ClearButton(value="Clear")
273
 
274
  with gr.Row():
275
- result = gr.Image(label="Output", type="numpy", show_label=True)
276
  with gr.Row():
277
- gallery = gr.Gallery(label="Restored faces gallery", type="numpy", show_label=True)
278
 
279
  clear_button.add(input_im)
280
  clear_button.add(result)
 
4
 
5
  import matplotlib.pyplot as plt
6
 
7
+ if os.getenv("SPACES_ZERO_GPU") == "true":
8
+ os.environ["SPACES_ZERO_GPU"] = "1"
9
+ os.environ["K_DIFFUSION_USE_COMPILE"] = "0"
10
 
11
  import spaces
12
  import cv2
 
24
  MAX_SEED = 10000
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
+ os.makedirs("pretrained_models", exist_ok=True)
28
+ realesr_model_path = "pretrained_models/RealESRGAN_x4plus.pth"
29
  if not os.path.exists(realesr_model_path):
30
  os.system(
31
+ "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O pretrained_models/RealESRGAN_x4plus.pth"
32
+ )
33
+
34
 
35
  # # background enhancer with RealESRGAN
36
  # model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
 
41
 
42
  def set_realesrgan():
43
  use_half = False
44
+ if torch.cuda.is_available(): # set False in CPU/MPS mode
45
+ no_half_gpu_list = ["1650", "1660"] # set False for GPUs that don't support f16
46
+ if not True in [
47
+ gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list
48
+ ]:
49
  use_half = True
50
 
51
  model = RRDBNet(
52
+ num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2,
 
 
 
 
 
53
  )
54
  upsampler = RealESRGANer(
55
  scale=2,
 
58
  tile=400,
59
  tile_pad=40,
60
  pre_pad=0,
61
+ half=use_half,
62
  )
63
  return upsampler
64
 
65
+
66
  upsampler = set_realesrgan()
67
+ pmrf = MMSERectifiedFlow.from_pretrained(
68
+ "ohayonguy/PMRF_blind_face_image_restoration"
69
+ ).to(device=device)
70
+
71
 
72
  def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, device):
73
+ source_dist_samples = pmrf_model.create_source_distribution_samples(
74
+ x, y, non_noisy_z0
75
+ )
76
  dt = (1.0 / num_flow_steps) * (1.0 - pmrf_model.hparams.eps)
77
  x_t_next = source_dist_samples.clone()
78
  t_one = torch.ones(x.shape[0], device=device)
79
  for i in tqdm(range(num_flow_steps)):
80
+ num_t = (i / num_flow_steps) * (
81
+ 1.0 - pmrf_model.hparams.eps
82
+ ) + pmrf_model.hparams.eps
83
  v_t_next = pmrf_model(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
84
  x_t_next = x_t_next.clone() + v_t_next * dt
85
 
 
94
  interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
95
  return cv2.resize(img, (w, h), interpolation=interp)
96
 
97
+
98
  @torch.inference_mode()
99
  @spaces.GPU()
100
  def enhance_face(img, face_helper, has_aligned, num_flow_steps, scale=2):
 
110
  if len(face_helper.cropped_faces) == 0:
111
  raise gr.Error("Could not identify any face in the image.")
112
  if has_aligned and len(face_helper.cropped_faces) > 1:
113
+ raise gr.Error(
114
+ "You marked that the input image is aligned, but multiple faces were detected."
115
+ )
116
 
117
  # face restoration
118
  for i, cropped_face in tqdm(enumerate(face_helper.cropped_faces)):
119
+ cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
120
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
121
 
122
+ output = generate_reconstructions(
123
+ pmrf,
124
+ torch.zeros_like(cropped_face_t),
125
+ cropped_face_t,
126
+ None,
127
+ num_flow_steps,
128
+ device,
129
+ )
130
+ restored_face = tensor2img(
131
+ output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1)
132
+ )
133
  restored_face = restored_face.astype("uint8")
134
  face_helper.add_restored_face(restored_face)
135
 
 
140
  face_helper.get_inverse_affine(None)
141
  # paste each restored face to the input image
142
  restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
 
 
143
  return face_helper.cropped_faces, face_helper.restored_faces, restored_img
144
  else:
145
  return face_helper.cropped_faces, face_helper.restored_faces, None
 
147
 
148
  @torch.inference_mode()
149
  @spaces.GPU()
150
+ def inference(
151
+ seed,
152
+ randomize_seed,
153
+ img,
154
+ aligned,
155
+ scale,
156
+ num_flow_steps,
157
+ progress=gr.Progress(track_tqdm=True),
158
+ ):
159
  if img is None:
160
  raise gr.Error("Please upload an image before submitting.")
161
  if randomize_seed:
 
164
  img = cv2.imread(img, cv2.IMREAD_COLOR)
165
  h, w = img.shape[0:2]
166
  if h > 4500 or w > 4500:
167
+ raise gr.Error("Image size too large.")
168
 
169
  face_helper = FaceRestoreHelper(
170
  scale,
171
  face_size=512,
172
  crop_ratio=(1, 1),
173
+ det_model="retinaface_resnet50",
174
+ save_ext="png",
175
  use_parse=True,
176
  device=device,
177
+ model_rootpath=None,
178
+ )
179
 
180
  has_aligned = aligned
181
+ cropped_face, restored_faces, restored_img = enhance_face(
182
+ img, face_helper, has_aligned, num_flow_steps=num_flow_steps, scale=scale
183
+ )
 
 
184
  if has_aligned:
185
  output = restored_faces[0]
186
  else:
 
249
  }
250
  """
251
 
 
252
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
253
  gr.HTML(intro)
254
  gr.Markdown(markdown_top)
 
272
  value=1,
273
  )
274
  seed = gr.Slider(
275
+ label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42,
 
 
 
 
276
  )
277
 
278
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
279
+ aligned = gr.Checkbox(
280
+ label="The input is an aligned face image.", value=False
281
+ )
282
 
283
  with gr.Row():
284
  with gr.Column(scale=1):
 
287
  clear_button = gr.ClearButton(value="Clear")
288
 
289
  with gr.Row():
290
+ result = gr.Image(label="Output", type="numpy", show_label=True, format="png")
291
  with gr.Row():
292
+ gallery = gr.Gallery(label="Restored faces gallery", type="numpy", show_label=True, format="png")
293
 
294
  clear_button.add(input_im)
295
  clear_button.add(result)