hysts HF staff commited on
Commit
945eea6
·
1 Parent(s): 6cca304
Files changed (7) hide show
  1. .gitignore +162 -0
  2. .pre-commit-config.yaml +37 -0
  3. .style.yapf +5 -0
  4. app.py +89 -0
  5. model.py +354 -0
  6. requirements.txt +11 -0
  7. style.css +3 -0
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ELITE/
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+
9
+ from model import Model
10
+
11
+ repo_dir = pathlib.Path(__file__).parent
12
+
13
+
14
+ def create_demo():
15
+ DESCRIPTION = '# [ELITE](https://github.com/csyxwei/ELITE)'
16
+
17
+ model = Model()
18
+
19
+ with gr.Blocks(css=repo_dir / 'style.css') as demo:
20
+ gr.Markdown(DESCRIPTION)
21
+ with gr.Row():
22
+ with gr.Column():
23
+ with gr.Box():
24
+ image = gr.Image(label='Input', tool='sketch', type='pil')
25
+ gr.Markdown('Draw a mask on your object.')
26
+ prompt = gr.Text(
27
+ label='Prompt',
28
+ placeholder='e.g. "A photo of S", "S wearing sunglasses"',
29
+ info='Use "S" for your concept.')
30
+ lambda_ = gr.Slider(
31
+ label='Lambda',
32
+ minimum=0,
33
+ maximum=1,
34
+ step=0.1,
35
+ value=0.6,
36
+ info=
37
+ 'The larger the lambda, the more consistency between the generated image and the input image, but less editability.'
38
+ )
39
+ run_button = gr.Button('Run')
40
+ with gr.Accordion(label='Advanced options', open=False):
41
+ seed = gr.Slider(
42
+ label='Seed',
43
+ minimum=-1,
44
+ maximum=1000000,
45
+ step=1,
46
+ value=-1,
47
+ info=
48
+ 'If set to -1, a different seed will be used each time.'
49
+ )
50
+ guidance_scale = gr.Slider(label='Guidance scale',
51
+ minimum=0,
52
+ maximum=50,
53
+ step=0.1,
54
+ value=5.0)
55
+ num_steps = gr.Slider(
56
+ label='Steps',
57
+ minimum=1,
58
+ maximum=100,
59
+ step=1,
60
+ value=20,
61
+ info=
62
+ 'In the paper, the number of steps is set to 100, but in this demo the default value is 20 to reduce inference time.'
63
+ )
64
+ with gr.Column():
65
+ result = gr.Image(label='Result')
66
+
67
+ paths = sorted([
68
+ path.as_posix()
69
+ for path in (repo_dir / 'ELITE/test_datasets').glob('*')
70
+ if 'bg' not in path.stem
71
+ ])
72
+ gr.Examples(examples=paths, inputs=image, examples_per_page=20)
73
+
74
+ inputs = [
75
+ image,
76
+ prompt,
77
+ seed,
78
+ guidance_scale,
79
+ lambda_,
80
+ num_steps,
81
+ ]
82
+ prompt.submit(fn=model.run, inputs=inputs, outputs=result)
83
+ run_button.click(fn=model.run, inputs=inputs, outputs=result)
84
+ return demo
85
+
86
+
87
+ if __name__ == '__main__':
88
+ demo = create_demo()
89
+ demo.queue(api_open=False).launch()
model.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import random
6
+ import sys
7
+ from typing import Any
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torchvision.transforms as T
16
+ import tqdm.auto
17
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
18
+ from huggingface_hub import hf_hub_download, snapshot_download
19
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
20
+
21
+ HF_TOKEN = os.getenv('HF_TOKEN')
22
+
23
+ repo_dir = pathlib.Path(__file__).parent
24
+ submodule_dir = repo_dir / 'ELITE'
25
+ snapshot_download('ELITE-library/ELITE',
26
+ repo_type='model',
27
+ local_dir=submodule_dir.as_posix(),
28
+ token=HF_TOKEN)
29
+ sys.path.insert(0, submodule_dir.as_posix())
30
+
31
+ from train_local import (Mapper, MapperLocal, inj_forward_crossattention,
32
+ inj_forward_text, th2image)
33
+
34
+
35
+ def get_tensor_clip(normalize=True, toTensor=True):
36
+ transform_list = []
37
+ if toTensor:
38
+ transform_list += [T.ToTensor()]
39
+ if normalize:
40
+ transform_list += [
41
+ T.Normalize((0.48145466, 0.4578275, 0.40821073),
42
+ (0.26862954, 0.26130258, 0.27577711))
43
+ ]
44
+ return T.Compose(transform_list)
45
+
46
+
47
+ def process(image: np.ndarray, size: int = 512) -> torch.Tensor:
48
+ image = cv2.resize(image, (size, size), interpolation=cv2.INTER_CUBIC)
49
+ image = np.array(image).astype(np.float32)
50
+ image = image / 127.5 - 1.0
51
+ return torch.from_numpy(image).permute(2, 0, 1)
52
+
53
+
54
+ class Model:
55
+ def __init__(self):
56
+ self.device = torch.device(
57
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
58
+
59
+ (self.vae, self.unet, self.text_encoder, self.tokenizer,
60
+ self.image_encoder, self.mapper, self.mapper_local,
61
+ self.scheduler) = self.load_model()
62
+
63
+ def download_mappers(self) -> tuple[str, str]:
64
+ global_mapper_path = hf_hub_download('ELITE-library/ELITE',
65
+ 'global_mapper.pt',
66
+ subfolder='checkpoints',
67
+ repo_type='model',
68
+ token=HF_TOKEN)
69
+ local_mapper_path = hf_hub_download('ELITE-library/ELITE',
70
+ 'local_mapper.pt',
71
+ subfolder='checkpoints',
72
+ repo_type='model',
73
+ token=HF_TOKEN)
74
+ return global_mapper_path, local_mapper_path
75
+
76
+ def load_model(
77
+ self,
78
+ scheduler_type=LMSDiscreteScheduler
79
+ ) -> tuple[UNet2DConditionModel, CLIPTextModel, CLIPTokenizer,
80
+ AutoencoderKL, CLIPVisionModel, Mapper, MapperLocal,
81
+ LMSDiscreteScheduler, ]:
82
+ diffusion_model_id = 'CompVis/stable-diffusion-v1-4'
83
+
84
+ vae = AutoencoderKL.from_pretrained(
85
+ diffusion_model_id,
86
+ subfolder='vae',
87
+ torch_dtype=torch.float16,
88
+ )
89
+
90
+ tokenizer = CLIPTokenizer.from_pretrained(
91
+ 'openai/clip-vit-large-patch14',
92
+ torch_dtype=torch.float16,
93
+ )
94
+ text_encoder = CLIPTextModel.from_pretrained(
95
+ 'openai/clip-vit-large-patch14',
96
+ torch_dtype=torch.float16,
97
+ )
98
+ image_encoder = CLIPVisionModel.from_pretrained(
99
+ 'openai/clip-vit-large-patch14',
100
+ torch_dtype=torch.float16,
101
+ )
102
+
103
+ # Load models and create wrapper for stable diffusion
104
+ for _module in text_encoder.modules():
105
+ if _module.__class__.__name__ == 'CLIPTextTransformer':
106
+ _module.__class__.__call__ = inj_forward_text
107
+
108
+ unet = UNet2DConditionModel.from_pretrained(
109
+ diffusion_model_id,
110
+ subfolder='unet',
111
+ torch_dtype=torch.float16,
112
+ )
113
+ inj_forward_crossattention
114
+ mapper = Mapper(input_dim=1024, output_dim=768)
115
+
116
+ mapper_local = MapperLocal(input_dim=1024, output_dim=768)
117
+
118
+ for _name, _module in unet.named_modules():
119
+ if _module.__class__.__name__ == 'CrossAttention':
120
+ if 'attn1' in _name:
121
+ continue
122
+ _module.__class__.__call__ = inj_forward_crossattention
123
+
124
+ shape = _module.to_k.weight.shape
125
+ to_k_global = nn.Linear(shape[1], shape[0], bias=False)
126
+ mapper.add_module(f'{_name.replace(".", "_")}_to_k',
127
+ to_k_global)
128
+
129
+ shape = _module.to_v.weight.shape
130
+ to_v_global = nn.Linear(shape[1], shape[0], bias=False)
131
+ mapper.add_module(f'{_name.replace(".", "_")}_to_v',
132
+ to_v_global)
133
+
134
+ to_v_local = nn.Linear(shape[1], shape[0], bias=False)
135
+ mapper_local.add_module(f'{_name.replace(".", "_")}_to_v',
136
+ to_v_local)
137
+
138
+ to_k_local = nn.Linear(shape[1], shape[0], bias=False)
139
+ mapper_local.add_module(f'{_name.replace(".", "_")}_to_k',
140
+ to_k_local)
141
+
142
+ #global_mapper_path, local_mapper_path = self.download_mappers()
143
+ global_mapper_path = submodule_dir / 'checkpoints/global_mapper.pt'
144
+ local_mapper_path = submodule_dir / 'checkpoints/local_mapper.pt'
145
+
146
+ mapper.load_state_dict(
147
+ torch.load(global_mapper_path, map_location='cpu'))
148
+ mapper.half()
149
+
150
+ mapper_local.load_state_dict(
151
+ torch.load(local_mapper_path, map_location='cpu'))
152
+ mapper_local.half()
153
+
154
+ for _name, _module in unet.named_modules():
155
+ if 'attn1' in _name:
156
+ continue
157
+ if _module.__class__.__name__ == 'CrossAttention':
158
+ _module.add_module(
159
+ 'to_k_global',
160
+ mapper.__getattr__(f'{_name.replace(".", "_")}_to_k'))
161
+ _module.add_module(
162
+ 'to_v_global',
163
+ mapper.__getattr__(f'{_name.replace(".", "_")}_to_v'))
164
+ _module.add_module(
165
+ 'to_v_local',
166
+ getattr(mapper_local, f'{_name.replace(".", "_")}_to_v'))
167
+ _module.add_module(
168
+ 'to_k_local',
169
+ getattr(mapper_local, f'{_name.replace(".", "_")}_to_k'))
170
+
171
+ vae.eval().to(self.device)
172
+ unet.eval().to(self.device)
173
+ text_encoder.eval().to(self.device)
174
+ image_encoder.eval().to(self.device)
175
+ mapper.eval().to(self.device)
176
+ mapper_local.eval().to(self.device)
177
+
178
+ scheduler = scheduler_type(
179
+ beta_start=0.00085,
180
+ beta_end=0.012,
181
+ beta_schedule='scaled_linear',
182
+ num_train_timesteps=1000,
183
+ )
184
+ return (vae, unet, text_encoder, tokenizer, image_encoder, mapper,
185
+ mapper_local, scheduler)
186
+
187
+ def prepare_data(self,
188
+ image: PIL.Image.Image,
189
+ mask: PIL.Image.Image,
190
+ text: str,
191
+ placeholder_string: str = 'S') -> dict[str, Any]:
192
+ data: dict[str, Any] = {}
193
+
194
+ data['text'] = text
195
+
196
+ placeholder_index = 0
197
+ words = text.strip().split(' ')
198
+ for idx, word in enumerate(words):
199
+ if word == placeholder_string:
200
+ placeholder_index = idx + 1
201
+
202
+ data['index'] = torch.tensor(placeholder_index)
203
+
204
+ data['input_ids'] = self.tokenizer(
205
+ text,
206
+ padding='max_length',
207
+ truncation=True,
208
+ max_length=self.tokenizer.model_max_length,
209
+ return_tensors='pt',
210
+ ).input_ids[0]
211
+
212
+ image = image.convert('RGB')
213
+ mask = mask.convert('RGB')
214
+ mask = np.array(mask) / 255.0
215
+
216
+ image_np = np.array(image)
217
+ object_tensor = image_np * mask
218
+ data['pixel_values'] = process(image_np)
219
+
220
+ ref_object_tensor = PIL.Image.fromarray(
221
+ object_tensor.astype('uint8')).resize(
222
+ (224, 224), resample=PIL.Image.Resampling.BICUBIC)
223
+ ref_image_tenser = PIL.Image.fromarray(
224
+ image_np.astype('uint8')).resize(
225
+ (224, 224), resample=PIL.Image.Resampling.BICUBIC)
226
+ data['pixel_values_obj'] = get_tensor_clip()(ref_object_tensor)
227
+ data['pixel_values_clip'] = get_tensor_clip()(ref_image_tenser)
228
+
229
+ ref_seg_tensor = PIL.Image.fromarray(mask.astype('uint8') * 255)
230
+ ref_seg_tensor = get_tensor_clip(normalize=False)(ref_seg_tensor)
231
+ data['pixel_values_seg'] = F.interpolate(ref_seg_tensor.unsqueeze(0),
232
+ size=(128, 128),
233
+ mode='nearest').squeeze(0)
234
+
235
+ device = torch.device('cuda:0')
236
+ data['pixel_values'] = data['pixel_values'].to(device)
237
+ data['pixel_values_clip'] = data['pixel_values_clip'].to(device).half()
238
+ data['pixel_values_obj'] = data['pixel_values_obj'].to(device).half()
239
+ data['pixel_values_seg'] = data['pixel_values_seg'].to(device).half()
240
+ data['input_ids'] = data['input_ids'].to(device)
241
+ data['index'] = data['index'].to(device).long()
242
+
243
+ for key, value in list(data.items()):
244
+ if isinstance(value, torch.Tensor):
245
+ data[key] = value.unsqueeze(0)
246
+
247
+ return data
248
+
249
+ @torch.inference_mode()
250
+ def run(
251
+ self,
252
+ image: dict[str, PIL.Image.Image],
253
+ text: str,
254
+ seed: int,
255
+ guidance_scale: float,
256
+ lambda_: float,
257
+ num_steps: int,
258
+ ) -> PIL.Image.Image:
259
+ data = self.prepare_data(image['image'], image['mask'], text)
260
+
261
+ uncond_input = self.tokenizer(
262
+ [''] * data['pixel_values'].shape[0],
263
+ padding='max_length',
264
+ max_length=self.tokenizer.model_max_length,
265
+ return_tensors='pt',
266
+ )
267
+ uncond_embeddings = self.text_encoder(
268
+ {'input_ids': uncond_input.input_ids.to(self.device)})[0]
269
+
270
+ if seed == -1:
271
+ seed = random.randint(0, 1000000)
272
+ generator = torch.Generator().manual_seed(seed)
273
+ latents = torch.randn(
274
+ (data['pixel_values'].shape[0], self.unet.in_channels, 64, 64),
275
+ generator=generator,
276
+ )
277
+
278
+ latents = latents.to(data['pixel_values_clip'])
279
+ self.scheduler.set_timesteps(num_steps)
280
+ latents = latents * self.scheduler.init_noise_sigma
281
+
282
+ placeholder_idx = data['index']
283
+
284
+ image = F.interpolate(data['pixel_values_clip'], (224, 224),
285
+ mode='bilinear')
286
+ image_features = self.image_encoder(image, output_hidden_states=True)
287
+ image_embeddings = [
288
+ image_features[0],
289
+ image_features[2][4],
290
+ image_features[2][8],
291
+ image_features[2][12],
292
+ image_features[2][16],
293
+ ]
294
+ image_embeddings = [emb.detach() for emb in image_embeddings]
295
+ inj_embedding = self.mapper(image_embeddings)
296
+
297
+ inj_embedding = inj_embedding[:, 0:1, :]
298
+ encoder_hidden_states = self.text_encoder({
299
+ 'input_ids':
300
+ data['input_ids'],
301
+ 'inj_embedding':
302
+ inj_embedding,
303
+ 'inj_index':
304
+ placeholder_idx,
305
+ })[0]
306
+
307
+ image_obj = F.interpolate(data['pixel_values_obj'], (224, 224),
308
+ mode='bilinear')
309
+ image_features_obj = self.image_encoder(image_obj,
310
+ output_hidden_states=True)
311
+ image_embeddings_obj = [
312
+ image_features_obj[0],
313
+ image_features_obj[2][4],
314
+ image_features_obj[2][8],
315
+ image_features_obj[2][12],
316
+ image_features_obj[2][16],
317
+ ]
318
+ image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj]
319
+
320
+ inj_embedding_local = self.mapper_local(image_embeddings_obj)
321
+ mask = F.interpolate(data['pixel_values_seg'], (16, 16),
322
+ mode='nearest')
323
+ mask = mask[:, 0].reshape(mask.shape[0], -1, 1)
324
+ inj_embedding_local = inj_embedding_local * mask
325
+
326
+ for t in tqdm.auto.tqdm(self.scheduler.timesteps):
327
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
328
+ noise_pred_text = self.unet(latent_model_input,
329
+ t,
330
+ encoder_hidden_states={
331
+ 'CONTEXT_TENSOR':
332
+ encoder_hidden_states,
333
+ 'LOCAL': inj_embedding_local,
334
+ 'LOCAL_INDEX':
335
+ placeholder_idx.detach(),
336
+ 'LAMBDA': lambda_
337
+ }).sample
338
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
339
+
340
+ noise_pred_uncond = self.unet(latent_model_input,
341
+ t,
342
+ encoder_hidden_states={
343
+ 'CONTEXT_TENSOR':
344
+ uncond_embeddings,
345
+ }).sample
346
+ noise_pred = noise_pred_uncond + guidance_scale * (
347
+ noise_pred_text - noise_pred_uncond)
348
+
349
+ # compute the previous noisy sample x_t -> x_t-1
350
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
351
+
352
+ _latents = 1 / 0.18215 * latents.clone()
353
+ images = self.vae.decode(_latents).sample
354
+ return th2image(images[0])
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.16.0
2
+ albumentations==1.3.0
3
+ diffusers==0.11.1
4
+ gradio==3.20.1
5
+ huggingface-hub==0.13.0
6
+ opencv-python-headless==4.7.0.68
7
+ Pillow==9.4.0
8
+ torch==1.13.1
9
+ torchvision==0.14.1
10
+ tqdm==4.65.0
11
+ transformers==4.26.1
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }