JustClothify / helpers /processor.py
Brasd99's picture
Оптимизация
26b0060
raw
history blame
6.99 kB
import cv2
import imageio
import numpy as np
import torch
from typing import Any, Dict
import io
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from densepose import add_densepose_config
from densepose.structures import (
DensePoseChartPredictorOutput,
DensePoseEmbeddingPredictorOutput
)
from densepose.vis.base import CompoundVisualizer
from densepose.vis.densepose_outputs_vertex import get_texture_atlases
from densepose.vis.densepose_results_textures import (
DensePoseResultsVisualizerWithTexture as dp_iuv_texture
)
from densepose.vis.extractor import (
CompoundExtractor,
create_extractor,
DensePoseOutputsExtractor,
DensePoseResultExtractor
)
class TextureProcessor:
def __init__(self, config, weights):
self.config = self.get_config(config, weights)
self.predictor = DefaultPredictor(self.config)
def process_texture(self, image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
output = self.execute(image)[0]
if 'pred_densepose' in output:
texture = self.create_iuv(output, image)
atlas_texture_bytes = io.BytesIO()
imageio.imwrite(atlas_texture_bytes, texture, format='PNG')
texture_atlas_array = np.frombuffer(atlas_texture_bytes.getvalue(), dtype=np.uint8)
texture_atlas = cv2.imdecode(texture_atlas_array, cv2.IMREAD_COLOR)
texture_atlas = cv2.cvtColor(texture_atlas, cv2.COLOR_BGR2RGB)
return texture_atlas
else:
raise Exception('Clothes not found')
def extract(self, person_img, model_img):
texture_atlas = self.process_texture(model_img)
return self.overlay_texture(texture_atlas, person_img)
def overlay_texture(self, texture_atlas, original_image):
texture_atlas[:, :, :3] = texture_atlas[:, :, 2::-1]
texture_atlases_dict = get_texture_atlases(None)
vis = dp_iuv_texture(
cfg=self.config,
texture_atlas=texture_atlas,
texture_atlases_dict=texture_atlases_dict
)
extractor = create_extractor(vis)
visualizer = CompoundVisualizer([vis])
extractor = CompoundExtractor([extractor])
with torch.no_grad():
outputs = self.predictor(original_image)["instances"]
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
data = extractor(outputs)
image_vis = visualizer.visualize(image, data)
return image_vis
def parse_iuv(self, result):
i = result['pred_densepose'][0].labels.cpu().numpy().astype(float)
uv = (result['pred_densepose'][0].uv.cpu().numpy() * 255.0).astype(float)
iuv = np.stack((uv[1, :, :], uv[0, :, :], i))
iuv = np.transpose(iuv, (1, 2, 0))
return iuv
def parse_bbox(self, result):
return result["pred_boxes_XYXY"][0].cpu().numpy()
def interpolate_tex(self, tex):
valid_mask = np.array((tex.sum(0) != 0) * 1, dtype='uint8')
radius_increase = 10
kernel = np.ones((radius_increase, radius_increase), np.uint8)
dilated_mask = cv2.dilate(valid_mask, kernel, iterations=1)
invalid_region = 1 - valid_mask
actual_part_max = tex.max()
actual_part_min = tex.min()
actual_part_uint = np.array(
(tex - actual_part_min) / (actual_part_max - actual_part_min) * 255, dtype='uint8')
actual_part_uint = cv2.inpaint(actual_part_uint.transpose((1, 2, 0)), invalid_region, 1,
cv2.INPAINT_TELEA).transpose((2, 0, 1))
actual_part = (actual_part_uint / 255.0) * \
(actual_part_max - actual_part_min) + actual_part_min
actual_part = actual_part * dilated_mask
return actual_part
def concat_textures(self, array):
texture = []
for i in range(4):
tmp = array[6 * i]
for j in range(6 * i + 1, 6 * i + 6):
tmp = np.concatenate((tmp, array[j]), axis=1)
texture = tmp if len(texture) == 0 else np.concatenate(
(texture, tmp), axis=0)
return texture
def get_texture(self, im, iuv, bbox, tex_part_size=200):
im = im.transpose(2, 1, 0) / 255
image_w, image_h = im.shape[1], im.shape[2]
bbox[2] = bbox[2] - bbox[0]
bbox[3] = bbox[3] - bbox[1]
x, y, w, h = [int(v) for v in bbox]
bg = np.zeros((image_h, image_w, 3))
bg[y:y + h, x:x + w, :] = iuv
iuv = bg
iuv = iuv.transpose((2, 1, 0))
i, u, v = iuv[2], iuv[1], iuv[0]
n_parts = 22
texture = np.zeros((n_parts, 3, tex_part_size, tex_part_size))
for part_id in range(1, n_parts + 1):
generated = np.zeros((3, tex_part_size, tex_part_size))
x, y = u[i == part_id], v[i == part_id]
tex_u_coo = (x * (tex_part_size - 1) / 255).astype(int)
tex_v_coo = (y * (tex_part_size - 1) / 255).astype(int)
tex_u_coo = np.clip(tex_u_coo, 0, tex_part_size - 1)
tex_v_coo = np.clip(tex_v_coo, 0, tex_part_size - 1)
for channel in range(3):
generated[channel][tex_v_coo,
tex_u_coo] = im[channel][i == part_id]
if np.sum(generated) > 0:
generated = self.interpolate_tex(generated)
texture[part_id - 1] = generated[:, ::-1, :]
tex_concat = np.zeros((24, tex_part_size, tex_part_size, 3))
for i in range(texture.shape[0]):
tex_concat[i] = texture[i].transpose(2, 1, 0)
tex = self.concat_textures(tex_concat)
return tex
def create_iuv(self, results, image):
iuv = self.parse_iuv(results)
bbox = self.parse_bbox(results)
uv_texture = self.get_texture(image, iuv, bbox)
uv_texture = uv_texture.transpose([1, 0, 2])
return uv_texture
def get_config(self, config_fpath, model_fpath):
cfg = get_cfg()
add_densepose_config(cfg)
cfg.merge_from_file(config_fpath)
cfg.MODEL.WEIGHTS = model_fpath
cfg.MODEL.DEVICE = "cpu"
cfg.freeze()
return cfg
def execute(self, image):
context = {'results': []}
with torch.no_grad():
outputs = self.predictor(image)["instances"]
self.execute_on_outputs(context, outputs)
return context["results"]
def execute_on_outputs(self, context: Dict[str, Any], outputs: Instances):
result = {}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result["pred_densepose"] = extractor(outputs)[0]
context["results"].append(result)