File size: 2,895 Bytes
0324143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import numpy as np
from einops import rearrange
from kornia.geometry.transform.crop2d import warp_affine

from utils.matlab_cp2tform import get_similarity_transform_for_cv2
from torchvision.transforms import Pad

REFERNCE_FACIAL_POINTS_RELATIVE = np.array([[38.29459953, 51.69630051],
                                            [72.53179932, 51.50139999],
                                            [56.02519989, 71.73660278],
                                            [41.54930115, 92.3655014],
                                            [70.72990036, 92.20410156]
                                            ]) / 112 # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112


def verify_load(missing_keys, unexpected_keys):
    if len(unexpected_keys) > 0:
        raise RuntimeError(f"Found unexpected keys in state dict while loading the encoder:\n{unexpected_keys}")
    
    filtered_missing = [key for key in missing_keys if not "extract_kv" in key]
    if len(filtered_missing) > 0:
        raise RuntimeError(f"Missing keys in state dict while loading the encoder:\n{filtered_missing}")


@torch.no_grad()
def detect_face(images: torch.Tensor, mtcnn: torch.nn.Module) -> torch.Tensor:
    """

    Detect faces in the images using MTCNN. If no face is detected, use the whole image.

    """
    images = rearrange(images, "b c h w -> b h w c")
    if images.dtype != torch.uint8:
        images = ((images * 0.5 + 0.5) * 255).type(torch.uint8)  # Unnormalize
        
    _, _, landmarks = mtcnn(images, landmarks=True)

    return landmarks


def extract_faces_and_landmarks(images: torch.Tensor, output_size=112, mtcnn: torch.nn.Module = None, refernce_points=REFERNCE_FACIAL_POINTS_RELATIVE):
    """

    detect faces in the images and crop them (in a differentiable way) to 112x112 using MTCNN.

    """
    images = Pad(200)(images)
    landmarks_batched = detect_face(images, mtcnn=mtcnn)
    affine_transformations = []
    invalid_indices = []
    for i, landmarks in enumerate(landmarks_batched):
        if landmarks is None:
            invalid_indices.append(i)
            affine_transformations.append(np.eye(2, 3).astype(np.float32))
        else:
            affine_transformations.append(get_similarity_transform_for_cv2(landmarks[0].astype(np.float32),
                                                                           refernce_points.astype(np.float32) * output_size))
    affine_transformations = torch.from_numpy(np.stack(affine_transformations).astype(np.float32)).to(device=images.device, dtype=torch.float32)

    invalid_indices = torch.tensor(invalid_indices).to(device=images.device)

    fp_images = images.to(torch.float32)
    return  warp_affine(fp_images, affine_transformations, dsize=(output_size, output_size)).to(dtype=images.dtype), invalid_indices