ohayonguy commited on
Commit
a00800e
1 Parent(s): 94bce76

improved interface and added examples

Browse files
app.py CHANGED
@@ -1,25 +1,27 @@
 
 
1
  import os
2
 
 
 
3
  if os.getenv('SPACES_ZERO_GPU') == "true":
4
  os.environ['SPACES_ZERO_GPU'] = "1"
5
  os.environ['K_DIFFUSION_USE_COMPILE'] = "0"
 
6
  import spaces
7
  import cv2
8
  from tqdm import tqdm
9
  import gradio as gr
10
  import random
11
  import torch
12
- from basicsr.archs.srvgg_arch import SRVGGNetCompact
13
  from basicsr.utils import img2tensor, tensor2img
14
- from gradio_imageslider import ImageSlider
15
  from facexlib.utils.face_restoration_helper import FaceRestoreHelper
16
  from realesrgan.utils import RealESRGANer
17
 
18
  from lightning_models.mmse_rectified_flow import MMSERectifiedFlow
19
 
20
- torch.set_grad_enabled(False)
21
-
22
- MAX_SEED = 1000000
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
  os.makedirs('pretrained_models', exist_ok=True)
@@ -28,25 +30,42 @@ if not os.path.exists(realesr_model_path):
28
  os.system(
29
  "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O pretrained_models/RealESRGAN_x4plus.pth")
30
 
31
- # background enhancer with RealESRGAN
32
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
33
- half = True if torch.cuda.is_available() else False
34
- upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=400, tile_pad=10, pre_pad=0,
35
- half=half)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
37
  pmrf = MMSERectifiedFlow.from_pretrained('ohayonguy/PMRF_blind_face_image_restoration').to(device=device)
38
 
39
- face_helper_dummy = FaceRestoreHelper(
40
- 1,
41
- face_size=512,
42
- crop_ratio=(1, 1),
43
- det_model='retinaface_resnet50',
44
- save_ext='png',
45
- use_parse=True,
46
- device=device,
47
- model_rootpath=None)
48
-
49
-
50
  def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, device):
51
  source_dist_samples = pmrf_model.create_source_distribution_samples(x, y, non_noisy_z0)
52
  dt = (1.0 / num_flow_steps) * (1.0 - pmrf_model.hparams.eps)
@@ -57,58 +76,61 @@ def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, dev
57
  v_t_next = pmrf_model(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
58
  x_t_next = x_t_next.clone() + v_t_next * dt
59
 
60
- return x_t_next.clip(0, 1).to(torch.float32)
61
 
62
 
 
 
 
 
 
 
 
 
63
  @torch.inference_mode()
64
  @spaces.GPU()
65
- def enhance_face(img, face_helper, has_aligned, num_flow_steps, only_center_face=False, paste_back=True, scale=2):
66
  face_helper.clean_all()
67
- if has_aligned: # the inputs are already aligned
68
  img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
69
  face_helper.cropped_faces = [img]
70
  else:
71
  face_helper.read_image(img)
72
- face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
73
- # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
74
- # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
75
- # align and warp each face
76
  face_helper.align_warp_face()
77
  if len(face_helper.cropped_faces) == 0:
78
  raise gr.Error("Could not identify any face in the image.")
79
  if len(face_helper.cropped_faces) > 1:
80
- gr.Info(f"Identified {len(face_helper.cropped_faces)} faces in the image. The algorithm will enhance the quality of each face.")
 
81
  else:
82
  gr.Info(f"Identified one face in the image.")
83
 
84
  # face restoration
85
  for i, cropped_face in tqdm(enumerate(face_helper.cropped_faces)):
86
- # prepare data
87
- h, w = cropped_face.shape[0], cropped_face.shape[1]
88
- cropped_face = cv2.resize(cropped_face, (512, 512), interpolation=cv2.INTER_LINEAR)
89
- # face_helper.cropped_faces[i] = cropped_face
90
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
91
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
92
 
93
- dummy_x = torch.zeros_like(cropped_face_t)
94
- output = generate_reconstructions(pmrf, dummy_x, cropped_face_t, None, num_flow_steps, device)
 
 
 
 
95
  restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
96
- restored_face = cv2.resize(restored_face, (h, w), interpolation=cv2.INTER_LINEAR)
97
-
98
- restored_face = restored_face.astype('uint8')
99
  face_helper.add_restored_face(restored_face)
100
 
101
- if not has_aligned and paste_back:
102
  # upsample the background
103
- if upsampler is not None:
104
- # Now only support RealESRGAN for upsampling background
105
- bg_img = upsampler.enhance(img, outscale=scale)[0]
106
- else:
107
- bg_img = None
108
-
109
  face_helper.get_inverse_affine(None)
110
  # paste each restored face to the input image
111
  restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
 
 
112
  return face_helper.cropped_faces, face_helper.restored_faces, restored_img
113
  else:
114
  return face_helper.cropped_faces, face_helper.restored_faces, None
@@ -123,12 +145,7 @@ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps,
123
  if randomize_seed:
124
  seed = random.randint(0, MAX_SEED)
125
  torch.manual_seed(seed)
126
- if scale > 4:
127
- scale = 4 # avoid too large scale value
128
- img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
129
- if len(img.shape) == 2: # for gray inputs
130
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
131
-
132
  h, w = img.shape[0:2]
133
  if h > 4500 or w > 4500:
134
  raise gr.Error('Image size too large.')
@@ -143,22 +160,22 @@ def inference(seed, randomize_seed, img, aligned, scale, num_flow_steps,
143
  device=device,
144
  model_rootpath=None)
145
 
146
- has_aligned = True if aligned == 'Yes' else False
147
- cropped_face, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
148
- paste_back=True, num_flow_steps=num_flow_steps,
 
 
149
  scale=scale)
