Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
import pathlib | |
import sys | |
import zipfile | |
import huggingface_hub | |
import numpy as np | |
import PIL.Image | |
import torch | |
sys.path.insert(0, 'Text2Human') | |
from models.sample_model import SampleFromPoseModel | |
from utils.language_utils import (generate_shape_attributes, | |
generate_texture_attributes) | |
from utils.options import dict_to_nonedict, parse | |
from utils.util import set_random_seed | |
COLOR_LIST = [ | |
(0, 0, 0), | |
(255, 250, 250), | |
(220, 220, 220), | |
(250, 235, 215), | |
(255, 250, 205), | |
(211, 211, 211), | |
(70, 130, 180), | |
(127, 255, 212), | |
(0, 100, 0), | |
(50, 205, 50), | |
(255, 255, 0), | |
(245, 222, 179), | |
(255, 140, 0), | |
(255, 0, 0), | |
(16, 78, 139), | |
(144, 238, 144), | |
(50, 205, 174), | |
(50, 155, 250), | |
(160, 140, 88), | |
(213, 140, 88), | |
(90, 140, 90), | |
(185, 210, 205), | |
(130, 165, 180), | |
(225, 141, 151), | |
] | |
class Model: | |
def __init__(self, device: str): | |
self.config = self._load_config() | |
self.config['device'] = device | |
self._download_models() | |
self.model = SampleFromPoseModel(self.config) | |
self.model.batch_size = 1 | |
def _load_config(self) -> dict: | |
path = 'Text2Human/configs/sample_from_pose.yml' | |
config = parse(path, is_train=False) | |
config = dict_to_nonedict(config) | |
return config | |
def _download_models(self) -> None: | |
model_dir = pathlib.Path('pretrained_models') | |
if model_dir.exists(): | |
return | |
token = os.getenv('HF_TOKEN') | |
path = huggingface_hub.hf_hub_download('yumingj/Text2Human_SSHQ', | |
'pretrained_models.zip', | |
use_auth_token=token) | |
model_dir.mkdir() | |
with zipfile.ZipFile(path) as f: | |
f.extractall(model_dir) | |
def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor: | |
image = np.array( | |
image.resize( | |
size=(256, 512), | |
resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose( | |
2, 0, 1).astype(np.float32) | |
image = image / 12. - 1 | |
data = torch.from_numpy(image).unsqueeze(1) | |
return data | |
def process_mask(mask: np.ndarray) -> np.ndarray: | |
if mask.shape != (512, 256, 3): | |
return None | |
seg_map = np.full(mask.shape[:-1], -1) | |
for index, color in enumerate(COLOR_LIST): | |
seg_map[np.sum(mask == color, axis=2) == 3] = index | |
if not (seg_map != -1).all(): | |
return None | |
return seg_map | |
# def process_mask(self, mask: np.ndarray) -> np.ndarray: | |
# if mask.shape != (512, 256, 3): | |
# return None | |
# seg_map = np.full(mask.shape[:-1], -1) | |
# for index, color in enumerate(COLOR_LIST): | |
# seg_map[np.sum(mask == color, axis=2) == 3] = index | |
# # 创建一个新的 3 通道图像用于输出结果 | |
# result = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) | |
# # 将匹配的像素分配对应的颜色 | |
# for index, color in enumerate(COLOR_LIST): | |
# result[seg_map == index] = color | |
# # 将未匹配的像素设置为白色 | |
# result[seg_map == -1] = (255, 250, 250) | |
# return result | |
def postprocess(result: torch.Tensor) -> np.ndarray: | |
result = result.permute(0, 2, 3, 1) | |
result = result.detach().cpu().numpy() | |
result = result * 255 | |
result = np.asarray(result[0, :, :, :], dtype=np.uint8) | |
return result | |
def process_pose_image(self, pose_image: PIL.Image.Image) -> torch.Tensor: | |
if pose_image is None: | |
return | |
data = self.preprocess_pose_image(pose_image) | |
self.model.feed_pose_data(data) | |
return data | |
def generate_label_image(self, pose_data: torch.Tensor, | |
shape_text: str) -> np.ndarray: | |
if pose_data is None: | |
return | |
self.model.feed_pose_data(pose_data) | |
shape_attributes = generate_shape_attributes(shape_text) | |
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0) | |
self.model.feed_shape_attributes(shape_attributes) | |
self.model.generate_parsing_map() | |
self.model.generate_quantized_segm() | |
colored_segm = self.model.palette_result(self.model.segm[0].cpu()) | |
return colored_segm | |
# def generate_human(self, label_image: np.ndarray, texture_text: str, | |
# sample_steps: int, seed: int) -> np.ndarray: | |
# if label_image is None: | |
# return | |
# mask = label_image.copy() | |
# seg_map = self.process_mask(mask) | |
# if seg_map is None: | |
# return | |
# self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze( | |
# 0).to(self.model.device) | |
# self.model.generate_quantized_segm() | |
# set_random_seed(seed) | |
# texture_attributes = generate_texture_attributes(texture_text) | |
# texture_attributes = torch.LongTensor(texture_attributes) | |
# self.model.feed_texture_attributes(texture_attributes) | |
# self.model.generate_texture_map() | |
# self.model.sample_steps = sample_steps | |
# out = self.model.sample_and_refine() | |
# res = self.postprocess(out) | |
# return res | |
def generate_human(self,pose_data,shape_text,texture_text,sample_steps,seed): | |
if pose_data is None: | |
return | |
self.model.feed_pose_data(pose_data) | |
shape_attributes = generate_shape_attributes(shape_text) | |
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0) | |
self.model.feed_shape_attributes(shape_attributes) | |
self.model.generate_parsing_map() | |
self.model.generate_quantized_segm() | |
set_random_seed(seed) | |
texture_attributes = generate_texture_attributes(texture_text) | |
texture_attributes = torch.LongTensor(texture_attributes) | |
self.model.feed_texture_attributes(texture_attributes) | |
self.model.generate_texture_map() | |
self.model.sample_steps = sample_steps | |
out = self.model.sample_and_refine() | |
res = self.postprocess(out) | |
return res | |
if __name__ == "__main__": | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = Model(device) | |
pose_image = PIL.Image.open("./001.png") | |
input_image=model.process_pose_image(pose_image) | |
shape_text = "A lady with a T-shirt and a skirt" | |
# res = model.generate_label_image(pose_data=input_image, shape_text=shape_text) | |
# # PIL.Image.SAVE(res, "result.png") | |
# im = PIL.Image.fromarray(res) | |
# im.save("label_image.jpg") | |
# print(res.shape) | |
all_res = model.generate_human(pose_data=input_image,shape_text=shape_text,texture_text="A lady with a T-shirt and a skirt",sample_steps=10,seed=0) | |
final_im = PIL.Image.fromarray(all_res) | |
final_im.save("final_image.jpg") | |
print(all_res.shape) |