JustClothify / helpers /processor.py
Brasd99's picture
Refactoring
10076ea
raw
history blame
No virus
6.54 kB
import io
import cv2
import imageio
import numpy as np
import torch
from typing import Dict, List
from fvcore.common.config import CfgNode
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from densepose import add_densepose_config
from densepose.vis.base import CompoundVisualizer
from densepose.vis.densepose_outputs_vertex import get_texture_atlases
from densepose.vis.densepose_results_textures import DensePoseResultsVisualizerWithTexture as dp_iuv_texture
from densepose.vis.extractor import CompoundExtractor, create_extractor, DensePoseResultExtractor
class TextureProcessor:
def __init__(self, config: str, weights: str) -> None:
self.config = self.get_config(config, weights)
self.predictor = DefaultPredictor(self.config)
self.extractor = DensePoseResultExtractor()
def process_texture(self, image: np.ndarray) -> np.ndarray:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
output = self.execute(image)
if 'pred_densepose' in output:
texture = self.create_iuv(output, image)
atlas_texture_bytes = io.BytesIO()
imageio.imwrite(atlas_texture_bytes, texture, format='PNG')
texture_atlas_array = np.frombuffer(atlas_texture_bytes.getvalue(), dtype=np.uint8)
texture_atlas = cv2.imdecode(texture_atlas_array, cv2.IMREAD_COLOR)
texture_atlas = cv2.cvtColor(texture_atlas, cv2.COLOR_BGR2RGB)
return texture_atlas
else:
raise Exception('Clothes not found')
def extract(self, person_img, model_img):
texture_atlas = self.process_texture(model_img)
return self.overlay_texture(texture_atlas, person_img)
def overlay_texture(self, texture_atlas: np.ndarray, original_image: np.ndarray) -> np.ndarray:
texture_atlas[:, :, :3] = texture_atlas[:, :, 2::-1]
texture_atlases_dict = get_texture_atlases(None)
vis = dp_iuv_texture(
cfg=self.config,
texture_atlas=texture_atlas,
texture_atlases_dict=texture_atlases_dict
)
extractor = create_extractor(vis)
visualizer = CompoundVisualizer([vis])
extractor = CompoundExtractor([extractor])
with torch.no_grad():
outputs = self.predictor(original_image)['instances']
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
data = extractor(outputs)
image_vis = visualizer.visualize(image, data)
return image_vis
def parse_iuv(self, result: Dict) -> np.ndarray:
i = result['pred_densepose'][0].labels.cpu().numpy().astype(float)
uv = (result['pred_densepose'][0].uv.cpu().numpy() * 255.0).astype(float)
iuv = np.stack((uv[1, :, :], uv[0, :, :], i))
iuv = np.transpose(iuv, (1, 2, 0))
return iuv
def parse_bbox(self, result: Dict) -> np.ndarray:
return result['pred_boxes_XYXY'][0].cpu().numpy()
def interpolate_tex(self, tex: np.ndarray) -> np.ndarray:
valid_mask = np.array((tex.sum(0) != 0) * 1, dtype='uint8')
radius_increase = 10
kernel = np.ones((radius_increase, radius_increase), np.uint8)
dilated_mask = cv2.dilate(valid_mask, kernel, iterations=1)
invalid_region = 1 - valid_mask
actual_part_max = tex.max()
actual_part_min = tex.min()
actual_part_uint = np.array((tex - actual_part_min) / (actual_part_max - actual_part_min) * 255, dtype='uint8')
actual_part_uint = cv2.inpaint(actual_part_uint.transpose((1, 2, 0)), invalid_region, 1, cv2.INPAINT_TELEA).transpose((2, 0, 1))
actual_part = (actual_part_uint / 255.0) * (actual_part_max - actual_part_min) + actual_part_min
actual_part = actual_part * dilated_mask
return actual_part
def concat_textures(self, array: List[np.ndarray]) -> np.ndarray:
texture_rows = [np.concatenate(array[i:i+6], axis=1) for i in range(0, 24, 6)]
texture = np.concatenate(texture_rows, axis=0)
return texture
def get_texture(
self,
im: np.ndarray,
iuv: np.ndarray,
bbox: List[int],
tex_part_size: int = 200) -> np.ndarray:
im = im.transpose(2, 1, 0) / 255
image_w, image_h = im.shape[1], im.shape[2]
bbox[2] = bbox[2] - bbox[0]
bbox[3] = bbox[3] - bbox[1]
x, y, w, h = [int(v) for v in bbox]
bg = np.zeros((image_h, image_w, 3))
bg[y:y + h, x:x + w, :] = iuv
iuv = bg
iuv = iuv.transpose((2, 1, 0))
i, u, v = iuv[2], iuv[1], iuv[0]
n_parts = 22
texture = np.zeros((n_parts, 3, tex_part_size, tex_part_size))
for part_id in range(1, n_parts + 1):
generated = np.zeros((3, tex_part_size, tex_part_size))
x, y = u[i == part_id], v[i == part_id]
tex_u_coo = (x * (tex_part_size - 1) / 255).astype(int)
tex_v_coo = (y * (tex_part_size - 1) / 255).astype(int)
tex_u_coo = np.clip(tex_u_coo, 0, tex_part_size - 1)
tex_v_coo = np.clip(tex_v_coo, 0, tex_part_size - 1)
for channel in range(3):
generated[channel][tex_v_coo, tex_u_coo] = im[channel][i == part_id]
if np.sum(generated) > 0:
generated = self.interpolate_tex(generated)
texture[part_id - 1] = generated[:, ::-1, :]
tex_concat = np.zeros((24, tex_part_size, tex_part_size, 3))
for i in range(texture.shape[0]):
tex_concat[i] = texture[i].transpose(2, 1, 0)
tex = self.concat_textures(tex_concat)
return tex
def create_iuv(self, results: Dict, image: np.ndarray) -> np.ndarray:
iuv = self.parse_iuv(results)
bbox = self.parse_bbox(results)
uv_texture = self.get_texture(image, iuv, bbox)
uv_texture = uv_texture.transpose([1, 0, 2])
return uv_texture
def get_config(self, config_fpath: str, model_fpath: str) -> CfgNode:
cfg = get_cfg()
add_densepose_config(cfg)
cfg.merge_from_file(config_fpath)
cfg.MODEL.WEIGHTS = model_fpath
cfg.MODEL.DEVICE = 'cpu'
cfg.freeze()
return cfg
def execute(self, image: np.ndarray) -> Dict:
with torch.no_grad():
outputs = self.predictor(image)['instances']
return self.execute_on_outputs(outputs)
def execute_on_outputs(self, outputs: Instances) -> Dict:
result = {}
if outputs.has('scores'):
result['scores'] = outputs.get('scores').cpu()
if outputs.has('pred_boxes'):
result['pred_boxes_XYXY'] = outputs.get('pred_boxes').tensor.cpu()
if outputs.has('pred_densepose'):
result['pred_densepose'] = self.extractor(outputs)[0]
return result