150
  if has_aligned:
151
- output = restored_aligned[0]
152
- # input = cropped_face[0].astype('uint8')
153
  else:
154
  output = restored_img
155
- # input = img
156
 
157
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
158
- # h, w = output.shape[0:2]
159
- # input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
160
- # input = cv2.resize(input, (h, w), interpolation=cv2.INTER_LINEAR)
161
- return output
162
 
163
 
164
  intro = """
@@ -177,8 +194,9 @@ Please refer to our project's page for more details: https://pmrf-ml.github.io/.
177
 
178
  *Notes* :
179
 
180
- 1. Our model is designed to restore aligned face images, where there is *only one* face in the image, and the face is centered. Here, however, we incorporate mechanisms that allow restoring the quality of *any* image that contains *any* number of faces. Thus, the resulting quality of such general images is not guaranteed.
181
- 2. Images that are too large won't work due to memory constraints.
 
182
 
183
  ---
184
  """
@@ -216,6 +234,7 @@ css = """
216
  }
217
  """
218
 
 
219
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
220
  gr.HTML(intro)
221
  gr.Markdown(markdown_top)
@@ -232,7 +251,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
232
  value=25,
233
  )
234
  upscale_factor = gr.Slider(
235
- label="Scale factor for the background upsampler. Applicable only to non-aligned face images.",
236
  minimum=1,
237
  maximum=4,
238
  step=0.1,
@@ -247,13 +266,37 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
247
  )
248
 
249
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
250
- aligned = gr.Checkbox(label="The input is an aligned face image", value=False)
251
 
252
  with gr.Row():
253
  run_button = gr.Button(value="Submit", variant="primary")
254
 
255
  with gr.Row():
256
  result = gr.Image(label="Output", type="numpy", show_label=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  gr.Markdown(article)
259
  gr.on(
@@ -267,8 +310,8 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
267
  upscale_factor,
268
  num_inference_steps,
269
  ],
270
- outputs=result,
271
- show_api=False,
272
  # show_progress="minimal",
273
  )
274
 
 
1
+ # Some of the implementations below are adopted from
2
+ # https://huggingface.co/spaces/sczhou/CodeFormer and https://huggingface.co/spaces/wzhouxiff/RestoreFormerPlusPlus
3
  import os
4
 
5
+ import matplotlib.pyplot as plt
6
+
7
  if os.getenv('SPACES_ZERO_GPU') == "true":
8
  os.environ['SPACES_ZERO_GPU'] = "1"
9
  os.environ['K_DIFFUSION_USE_COMPILE'] = "0"
10
+
11
  import spaces
12
  import cv2
13
  from tqdm import tqdm
14
  import gradio as gr
15
  import random
16
  import torch
