File size: 11,952 Bytes
8fd2f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from diffusers import DDPMScheduler, DiffusionPipeline
from typing import List, Any, Union, Type
from utils.loader import get_class
from copy import deepcopy
from modules.loader.module_loader_config import ModuleLoaderConfig
import torch
import pytorch_lightning as pl
import jsonargparse


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


class GenericModuleLoader():

    def __init__(self,
                 pipeline_repo: str = None,
                 pipeline_obj: str = None,
                 set_prediction_type: str = "",
                 module_names: List[str] = [
                     "scheduler", "text_encoder", "tokenizer", "vae", "unet",],
                 module_config: dict[str,
                                     Union[ModuleLoaderConfig,  torch.nn.Module, Any]] = None,
                 fast_dev_run: Union[int, bool] = False,
                 root_cls: Type[Any] = None,
                 ) -> None:
        self.module_config = module_config
        self.pipeline_repo = pipeline_repo
        self.pipeline_obj = pipeline_obj
        self.set_prediction_type = set_prediction_type
        self.module_names = module_names
        self.fast_dev_run = fast_dev_run
        self.root_cls = root_cls

    def load_custom_scheduler(self):
        module_obj = DDPMScheduler.from_pretrained(
            self.pipeline_repo, subfolder="scheduler")

        if len(self.set_prediction_type) > 0:
            scheduler_config = module_obj.load_config(
                self.pipeline_repo, subfolder="scheduler")
            scheduler_config["prediction_type"] = self.set_prediction_type
            module_obj = module_obj.from_config(scheduler_config)
        return module_obj

    def load_pipeline(self):
        return DiffusionPipeline.from_pretrained(self.pipeline_repo) if self.pipeline_repo is not None else None

    def __call__(self, trainer: pl.LightningModule, diff_trainer_params):
        # load diffusers pipeline object if set
        if self.pipeline_obj is not None:
            pipe = self.load_pipeline()
        else:
            pipe = None

        if pipe is not None and self.pipeline_obj is not None:
            # store the entire diffusers pipeline object under the name given by pipeline_obj
            setattr(trainer, self.pipeline_obj, self.load_pipeline())

        for module_name in self.module_names:
            print(f" --- START:  Loading module: {module_name} ---")
            if module_name not in self.module_config.keys() and pipe is not None:
                # stores models from already loaded diffusers pipeline
                module_obj = getattr(pipe, module_name)
                if module_name == "scheduler":
                    module_obj = self.load_custom_scheduler()
                setattr(trainer, module_name, module_obj)
            else:
                if not isinstance(self.module_config[module_name], ModuleLoaderConfig):
                    # instantiate model by jsonargparse and store it
                    module = self.module_config[module_name]
                    # TODO we want to be able to load ckpt still.
                    config_obj = None
                else:
                    # instantiate object from class method (as used by Diffusers, e.g. DiffusionPipeline.load_from_pretrained)
                    config_obj = self.module_config[module_name]
                    # retrieve loader class
                    loader_cls = get_class(
                        config_obj.loader_cls_path)

                    # retrieve loader method
                    if config_obj.cls_func != "":
                        # we allow to specify a method for fast loading (e.g. in diffusers, from_config instead of from_pretrained)
                        # makes loading faster for quick testing
                        if not self.fast_dev_run or config_obj.cls_func_fast_dev_run == "":
                            cls_func = getattr(
                                loader_cls, config_obj.cls_func)
                        else:
                            print(
                                f"Model {module_name}: loading fast_dev_run class loader")
                            cls_func = getattr(
                                loader_cls, config_obj.cls_func_fast_dev_run)
                    else:
                        cls_func = loader_cls

                    # retrieve parameters
                    # load parameters specified in diff_trainer_params (so it links them)
                    kwargs_trainer_params = config_obj.kwargs_diff_trainer_params

                    kwargs_diffusers = config_obj.kwargs_diffusers

                    # names of dependent modules that we need as input
                    dependent_modules = config_obj.dependent_modules

                    # names of dependent modules that we need as input. Modules will be cloned
                    dependent_modules_cloned = config_obj.dependent_modules_cloned

                    # model kwargs. Can be just a dict, or a parameter class (derived from modules.params.params_mixin.AsDictMixin) so we have verification of inputs
                    model_params = config_obj.model_params

                    # kwargs used only if on fast_dev_run mode
                    model_params_fast_dev_run = config_obj.model_params_fast_dev_run

                    if model_params is not None:
                        if isinstance(model_params, dict):
                            model_dict = model_params
                        else:
                            model_dict = model_params.to_dict()
                    else:
                        model_dict = {}

                    if (model_params_fast_dev_run is None) or (not self.fast_dev_run):
                        model_params_fast_dev_run = {}
                    else:
                        print(
                            f"Module {module_name}: loading fast_dev_run params")

                    loaded_modules_dict = {}
                    if dependent_modules is not None:
                        for key, dependent_module in dependent_modules.items():
                            assert hasattr(
                                trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}"
                            loaded_modules_dict[key] = getattr(
                                trainer, dependent_module)

                    if dependent_modules_cloned is not None:
                        for key, dependent_module in dependent_modules_cloned.items():
                            assert hasattr(
                                trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}"
                            loaded_modules_dict[key] = getattr(
                                trainer, deepcopy(dependent_module))
                    if kwargs_trainer_params is not None:
                        for key, param in kwargs_trainer_params.items():
                            if param is not None:
                                kwargs_trainer_params[key] = getattr(
                                    diff_trainer_params, param)
                            else:
                                kwargs_trainer_params[key] = diff_trainer_params
                    else:
                        kwargs_trainer_params = {}

                    if kwargs_diffusers is None:
                        kwargs_diffusers = {}
                    else:
                        for key, value in kwargs_diffusers.items():
                            if key == "torch_dtype":
                                if value == "torch.float16":
                                    kwargs_diffusers[key] = torch.float16

                    kwargs = kwargs_diffusers | loaded_modules_dict | kwargs_trainer_params | model_dict | model_params_fast_dev_run
                    args = config_obj.args
                    # instantiate object
                    module = cls_func(*args, **kwargs)
                    module: torch.nn.Module
                    if self.root_cls is not None:
                        assert isinstance(module, self.root_cls)

                if config_obj is not None and config_obj.state_dict_path != "" and not self.fast_dev_run:
                    # TODO extend loading to hf spaces
                    print(
                        f"             * Loading checkpoint {config_obj.state_dict_path} - STARTED")
                    module_state_dict = torch.load(
                        config_obj.state_dict_path, map_location=torch.device("cpu"))
                    module_state_dict = module_state_dict["state_dict"]

                    if len(config_obj.state_dict_filters) > 0:
                        assert not config_obj.strict_loading
                        ckpt_params_dict = {}
                        for name, param in module.named_parameters(prefix=module_name):
                            for filter_str in config_obj.state_dict_filters:
                                filter_groups = filter_str.split("*")
                                has_all_parts = True
                                for filter_group in filter_groups:
                                    has_all_parts = has_all_parts and filter_group in name

                                if has_all_parts:
                                    validate_name = name
                                    for filter_group in filter_groups:
                                        if filter_group in validate_name:
                                            shift = validate_name.index(
                                                filter_group)
                                            validate_name = validate_name[shift+len(
                                                filter_group):]
                                        else:
                                            has_all_parts = False
                                            break
                                if has_all_parts:
                                    ckpt_params_dict[name[len(
                                        module_name+"."):]] = param
                    else:
                        ckpt_params_dict = dict(filter(lambda x: x[0].startswith(
                            module_name), module_state_dict.items()))
                        ckpt_params_dict = {
                            k.split(module_name+".")[1]: v for (k, v) in ckpt_params_dict.items()}

                    if len(ckpt_params_dict) > 0:
                        miss, unex = module.load_state_dict(
                            ckpt_params_dict, strict=config_obj.strict_loading)
                        ckpt_params_dict = {}
                        assert len(
                            unex) == 0, f"Unexpected parameters in checkpoint: {unex}"
                        if len(miss) > 0:
                            print(
                                f"Checkpoint {config_obj.state_dict_path} is missing parameters for module {module_name}.")
                            print(miss)
                    print(
                        f"             * Loading checkpoint {config_obj.state_dict_path} - FINISHED")
                if isinstance(module, jsonargparse.Namespace) or isinstance(module, dict):
                    print(bcolors.WARNING +
                          f"Warning: Seems object {module_name} was not build correct." + bcolors.ENDC)

                setattr(trainer, module_name, module)
            print(f" --- FINSHED:  Loading module: {module_name} ---")