# -*- coding: utf-8 -*- # @title Load modules import os import random import numpy as np import torch from IPython.display import display from PIL import Image from model import Model # @title Load model device = "cuda" if torch.cuda.is_available() else "cpu" model = Model(device) print( "Model loaded. Parameters:", sum( x.numel() for y in [ model.model.shape_attr_embedder.parameters(), model.model.shape_parsing_encoder.parameters(), model.model.shape_parsing_decoder.parameters(), ] for x in y ), ) # @title Patch PIL from collections import namedtuple Image.Resampling = namedtuple("Patch", ["LANCZOS"])(Image.LANCZOS) """# Usage""" # @title Generation parameters # @markdown Can be a URL or a file link (if you upload your own image) pose_image = Image.open("./001.png") # @markdown Shape text for the general shape, texture text for the color texture shape_text = "A lady with a T-shirt and a skirt" # @param {type: "string"} texture_text = "Lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt." # @param {type: "string"} steps = 50 # @param {type: "slider", min: 10, max:300, step: 10} seed = -1 # @param {type: "integer"} if seed == -1: seed = random.getrandbits(16) print("Seed:", seed) # %%time # @title Generate label image print("Pose image:") display(pose_image) print(type(pose_image)) print(pose_image.size) print("Shape description:", shape_text) label_image = model.generate_label_image( pose_data=model.process_pose_image(pose_image), shape_text=shape_text ) print("Label image:") print(np.sum(label_image == -1)) display(Image.fromarray(label_image).resize((128, 256))) # Commented out IPython magic to ensure Python compatibility. # %%time # #@title Generate human image # print("Label mask:") # display(Image.fromarray(label_image).resize((128, 256))) # print("Texture text:", texture_text) # print("Generation steps:", steps) # result = model.generate_human(label_image=label_image, # texture_text=texture_text, # sample_steps=steps, # seed=0) # print("Resulting image:") # display(Image.fromarray(result))