from __future__ import annotations import os import pathlib import sys import zipfile import huggingface_hub import numpy as np import PIL.Image import torch sys.path.insert(0, 'Text2Human') from models.sample_model import SampleFromPoseModel from utils.language_utils import (generate_shape_attributes, generate_texture_attributes) from utils.options import dict_to_nonedict, parse from utils.util import set_random_seed COLOR_LIST = [ (0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215), (255, 250, 205), (211, 211, 211), (70, 130, 180), (127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0), (245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139), (144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88), (213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180), (225, 141, 151), ] class Model: def __init__(self, device: str): self.config = self._load_config() self.config['device'] = device self._download_models() self.model = SampleFromPoseModel(self.config) self.model.batch_size = 1 def _load_config(self) -> dict: path = 'Text2Human/configs/sample_from_pose.yml' config = parse(path, is_train=False) config = dict_to_nonedict(config) return config def _download_models(self) -> None: model_dir = pathlib.Path('pretrained_models') if model_dir.exists(): return token = os.getenv('HF_TOKEN') path = huggingface_hub.hf_hub_download('yumingj/Text2Human_SSHQ', 'pretrained_models.zip', use_auth_token=token) model_dir.mkdir() with zipfile.ZipFile(path) as f: f.extractall(model_dir) @staticmethod def preprocess_pose_image(image: PIL.Image.Image) -> torch.Tensor: image = np.array( image.resize( size=(256, 512), resample=PIL.Image.Resampling.LANCZOS))[:, :, 2:].transpose( 2, 0, 1).astype(np.float32) image = image / 12. - 1 data = torch.from_numpy(image).unsqueeze(1) return data @staticmethod def process_mask(mask: np.ndarray) -> np.ndarray: if mask.shape != (512, 256, 3): return None seg_map = np.full(mask.shape[:-1], -1) for index, color in enumerate(COLOR_LIST): seg_map[np.sum(mask == color, axis=2) == 3] = index if not (seg_map != -1).all(): return None return seg_map # def process_mask(self, mask: np.ndarray) -> np.ndarray: # if mask.shape != (512, 256, 3): # return None # seg_map = np.full(mask.shape[:-1], -1) # for index, color in enumerate(COLOR_LIST): # seg_map[np.sum(mask == color, axis=2) == 3] = index # # 创建一个新的 3 通道图像用于输出结果 # result = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) # # 将匹配的像素分配对应的颜色 # for index, color in enumerate(COLOR_LIST): # result[seg_map == index] = color # # 将未匹配的像素设置为白色 # result[seg_map == -1] = (255, 250, 250) # return result @staticmethod def postprocess(result: torch.Tensor) -> np.ndarray: result = result.permute(0, 2, 3, 1) result = result.detach().cpu().numpy() result = result * 255 result = np.asarray(result[0, :, :, :], dtype=np.uint8) return result def process_pose_image(self, pose_image: PIL.Image.Image) -> torch.Tensor: if pose_image is None: return data = self.preprocess_pose_image(pose_image) self.model.feed_pose_data(data) return data def generate_label_image(self, pose_data: torch.Tensor, shape_text: str) -> np.ndarray: if pose_data is None: return self.model.feed_pose_data(pose_data) shape_attributes = generate_shape_attributes(shape_text) shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0) self.model.feed_shape_attributes(shape_attributes) self.model.generate_parsing_map() self.model.generate_quantized_segm() colored_segm = self.model.palette_result(self.model.segm[0].cpu()) return colored_segm # def generate_human(self, label_image: np.ndarray, texture_text: str, # sample_steps: int, seed: int) -> np.ndarray: # if label_image is None: # return # mask = label_image.copy() # seg_map = self.process_mask(mask) # if seg_map is None: # return # self.model.segm = torch.from_numpy(seg_map).unsqueeze(0).unsqueeze( # 0).to(self.model.device) # self.model.generate_quantized_segm() # set_random_seed(seed) # texture_attributes = generate_texture_attributes(texture_text) # texture_attributes = torch.LongTensor(texture_attributes) # self.model.feed_texture_attributes(texture_attributes) # self.model.generate_texture_map() # self.model.sample_steps = sample_steps # out = self.model.sample_and_refine() # res = self.postprocess(out) # return res def generate_human(self,pose_data,shape_text,texture_text,sample_steps,seed): if pose_data is None: return self.model.feed_pose_data(pose_data) shape_attributes = generate_shape_attributes(shape_text) shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0) self.model.feed_shape_attributes(shape_attributes) self.model.generate_parsing_map() self.model.generate_quantized_segm() set_random_seed(seed) texture_attributes = generate_texture_attributes(texture_text) texture_attributes = torch.LongTensor(texture_attributes) self.model.feed_texture_attributes(texture_attributes) self.model.generate_texture_map() self.model.sample_steps = sample_steps out = self.model.sample_and_refine() res = self.postprocess(out) return res if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" model = Model(device) pose_image = PIL.Image.open("./001.png") input_image=model.process_pose_image(pose_image) shape_text = "A lady with a T-shirt and a skirt" # res = model.generate_label_image(pose_data=input_image, shape_text=shape_text) # # PIL.Image.SAVE(res, "result.png") # im = PIL.Image.fromarray(res) # im.save("label_image.jpg") # print(res.shape) all_res = model.generate_human(pose_data=input_image,shape_text=shape_text,texture_text="A lady with a T-shirt and a skirt",sample_steps=10,seed=0) final_im = PIL.Image.fromarray(all_res) final_im.save("final_image.jpg") print(all_res.shape)