|
from __future__ import annotations |
|
|
|
import argparse |
|
import os |
|
import pathlib |
|
import subprocess |
|
import sys |
|
from typing import Callable, Union |
|
|
|
import dlib |
|
import huggingface_hub |
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
|
|
if os.getenv("SYSTEM") == "spaces" and not torch.cuda.is_available(): |
|
with open("patch.e4e") as f: |
|
subprocess.run("patch -p1".split(), cwd="encoder4editing", stdin=f) |
|
with open("patch.hairclip") as f: |
|
subprocess.run("patch -p1".split(), cwd="HairCLIP", stdin=f) |
|
|
|
app_dir = pathlib.Path(__file__).parent |
|
|
|
e4e_dir = app_dir / "encoder4editing" |
|
sys.path.insert(0, e4e_dir.as_posix()) |
|
|
|
from models.psp import pSp |
|
from utils.alignment import align_face |
|
|
|
hairclip_dir = app_dir / "HairCLIP" |
|
mapper_dir = hairclip_dir / "mapper" |
|
sys.path.insert(0, hairclip_dir.as_posix()) |
|
sys.path.insert(0, mapper_dir.as_posix()) |
|
|
|
from mapper.datasets.latents_dataset_inference import LatentsDatasetInference |
|
from mapper.hairclip_mapper import HairCLIPMapper |
|
|
|
|
|
class Model: |
|
def __init__(self): |
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
self.landmark_model = self._create_dlib_landmark_model() |
|
self.e4e = self._load_e4e() |
|
self.hairclip = self._load_hairclip() |
|
self.transform = self._create_transform() |
|
|
|
@staticmethod |
|
def _create_dlib_landmark_model(): |
|
path = huggingface_hub.hf_hub_download( |
|
"public-data/dlib_face_landmark_model", "shape_predictor_68_face_landmarks.dat" |
|
) |
|
return dlib.shape_predictor(path) |
|
|
|
def _load_e4e(self) -> nn.Module: |
|
ckpt_path = huggingface_hub.hf_hub_download("public-data/e4e", "e4e_ffhq_encode.pt") |
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
opts = ckpt["opts"] |
|
opts["device"] = self.device.type |
|
opts["checkpoint_path"] = ckpt_path |
|
opts = argparse.Namespace(**opts) |
|
model = pSp(opts) |
|
model.to(self.device) |
|
model.eval() |
|
return model |
|
|
|
def _load_hairclip(self) -> nn.Module: |
|
ckpt_path = huggingface_hub.hf_hub_download("public-data/HairCLIP", "hairclip.pt") |
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
opts = ckpt["opts"] |
|
opts["device"] = self.device.type |
|
opts["checkpoint_path"] = ckpt_path |
|
opts["editing_type"] = "both" |
|
opts["input_type"] = "text" |
|
opts["hairstyle_description"] = "HairCLIP/mapper/hairstyle_list.txt" |
|
opts["color_description"] = "red" |
|
opts = argparse.Namespace(**opts) |
|
model = HairCLIPMapper(opts) |
|
model.to(self.device) |
|
model.eval() |
|
return model |
|
|
|
@staticmethod |
|
def _create_transform() -> Callable: |
|
transform = T.Compose( |
|
[ |
|
T.Resize(256), |
|
T.CenterCrop(256), |
|
T.ToTensor(), |
|
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
] |
|
) |
|
return transform |
|
|
|
def detect_and_align_face(self, image: str) -> PIL.Image.Image: |
|
image = align_face(filepath=image, predictor=self.landmark_model) |
|
return image |
|
|
|
@staticmethod |
|
def denormalize(tensor: torch.Tensor) -> torch.Tensor: |
|
return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8) |
|
|
|
def postprocess(self, tensor: torch.Tensor) -> np.ndarray: |
|
tensor = self.denormalize(tensor) |
|
return tensor.cpu().numpy().transpose(1, 2, 0) |
|
|
|
@torch.inference_mode() |
|
def reconstruct_face(self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]: |
|
input_data = self.transform(image).unsqueeze(0).to(self.device) |
|
reconstructed_images, latents = self.e4e(input_data, randomize_noise=False, return_latents=True) |
|
reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1) |
|
reconstructed = self.postprocess(reconstructed) |
|
return reconstructed, latents[0] |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, editing_type: str, hairstyle_index: int, color_description: str, latent: torch.Tensor |
|
) -> np.ndarray: |
|
opts = self.hairclip.opts |
|
opts.editing_type = editing_type |
|
opts.color_description = color_description |
|
|
|
if editing_type == "color": |
|
hairstyle_index = 0 |
|
|
|
device = torch.device(opts.device) |
|
|
|
dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(), opts=opts) |
|
w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3] |
|
|
|
w = w.unsqueeze(0).to(device) |
|
hairstyle_text_inputs = hairstyle_text_inputs_list[hairstyle_index].unsqueeze(0).to(device) |
|
color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device) |
|
|
|
hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device) |
|
color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device) |
|
|
|
w_hat = w + 0.1 * self.hairclip.mapper( |
|
w, |
|
hairstyle_text_inputs, |
|
color_text_inputs, |
|
hairstyle_tensor_hairmasked, |
|
color_tensor_hairmasked, |
|
) |
|
x_hat, _ = self.hairclip.decoder( |
|
[w_hat], |
|
input_is_latent=True, |
|
return_latents=True, |
|
randomize_noise=False, |
|
truncation=1, |
|
) |
|
res = torch.clamp(x_hat[0].detach(), -1, 1) |
|
res = self.postprocess(res) |
|
return res |
|
|