JustClothify / helpers /processor.py
Brasd99's picture
Switching from double quotes to single quotes
f04a56f
raw
history blame
6.99 kB
import cv2
import imageio
import numpy as np
import torch
from typing import Any, Dict
import io
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from densepose import add_densepose_config
from densepose.structures import (
DensePoseChartPredictorOutput,
DensePoseEmbeddingPredictorOutput
)
from densepose.vis.base import CompoundVisualizer
from densepose.vis.densepose_outputs_vertex import get_texture_atlases
from densepose.vis.densepose_results_textures import (
DensePoseResultsVisualizerWithTexture as dp_iuv_texture
)
from densepose.vis.extractor import (
CompoundExtractor,
create_extractor,
DensePoseOutputsExtractor,
DensePoseResultExtractor
)
class TextureProcessor:
def __init__(self, config, weights):
self.config = self.get_config(config, weights)
self.predictor = DefaultPredictor(self.config)
def process_texture(self, image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
output = self.execute(image)[0]
if 'pred_densepose' in output:
texture = self.create_iuv(output, image)
atlas_texture_bytes = io.BytesIO()
imageio.imwrite(atlas_texture_bytes, texture, format='PNG')
texture_atlas_array = np.frombuffer(atlas_texture_bytes.getvalue(), dtype=np.uint8)
texture_atlas = cv2.imdecode(texture_atlas_array, cv2.IMREAD_COLOR)
texture_atlas = cv2.cvtColor(texture_atlas, cv2.COLOR_BGR2RGB)
return texture_atlas
else:
raise Exception('Clothes not found')
def extract(self, person_img, model_img):
texture_atlas = self.process_texture(model_img)
return self.overlay_texture(texture_atlas, person_img)
def overlay_texture(self, texture_atlas, original_image):
texture_atlas[:, :, :3] = texture_atlas[:, :, 2::-1]
texture_atlases_dict = get_texture_atlases(None)
vis = dp_iuv_texture(
cfg=self.config,
texture_atlas=texture_atlas,
texture_atlases_dict=texture_atlases_dict
)
extractor = create_extractor(vis)
visualizer = CompoundVisualizer([vis])
extractor = CompoundExtractor([extractor])
with torch.no_grad():
outputs = self.predictor(original_image)['instances']
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
data = extractor(outputs)
image_vis = visualizer.visualize(image, data)
return image_vis
def parse_iuv(self, result):
i = result['pred_densepose'][0].labels.cpu().numpy().astype(float)
uv = (result['pred_densepose'][0].uv.cpu().numpy() * 255.0).astype(float)
iuv = np.stack((uv[1, :, :], uv[0, :, :], i))
iuv = np.transpose(iuv, (1, 2, 0))
return iuv
def parse_bbox(self, result):
return result['pred_boxes_XYXY'][0].cpu().numpy()
def interpolate_tex(self, tex):
valid_mask = np.array((tex.sum(0) != 0) * 1, dtype='uint8')
radius_increase = 10
kernel = np.ones((radius_increase, radius_increase), np.uint8)
dilated_mask = cv2.dilate(valid_mask, kernel, iterations=1)
invalid_region = 1 - valid_mask
actual_part_max = tex.max()
actual_part_min = tex.min()
actual_part_uint = np.array(
(tex - actual_part_min) / (actual_part_max - actual_part_min) * 255, dtype='uint8')
actual_part_uint = cv2.inpaint(actual_part_uint.transpose((1, 2, 0)), invalid_region, 1,
cv2.INPAINT_TELEA).transpose((2, 0, 1))
actual_part = (actual_part_uint / 255.0) * \
(actual_part_max - actual_part_min) + actual_part_min
actual_part = actual_part * dilated_mask
return actual_part
def concat_textures(self, array):
texture = []
for i in range(4):
tmp = array[6 * i]
for j in range(6 * i + 1, 6 * i + 6):
tmp = np.concatenate((tmp, array[j]), axis=1)
texture = tmp if len(texture) == 0 else np.concatenate(
(texture, tmp), axis=0)
return texture
def get_texture(self, im, iuv, bbox, tex_part_size=200):
im = im.transpose(2, 1, 0) / 255
image_w, image_h = im.shape[1], im.shape[2]
bbox[2] = bbox[2] - bbox[0]
bbox[3] = bbox[3] - bbox[1]
x, y, w, h = [int(v) for v in bbox]
bg = np.zeros((image_h, image_w, 3))
bg[y:y + h, x:x + w, :] = iuv
iuv = bg
iuv = iuv.transpose((2, 1, 0))
i, u, v = iuv[2], iuv[1], iuv[0]
n_parts = 22
texture = np.zeros((n_parts, 3, tex_part_size, tex_part_size))
for part_id in range(1, n_parts + 1):
generated = np.zeros((3, tex_part_size, tex_part_size))
x, y = u[i == part_id], v[i == part_id]
tex_u_coo = (x * (tex_part_size - 1) / 255).astype(int)
tex_v_coo = (y * (tex_part_size - 1) / 255).astype(int)
tex_u_coo = np.clip(tex_u_coo, 0, tex_part_size - 1)
tex_v_coo = np.clip(tex_v_coo, 0, tex_part_size - 1)
for channel in range(3):
generated[channel][tex_v_coo,
tex_u_coo] = im[channel][i == part_id]
if np.sum(generated) > 0:
generated = self.interpolate_tex(generated)
texture[part_id - 1] = generated[:, ::-1, :]
tex_concat = np.zeros((24, tex_part_size, tex_part_size, 3))
for i in range(texture.shape[0]):
tex_concat[i] = texture[i].transpose(2, 1, 0)
tex = self.concat_textures(tex_concat)
return tex
def create_iuv(self, results, image):
iuv = self.parse_iuv(results)
bbox = self.parse_bbox(results)
uv_texture = self.get_texture(image, iuv, bbox)
uv_texture = uv_texture.transpose([1, 0, 2])
return uv_texture
def get_config(self, config_fpath, model_fpath):
cfg = get_cfg()
add_densepose_config(cfg)
cfg.merge_from_file(config_fpath)
cfg.MODEL.WEIGHTS = model_fpath
cfg.MODEL.DEVICE = 'cpu'
cfg.freeze()
return cfg
def execute(self, image):
context = {'results': []}
with torch.no_grad():
outputs = self.predictor(image)['instances']
self.execute_on_outputs(context, outputs)
return context['results']
def execute_on_outputs(self, context: Dict[str, Any], outputs: Instances):
result = {}
if outputs.has('scores'):
result['scores'] = outputs.get('scores').cpu()
if outputs.has('pred_boxes'):
result['pred_boxes_XYXY'] = outputs.get('pred_boxes').tensor.cpu()
if outputs.has('pred_densepose'):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result['pred_densepose'] = extractor(outputs)[0]
context['results'].append(result)