File size: 3,971 Bytes
b3478e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import cv2
import os
import torch
from pytorch_lightning import seed_everything
from torch import autocast

from basicsr.utils import tensor2img
from ldm.inference_base import diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models
from ldm.modules.extra_condition import api
from ldm.modules.extra_condition.api import ExtraCondition, get_adapter_feature, get_cond_model

torch.set_grad_enabled(False)


def main():
    supported_cond = [e.name for e in ExtraCondition]
    parser = get_base_argument_parser()
    for cond_name in supported_cond:
        parser.add_argument(
            f'--{cond_name}_path',
            type=str,
            default=None,
            help=f'condition image path for {cond_name}',
        )
        parser.add_argument(
            f'--{cond_name}_inp_type',
            type=str,
            default='image',
            help=f'the type of the input condition image, can be image or {cond_name}',
            choices=['image', cond_name],
        )
        parser.add_argument(
            f'--{cond_name}_adapter_ckpt',
            type=str,
            default=None,
            help=f'path to checkpoint of the {cond_name} adapter, '
                 f'if {cond_name}_path is not None, this should not be None too',
        )
        parser.add_argument(
            f'--{cond_name}_weight',
            type=float,
            default=1.0,
            help=f'the {cond_name} adapter features are multiplied by the {cond_name}_weight and then summed up together',
        )
    opt = parser.parse_args()

    # process argument
    activated_conds = []
    cond_paths = []
    adapter_ckpts = []
    for cond_name in supported_cond:
        if getattr(opt, f'{cond_name}_path') is None:
            continue
        assert getattr(opt, f'{cond_name}_adapter_ckpt') is not None, f'you should specify the {cond_name}_adapter_ckpt'
        activated_conds.append(cond_name)
        cond_paths.append(getattr(opt, f'{cond_name}_path'))
        adapter_ckpts.append(getattr(opt, f'{cond_name}_adapter_ckpt'))
    assert len(activated_conds) != 0, 'you did not input any condition'

    if opt.outdir is None:
        opt.outdir = f'outputs/test-composable-adapters'
    os.makedirs(opt.outdir, exist_ok=True)
    if opt.resize_short_edge is None:
        print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
    opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # prepare models
    adapters = []
    cond_models = []
    cond_inp_types = []
    process_cond_modules = []
    for cond_name in activated_conds:
        adapters.append(get_adapters(opt, getattr(ExtraCondition, cond_name)))
        cond_inp_type = getattr(opt, f'{cond_name}_inp_type', 'image')
        if cond_inp_type == 'image':
            cond_models.append(get_cond_model(opt, getattr(ExtraCondition, cond_name)))
        else:
            cond_models.append(None)
        cond_inp_types.append(cond_inp_type)
        process_cond_modules.append(getattr(api, f'get_cond_{cond_name}'))
    sd_model, sampler = get_sd_models(opt)

    # inference
    with torch.inference_mode(), \
            sd_model.ema_scope(), \
            autocast('cuda'):
        seed_everything(opt.seed)
        conds = []
        for cond_idx, cond_name in enumerate(activated_conds):
            conds.append(process_cond_modules[cond_idx](
                opt, cond_paths[cond_idx], cond_inp_types[cond_idx], cond_models[cond_idx],
            ))
        adapter_features, append_to_context = get_adapter_feature(conds, adapters)
        for v_idx in range(opt.n_samples):
            result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
            base_count = len(os.listdir(opt.outdir))
            cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(result))


if __name__ == '__main__':
    main()