Text-human / model.py
yitianlian's picture
update demo
24be7a2
raw
history blame
7.14 kB
from __future__ import annotations
import os
import pathlib
import sys
import zipfile
import huggingface_hub
import numpy as np
import PIL.Image
import torch
sys.path.insert(0, 'Text2Human')
from models.sample_model import SampleFromPoseModel
from utils.language_utils import (generate_shape_attributes,
generate_texture_attributes)
from utils.options import dict_to_nonedict, parse
from utils.util import set_random_seed
COLOR_LIST = [
(0, 0, 0),
(255, 250, 250),
(220, 220, 220),
(250, 235, 215),
(255, 250, 205),
(211, 211, 211),
(70, 130, 180),
(127, 255, 212),
(0, 100, 0),
(50, 205, 50),
(255, 255, 0),
(245, 222, 179),
(255, 140, 0),
(255, 0, 0),
(16, 78, 139),
(144, 238, 144),
(50, 205, 174),
(50, 155, 250),
(160, 140, 88),
(213, 140, 88),
(90, 140, 90),
(185, 210, 205),
(130, 165, 180),
(225, 141, 151),
]
class Model:
def __init__(self, device: str):
self.config = self._load_config()
self.config['device'] = device
self._download_models()
self.model = SampleFromPoseModel(self.config)
self.model.batch_size = 1
def _load_config(self) -> dict:
path = 'Text2Human/configs/sample_from_pose.yml'
config = parse(path, is_train=False)
config = dict_to_nonedict(config)
return config
def _download_models(self) -> None:
model_dir = pathlib.Path('pretrained_models')
if model_dir.exists():
return
token = os.getenv('HF_TOKEN')
path = huggingface_hub.hf_hub_download('yumingj/Text2Human_SSHQ',
'pretrained_models.zip',
use_auth_token=token)
model_dir.mkdir()
with zipfile.ZipFile(path) as f:
f.extractall(model_dir)
@staticmethod
def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor:
image = np.array(
image.resize(
size=(256, 512),
resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose(
2, 0, 1).astype(np.float32)
image = image / 12. - 1
data = torch.from_numpy(image).unsqueeze(1)
return data
@staticmethod
def process_mask(mask: np.ndarray) -> np.ndarray:
if mask.shape != (512, 256, 3):
return None
seg_map = np.full(mask.shape[:-1], -1)
for index, color in enumerate(COLOR_LIST):
seg_map[np.sum(mask == color, axis=2) == 3] = index
if not (seg_map != -1).all():
return None
return seg_map
# def process_mask(self, mask: np.ndarray) -> np.ndarray:
# if mask.shape != (512, 256, 3):
# return None
# seg_map = np.full(mask.shape[:-1], -1)
# for index, color in enumerate(COLOR_LIST):
# seg_map[np.sum(mask == color, axis=2) == 3] = index
# # 创建一个新的 3 通道图像用于输出结果
# result = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
# # 将匹配的像素分配对应的颜色
# for index, color in enumerate(COLOR_LIST):
# result[seg_map == index] = color
# # 将未匹配的像素设置为白色
# result[seg_map == -1] = (255, 250, 250)
# return result
@staticmethod
def postprocess(result: torch.Tensor) -> np.ndarray:
result = result.permute(0, 2, 3, 1)
result = result.detach().cpu().numpy()
result = result * 255
result = np.asarray(result[0, :, :, :], dtype=np.uint8)
return result
def process_pose_image(self, pose_image: PIL.Image.Image) -> torch.Tensor:
if pose_image is None:
return
data = self.preprocess_pose_image(pose_image)
self.model.feed_pose_data(data)
return data
def generate_label_image(self, pose_data: torch.Tensor,
shape_text: str) -> np.ndarray:
if pose_data is None:
return
self.model.feed_pose_data(pose_data)
shape_attributes = generate_shape_attributes(shape_text)
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
self.model.feed_shape_attributes(shape_attributes)
self.model.generate_parsing_map()
self.model.generate_quantized_segm()
colored_segm = self.model.palette_result(self.model.segm[0].cpu())
return colored_segm
# def generate_human(self, label_image: np.ndarray, texture_text: str,
# sample_steps: int, seed: int) -> np.ndarray:
# if label_image is None:
# return
# mask = label_image.copy()
# seg_map = self.process_mask(mask)
# if seg_map is None:
# return
# self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze(
# 0).to(self.model.device)
# self.model.generate_quantized_segm()
# set_random_seed(seed)
# texture_attributes = generate_texture_attributes(texture_text)
# texture_attributes = torch.LongTensor(texture_attributes)
# self.model.feed_texture_attributes(texture_attributes)
# self.model.generate_texture_map()
# self.model.sample_steps = sample_steps
# out = self.model.sample_and_refine()
# res = self.postprocess(out)
# return res
def generate_human(self,pose_data,shape_text,texture_text,sample_steps,seed):
if pose_data is None:
return
self.model.feed_pose_data(pose_data)
shape_attributes = generate_shape_attributes(shape_text)
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
self.model.feed_shape_attributes(shape_attributes)
self.model.generate_parsing_map()
self.model.generate_quantized_segm()
set_random_seed(seed)
texture_attributes = generate_texture_attributes(texture_text)
texture_attributes = torch.LongTensor(texture_attributes)
self.model.feed_texture_attributes(texture_attributes)
self.model.generate_texture_map()
self.model.sample_steps = sample_steps
out = self.model.sample_and_refine()
res = self.postprocess(out)
return res
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model(device)
pose_image = PIL.Image.open("./001.png")
input_image=model.process_pose_image(pose_image)
shape_text = "A lady with a T-shirt and a skirt"
# res = model.generate_label_image(pose_data=input_image, shape_text=shape_text)
# # PIL.Image.SAVE(res, "result.png")
# im = PIL.Image.fromarray(res)
# im.save("label_image.jpg")
# print(res.shape)
all_res = model.generate_human(pose_data=input_image,shape_text=shape_text,texture_text="A lady with a T-shirt and a skirt",sample_steps=10,seed=0)
final_im = PIL.Image.fromarray(all_res)
final_im.save("final_image.jpg")
print(all_res.shape)