Spaces:
Runtime error
Runtime error
import sys | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from PyQt5.QtCore import * | |
from PyQt5.QtGui import * | |
from PyQt5.QtWidgets import * | |
from models.sample_model import SampleFromPoseModel | |
from ui.mouse_event import GraphicsScene | |
from ui.ui import Ui_Form | |
from utils.language_utils import (generate_shape_attributes, | |
generate_texture_attributes) | |
from utils.options import dict_to_nonedict, parse | |
color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215), | |
(255, 250, 205), (211, 211, 211), (70, 130, 180), | |
(127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0), | |
(245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139), | |
(144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88), | |
(213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180), | |
(225, 141, 151)] | |
class Ex(QWidget, Ui_Form): | |
def __init__(self, opt): | |
super(Ex, self).__init__() | |
self.setupUi(self) | |
self.show() | |
self.output_img = None | |
self.mat_img = None | |
self.mode = 0 | |
self.size = 6 | |
self.mask = None | |
self.mask_m = None | |
self.img = None | |
# about UI | |
self.mouse_clicked = False | |
self.scene = QGraphicsScene() | |
self.graphicsView.setScene(self.scene) | |
self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft) | |
self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) | |
self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) | |
self.ref_scene = GraphicsScene(self.mode, self.size) | |
self.graphicsView_2.setScene(self.ref_scene) | |
self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft) | |
self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) | |
self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) | |
self.result_scene = QGraphicsScene() | |
self.graphicsView_3.setScene(self.result_scene) | |
self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft) | |
self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) | |
self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) | |
self.dlg = QColorDialog(self.graphicsView) | |
self.color = None | |
self.sample_model = SampleFromPoseModel(opt) | |
def open_densepose(self): | |
fileName, _ = QFileDialog.getOpenFileName(self, "Open File", | |
QDir.currentPath()) | |
if fileName: | |
image = QPixmap(fileName) | |
mat_img = Image.open(fileName) | |
self.pose_img = mat_img.copy() | |
if image.isNull(): | |
QMessageBox.information(self, "Image Viewer", | |
"Cannot load %s." % fileName) | |
return | |
image = image.scaled(self.graphicsView.size(), | |
Qt.IgnoreAspectRatio) | |
if len(self.scene.items()) > 0: | |
self.scene.removeItem(self.scene.items()[-1]) | |
self.scene.addPixmap(image) | |
self.ref_scene.clear() | |
self.result_scene.clear() | |
# load pose to model | |
self.pose_img = np.array( | |
self.pose_img.resize( | |
size=(256, 512), | |
resample=Image.LANCZOS))[:, :, 2:].transpose( | |
2, 0, 1).astype(np.float32) | |
self.pose_img = self.pose_img / 12. - 1 | |
self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1) | |
self.sample_model.feed_pose_data(self.pose_img) | |
def generate_parsing(self): | |
self.ref_scene.reset_items() | |
self.ref_scene.reset() | |
shape_texts = self.message_box_1.text() | |
shape_attributes = generate_shape_attributes(shape_texts) | |
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0) | |
self.sample_model.feed_shape_attributes(shape_attributes) | |
self.sample_model.generate_parsing_map() | |
self.sample_model.generate_quantized_segm() | |
self.colored_segm = self.sample_model.palette_result( | |
self.sample_model.segm[0].cpu()) | |
self.mask_m = cv2.cvtColor( | |
cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR), | |
cv2.COLOR_BGR2RGB) | |
qim = QImage(self.colored_segm.data.tobytes(), | |
self.colored_segm.shape[1], self.colored_segm.shape[0], | |
QImage.Format_RGB888) | |
image = QPixmap.fromImage(qim) | |
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) | |
if len(self.ref_scene.items()) > 0: | |
self.ref_scene.removeItem(self.ref_scene.items()[-1]) | |
self.ref_scene.addPixmap(image) | |
self.result_scene.clear() | |
def generate_human(self): | |
for i in range(24): | |
self.mask_m = self.make_mask(self.mask_m, | |
self.ref_scene.mask_points[i], | |
self.ref_scene.size_points[i], | |
color_list[i]) | |
seg_map = np.full(self.mask_m.shape[:-1], -1) | |
# convert rgb to num | |
for index, color in enumerate(color_list): | |
seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index | |
assert (seg_map != -1).all() | |
self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze( | |
0).unsqueeze(0).to(self.sample_model.device) | |
self.sample_model.generate_quantized_segm() | |
texture_texts = self.message_box_2.text() | |
texture_attributes = generate_texture_attributes(texture_texts) | |
texture_attributes = torch.LongTensor(texture_attributes) | |
self.sample_model.feed_texture_attributes(texture_attributes) | |
self.sample_model.generate_texture_map() | |
result = self.sample_model.sample_and_refine() | |
result = result.permute(0, 2, 3, 1) | |
result = result.detach().cpu().numpy() | |
result = result * 255 | |
result = np.asarray(result[0, :, :, :], dtype=np.uint8) | |
self.output_img = result | |
qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0], | |
QImage.Format_RGB888) | |
image = QPixmap.fromImage(qim) | |
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio) | |
if len(self.result_scene.items()) > 0: | |
self.result_scene.removeItem(self.result_scene.items()[-1]) | |
self.result_scene.addPixmap(image) | |
def top_mode(self): | |
self.ref_scene.mode = 1 | |
def skin_mode(self): | |
self.ref_scene.mode = 15 | |
def outer_mode(self): | |
self.ref_scene.mode = 2 | |
def face_mode(self): | |
self.ref_scene.mode = 14 | |
def skirt_mode(self): | |
self.ref_scene.mode = 3 | |
def hair_mode(self): | |
self.ref_scene.mode = 13 | |
def dress_mode(self): | |
self.ref_scene.mode = 4 | |
def headwear_mode(self): | |
self.ref_scene.mode = 7 | |
def pants_mode(self): | |
self.ref_scene.mode = 5 | |
def eyeglass_mode(self): | |
self.ref_scene.mode = 8 | |
def rompers_mode(self): | |
self.ref_scene.mode = 21 | |
def footwear_mode(self): | |
self.ref_scene.mode = 11 | |
def leggings_mode(self): | |
self.ref_scene.mode = 6 | |
def ring_mode(self): | |
self.ref_scene.mode = 16 | |
def belt_mode(self): | |
self.ref_scene.mode = 10 | |
def neckwear_mode(self): | |
self.ref_scene.mode = 9 | |
def wrist_mode(self): | |
self.ref_scene.mode = 17 | |
def socks_mode(self): | |
self.ref_scene.mode = 18 | |
def tie_mode(self): | |
self.ref_scene.mode = 23 | |
def earstuds_mode(self): | |
self.ref_scene.mode = 22 | |
def necklace_mode(self): | |
self.ref_scene.mode = 20 | |
def bag_mode(self): | |
self.ref_scene.mode = 12 | |
def glove_mode(self): | |
self.ref_scene.mode = 19 | |
def background_mode(self): | |
self.ref_scene.mode = 0 | |
def make_mask(self, mask, pts, sizes, color): | |
if len(pts) > 0: | |
for idx, pt in enumerate(pts): | |
cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx]) | |
return mask | |
def save_img(self): | |
if type(self.output_img): | |
fileName, _ = QFileDialog.getSaveFileName(self, "Save File", | |
QDir.currentPath()) | |
cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1]) | |
def undo(self): | |
self.scene.undo() | |
def clear(self): | |
self.ref_scene.reset_items() | |
self.ref_scene.reset() | |
self.ref_scene.clear() | |
self.result_scene.clear() | |
if __name__ == '__main__': | |
app = QApplication(sys.argv) | |
opt = './configs/sample_from_pose.yml' | |
opt = parse(opt, is_train=False) | |
opt = dict_to_nonedict(opt) | |
ex = Ex(opt) | |
sys.exit(app.exec_()) | |