17
+ from basicsr.archs.rrdbnet_arch import RRDBNet
18
  from basicsr.utils import img2tensor, tensor2img
 
19
  from facexlib.utils.face_restoration_helper import FaceRestoreHelper
20
  from realesrgan.utils import RealESRGANer
21
 
22
  from lightning_models.mmse_rectified_flow import MMSERectifiedFlow
23
 
24
+ MAX_SEED = 10000
 
 
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
  os.makedirs('pretrained_models', exist_ok=True)
 
30
  os.system(
31
  "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O pretrained_models/RealESRGAN_x4plus.pth")
32
 
33
+ # # background enhancer with RealESRGAN
34
+ # model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
35
+ # half = True if torch.cuda.is_available() else False
36
+ # upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=400, tile_pad=10, pre_pad=0,
37
+ # half=half)
38
+
39
+
40
+ def set_realesrgan():
41
+ use_half = False
42
+ if torch.cuda.is_available(): # set False in CPU/MPS mode
43
+ no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
44
+ if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
45
+ use_half = True
46
+
47
+ model = RRDBNet(
48
+ num_in_ch=3,
49
+ num_out_ch=3,
50
+ num_feat=64,
51
+ num_block=23,
52
+ num_grow_ch=32,
53
+ scale=2,
54
+ )
55
+ upsampler = RealESRGANer(
56
+ scale=2,
57
+ model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
58
+ model=model,
59
+ tile=400,
60
+ tile_pad=40,
61
+ pre_pad=0,
62
+ half=use_half
63
+ )
64
+ return upsampler
65
 
66
+ upsampler = set_realesrgan()
67
  pmrf = MMSERectifiedFlow.from_pretrained('ohayonguy/PMRF_blind_face_image_restoration').to(device=device)
68
 
 
 
 
 
 
 
 
 
 
 
 
69
  def generate_reconstructions(pmrf_model, x, y, non_noisy_z0, num_flow_steps, device):
70
  source_dist_samples = pmrf_model.create_source_distribution_samples(x, y, non_noisy_z0)
71
  dt = (1.0 / num_flow_steps) * (1.0 - pmrf_model.hparams.eps)
 
76
  v_t_next = pmrf_model(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
77
  x_t_next = x_t_next.clone() + v_t_next * dt
78
 
79
+ return x_t_next.clip(0, 1)
80
 
81
 
82
+ def resize(img, size):
83
+ # From https://github.com/sczhou/CodeFormer/blob/master/facelib/utils/face_restoration_helper.py
84
+ h, w = img.shape[0:2]
85
+ scale = size / min(h, w)
86
+ h, w = int(h * scale), int(w * scale)
87
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
88
+ return cv2.resize(img, (w, h), interpolation=interp)
89
+
90
  @torch.inference_mode()
91
  @spaces.GPU()
92
+ def enhance_face(img, face_helper, has_aligned, num_flow_steps, scale=2):
93
  face_helper.clean_all()
94
+ if has_aligned: # The inputs are already aligned
95
  img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
96
  face_helper.cropped_faces = [img]
97
  else:
98
  face_helper.read_image(img)
99
+ face_helper.input_img = resize(face_helper.input_img, 640)
100
+ face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=5)
 
 
101
  face_helper.align_warp_face()
102
  if len(face_helper.cropped_faces) == 0:
103
  raise gr.Error("Could not identify any face in the image.")
104
  if len(face_helper.cropped_faces) > 1:
105
+ gr.Info(f"Identified {len(face_helper.cropped_faces)} "
106
+ f"faces in the image. The algorithm will enhance the quality of each face.")
107
  else:
108
  gr.Info(f"Identified one face in the image.")
109
 
110
  # face restoration
111
  for i, cropped_face in tqdm(enumerate(face_helper.cropped_faces)):
 
 
 
 
112
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
113
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
114
 
115
+ output = generate_reconstructions(pmrf,
116
+ torch.zeros_like(cropped_face_t),
117
+ cropped_face_t,
118
+ None,
119
+ num_flow_steps,
120
+ device)
121
  restored_face = tensor2img(output.to(torch.float32).squeeze(0), rgb2bgr=True, min_max=(0, 1))
122
+ restored_face = restored_face.astype("uint8")
 
 
123
  face_helper.add_restored_face(restored_face)
124
 
