import sys import cv2 import numpy as np import torch from PIL import Image from PyQt5.QtCore import * from PyQt5.QtGui import * from PyQt5.QtWidgets import * from models.sample_model import SampleFromPoseModel from ui.mouse_event import GraphicsScene from ui.ui import Ui_Form from utils.language_utils import (generate_shape_attributes, generate_texture_attributes) from utils.options import dict_to_nonedict, parse color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215), (255, 250, 205), (211, 211, 211), (70, 130, 180), (127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0), (245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139), (144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88), (213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180), (225, 141, 151)] class Ex(QWidget, Ui_Form): def __init__(self, opt): super(Ex, self).__init__() self.setupUi(self) self.show() self.output_img = None self.mat_img = None self.mode = 0 self.size = 6 self.mask = None self.mask_m = None self.img = None # about UI self.mouse_clicked = False self.scene = QGraphicsScene() self.graphicsView.setScene(self.scene) self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.ref_scene = GraphicsScene(self.mode, self.size) self.graphicsView_2.setScene(self.ref_scene) self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.result_scene = QGraphicsScene() self.graphicsView_3.setScene(self.result_scene) self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft) self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.dlg = QColorDialog(self.graphicsView) self.color = None self.sample_model = SampleFromPoseModel(opt) def open_densepose(self): fileName, _ = QFileDialog.getOpenFileName(self, "Open File", QDir.currentPath()) if fileName: image = QPixmap(fileName) mat_img = Image.open(fileName) self.pose_img = mat_img.copy() if image.isNull(): QMessageBox.information(self, "Image Viewer", "Cannot load %s." % fileName) return image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) if len(self.scene.items()) > 0: self.scene.removeItem(self.scene.items()[-1]) self.scene.addPixmap(image) self.ref_scene.clear() self.result_scene.clear() # load pose to model self.pose_img = np.array( self.pose_img.resize( size=(256, 512), resample=Image.LANCZOS))[:, :, 2:].transpose( 2, 0, 1).astype(np.float32) self.pose_img = self.pose_img / 12. - 1 self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1) self.sample_model.feed_pose_data(self.pose_img) def generate_parsing(self): self.ref_scene.reset_items() self.ref_scene.reset() shape_texts = self.message_box_1.text() shape_attributes = generate_shape_attributes(shape_texts) shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0) self.sample_model.feed_shape_attributes(shape_attributes) self.sample_model.generate_parsing_map() self.sample_model.generate_quantized_segm() self.colored_segm = self.sample_model.palette_result( self.sample_model.segm[0].cpu()) self.mask_m = cv2.cvtColor( cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2RGB) qim = QImage(self.colored_segm.data.tobytes(), self.colored_segm.shape[1], self.colored_segm.shape[0], QImage.Format_RGB888) image = QPixmap.fromImage(qim) image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) if len(self.ref_scene.items()) > 0: self.ref_scene.removeItem(self.ref_scene.items()[-1]) self.ref_scene.addPixmap(image) self.result_scene.clear() def generate_human(self): for i in range(24): self.mask_m = self.make_mask(self.mask_m, self.ref_scene.mask_points[i], self.ref_scene.size_points[i], color_list[i]) seg_map = np.full(self.mask_m.shape[:-1], -1) # convert rgb to num for index, color in enumerate(color_list): seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index assert (seg_map != -1).all() self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze( 0).unsqueeze(0).to(self.sample_model.device) self.sample_model.generate_quantized_segm() texture_texts = self.message_box_2.text() texture_attributes = generate_texture_attributes(texture_texts) texture_attributes = torch.LongTensor(texture_attributes) self.sample_model.feed_texture_attributes(texture_attributes) self.sample_model.generate_texture_map() result = self.sample_model.sample_and_refine() result = result.permute(0, 2, 3, 1) result = result.detach().cpu().numpy() result = result * 255 result = np.asarray(result[0, :, :, :], dtype=np.uint8) self.output_img = result qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0], QImage.Format_RGB888) image = QPixmap.fromImage(qim) image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) if len(self.result_scene.items()) > 0: self.result_scene.removeItem(self.result_scene.items()[-1]) self.result_scene.addPixmap(image) def top_mode(self): self.ref_scene.mode = 1 def skin_mode(self): self.ref_scene.mode = 15 def outer_mode(self): self.ref_scene.mode = 2 def face_mode(self): self.ref_scene.mode = 14 def skirt_mode(self): self.ref_scene.mode = 3 def hair_mode(self): self.ref_scene.mode = 13 def dress_mode(self): self.ref_scene.mode = 4 def headwear_mode(self): self.ref_scene.mode = 7 def pants_mode(self): self.ref_scene.mode = 5 def eyeglass_mode(self): self.ref_scene.mode = 8 def rompers_mode(self): self.ref_scene.mode = 21 def footwear_mode(self): self.ref_scene.mode = 11 def leggings_mode(self): self.ref_scene.mode = 6 def ring_mode(self): self.ref_scene.mode = 16 def belt_mode(self): self.ref_scene.mode = 10 def neckwear_mode(self): self.ref_scene.mode = 9 def wrist_mode(self): self.ref_scene.mode = 17 def socks_mode(self): self.ref_scene.mode = 18 def tie_mode(self): self.ref_scene.mode = 23 def earstuds_mode(self): self.ref_scene.mode = 22 def necklace_mode(self): self.ref_scene.mode = 20 def bag_mode(self): self.ref_scene.mode = 12 def glove_mode(self): self.ref_scene.mode = 19 def background_mode(self): self.ref_scene.mode = 0 def make_mask(self, mask, pts, sizes, color): if len(pts) > 0: for idx, pt in enumerate(pts): cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx]) return mask def save_img(self): if type(self.output_img): fileName, _ = QFileDialog.getSaveFileName(self, "Save File", QDir.currentPath()) cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1]) def undo(self): self.scene.undo() def clear(self): self.ref_scene.reset_items() self.ref_scene.reset() self.ref_scene.clear() self.result_scene.clear() if __name__ == '__main__': app = QApplication(sys.argv) opt = './configs/sample_from_pose.yml' opt = parse(opt, is_train=False) opt = dict_to_nonedict(opt) ex = Ex(opt) sys.exit(app.exec_())