Text-human / demo.py
yitianlian's picture
update demo
24be7a2
# -*- coding: utf-8 -*-
# @title Load modules
import os
import random
import numpy as np
import torch
from IPython.display import display
from PIL import Image
from model import Model
# @title Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model(device)
print(
"Model loaded. Parameters:",
sum(
x.numel()
for y in [
model.model.shape_attr_embedder.parameters(),
model.model.shape_parsing_encoder.parameters(),
model.model.shape_parsing_decoder.parameters(),
]
for x in y
),
)
# @title Patch PIL
from collections import namedtuple
Image.Resampling = namedtuple("Patch", ["LANCZOS"])(Image.LANCZOS)
"""# Usage"""
# @title Generation parameters
# @markdown Can be a URL or a file link (if you upload your own image)
pose_image = Image.open("./001.png")
# @markdown Shape text for the general shape, texture text for the color texture
shape_text = "A lady with a T-shirt and a skirt" # @param {type: "string"}
texture_text = "Lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt." # @param {type: "string"}
steps = 50 # @param {type: "slider", min: 10, max:300, step: 10}
seed = -1 # @param {type: "integer"}
if seed == -1:
seed = random.getrandbits(16)
print("Seed:", seed)
# %%time
# @title Generate label image
print("Pose image:")
display(pose_image)
print(type(pose_image))
print(pose_image.size)
print("Shape description:", shape_text)
label_image = model.generate_label_image(
pose_data=model.process_pose_image(pose_image), shape_text=shape_text
)
print("Label image:")
print(np.sum(label_image == -1))
display(Image.fromarray(label_image).resize((128, 256)))
# Commented out IPython magic to ensure Python compatibility.
# %%time
# #@title Generate human image
# print("Label mask:")
# display(Image.fromarray(label_image).resize((128, 256)))
# print("Texture text:", texture_text)
# print("Generation steps:", steps)
# result = model.generate_human(label_image=label_image,
# texture_text=texture_text,
# sample_steps=steps,
# seed=0)
# print("Resulting image:")
# display(Image.fromarray(result))