File size: 2,259 Bytes
24be7a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-

# @title Load modules
import os
import random

import numpy as np
import torch
from IPython.display import display
from PIL import Image

from model import Model

# @title Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model(device)
print(
    "Model loaded. Parameters:",
    sum(
        x.numel()
        for y in [
            model.model.shape_attr_embedder.parameters(),
            model.model.shape_parsing_encoder.parameters(),
            model.model.shape_parsing_decoder.parameters(),
        ]
        for x in y
    ),
)

# @title Patch PIL
from collections import namedtuple

Image.Resampling = namedtuple("Patch", ["LANCZOS"])(Image.LANCZOS)

"""# Usage"""

# @title Generation parameters
# @markdown Can be a URL or a file link (if you upload your own image)
pose_image = Image.open("./001.png")
# @markdown Shape text for the general shape, texture text for the color texture
shape_text = "A lady with a T-shirt and a skirt"  # @param {type: "string"}
texture_text = "Lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt."  # @param {type: "string"}
steps = 50  # @param {type: "slider", min: 10, max:300, step: 10}

seed = -1  # @param {type: "integer"}
if seed == -1:
    seed = random.getrandbits(16)
print("Seed:", seed)

# %%time
# @title Generate label image
print("Pose image:")
display(pose_image)
print(type(pose_image))
print(pose_image.size)
print("Shape description:", shape_text)
label_image = model.generate_label_image(
    pose_data=model.process_pose_image(pose_image), shape_text=shape_text
)
print("Label image:")
print(np.sum(label_image == -1))
display(Image.fromarray(label_image).resize((128, 256)))

# Commented out IPython magic to ensure Python compatibility.
# %%time
# #@title Generate human image
# print("Label mask:")
# display(Image.fromarray(label_image).resize((128, 256)))
# print("Texture text:", texture_text)
# print("Generation steps:", steps)
# result = model.generate_human(label_image=label_image,
#                               texture_text=texture_text,
#                               sample_steps=steps,
#                               seed=0)
# print("Resulting image:")
# display(Image.fromarray(result))