125
+ if not has_aligned:
126
  # upsample the background
127
+ # Now only support RealESRGAN for upsampling background
128
+ bg_img = upsampler.enhance(img, outscale=scale)[0]
 
 
 
 
129
  face_helper.get_inverse_affine(None)
130
  # paste each restored face to the input image
131
  restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
132
+ print(bg_img.shape, img.shape,restored_img.shape)
133
+
134
  return face_helper.cropped_faces, face_helper.restored_faces, restored_img
135
  else:
136
  return face_helper.cropped_faces, face_helper.restored_faces, None
 
145
  if randomize_seed:
146
  seed = random.randint(0, MAX_SEED)
147
  torch.manual_seed(seed)
148
+ img = cv2.imread(img, cv2.IMREAD_COLOR)
 
 
 
 
 
149
  h, w = img.shape[0:2]
150
  if h > 4500 or w > 4500:
151
  raise gr.Error('Image size too large.')
 
160
  device=device,
161
  model_rootpath=None)
162
 
163
+ has_aligned = aligned
164
+ cropped_face, restored_faces, restored_img = enhance_face(img,
165
+ face_helper,
166
+ has_aligned,
167
+ num_flow_steps=num_flow_steps,
168
  scale=scale)
169
  if has_aligned:
170
+ output = restored_faces[0]
 
171
  else:
172
  output = restored_img
 
173
 
174
  output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
175
+ for i, restored_face in enumerate(restored_faces):
176
+ restored_faces[i] = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB)
177
+ torch.cuda.empty_cache()
178
+ return output, restored_faces
179
 
180
 
181
  intro = """
 
194
 
195
  *Notes* :
196
 
197
+ 1. Our model is designed to restore aligned face images, where there is *only one* face in the image, and the face is centered and aligned. Here, however, we incorporate mechanisms that allow restoring the quality of *any* image that contains *any* number of faces. Thus, the resulting quality of such general images is not guaranteed.
198
+ 2. If the faces in your image are not aligned, make sure that the checkbox "The input is an aligned face image" in *not* marked.
199
+ 3. Too large images may result in out-of-memory error.
200
 
201
  ---
202
  """
 
234
  }
235
  """
236
 
237
+
238
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
239
  gr.HTML(intro)
240
  gr.Markdown(markdown_top)
 
251
  value=25,
252
  )
253
  upscale_factor = gr.Slider(
254
+ label="Scale factor. Applicable only to non-aligned face images. This will upscale the entire image.",
255
  minimum=1,
256
  maximum=4,
257
  step=0.1,
 
266
  )
267
 
268
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
269
+ aligned = gr.Checkbox(label="The input is an aligned face image.", value=False)
270
 
271
  with gr.Row():
272
  run_button = gr.Button(value="Submit", variant="primary")
273
 
274
  with gr.Row():
275
  result = gr.Image(label="Output", type="numpy", show_label=True)
276
+ with gr.Row():
277
+ gallery = gr.Gallery(label="Restored faces gallery", type="numpy", show_label=True)
278
+
279
+ examples = gr.Examples(
280
+ examples=[
281
+ [42, False, "examples/01.png", False, 1, 25],
282
+ [42, False, "examples/03.jpg", False, 2, 25],
283
+ [42, False, "examples/00000055.png", True, 1, 25],
284
+ [42, False, "examples/00000085.png", True, 1, 25],
285
+ [42, False, "examples/00000113.png", True, 1, 25],
286
+ [42, False, "examples/00000137.png", True, 1, 25],
287
+ ],
288
+ fn=inference,
289
+ inputs=[
290
+ seed,
291
+ randomize_seed,
292
+ input_im,
293
+ aligned,
294
+ upscale_factor,
295
+ num_inference_steps,
296
+ ],
297
+ outputs=[result, gallery],
298
+ cache_examples="lazy",
299
+ )
300
 
301
  gr.Markdown(article)
302
  gr.on(
 
310
  upscale_factor,
311
  num_inference_steps,
312
  ],
313
+ outputs=[result, gallery],
314
+ # show_api=False,
315
  # show_progress="minimal",
316
  )
317
 
examples/00000055.png ADDED
examples/00000085.png ADDED
examples/00000113.png ADDED
examples/00000137.png ADDED
examples/01.png ADDED
examples/03.jpg ADDED