Koke_Cacao commited on
Commit
74db9af
1 Parent(s): 0974835

:sparkles: clean up code

Browse files
.style.yapf ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = google
3
+ spaces_before_comment = 1
4
+ indent_width: 4
5
+ split_before_logical_operator = true
6
+ column_limit = 1024
scripts/README.md → README.md RENAMED
@@ -1,4 +1,8 @@
1
- # Convert original weights to diffusers
 
 
 
 
2
 
3
  Download original MVDream checkpoint through one of the following sources:
4
 
@@ -14,5 +18,5 @@ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd
14
 
15
  Hugging Face diffusers weights are converted by script:
16
  ```bash
17
- python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path . --original_config_file ./sd-v1.yaml
18
  ```
 
1
+ # MVDream-HF
2
+
3
+ A huggingface implementation of MVDream, used for quick one-line download. See [huggingface repo](https://huggingface.co/KokeCacao/mvdream-hf/tree/main) that hosts sd-v1.5 version.
4
+
5
+ ## Convert Original Weights to Diffusers
6
 
7
  Download original MVDream checkpoint through one of the following sources:
8
 
 
18
 
19
  Hugging Face diffusers weights are converted by script:
20
  ```bash
21
+ python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path . --original_config_file ./sd-v1.yaml --test
22
  ```
scripts/attention.py CHANGED
@@ -11,7 +11,6 @@ from einops import rearrange, repeat
11
  from typing import Optional, Any
12
  from util import checkpoint
13
 
14
-
15
  try:
16
  import xformers
17
  import xformers.ops
@@ -21,11 +20,12 @@ except:
21
 
22
  # CrossAttn precision handling
23
  import os
 
24
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
25
 
26
 
27
  def uniq(arr):
28
- return{el: True for el in arr}.keys()
29
 
30
 
31
  def default(val, d):
@@ -47,6 +47,7 @@ def init_(tensor):
47
 
48
  # feedforward
49
  class GEGLU(nn.Module):
 
50
  def __init__(self, dim_in, dim_out):
51
  super().__init__()
52
  self.proj = nn.Linear(dim_in, dim_out * 2)
@@ -57,20 +58,14 @@ class GEGLU(nn.Module):
57
 
58
 
59
  class FeedForward(nn.Module):
 
60
  def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
61
  super().__init__()
62
  inner_dim = int(dim * mult)
63
  dim_out = default(dim_out, dim)
64
- project_in = nn.Sequential(
65
- nn.Linear(dim, inner_dim),
66
- nn.GELU()
67
- ) if not glu else GEGLU(dim, inner_dim)
68
-
69
- self.net = nn.Sequential(
70
- project_in,
71
- nn.Dropout(dropout),
72
- nn.Linear(inner_dim, dim_out)
73
- )
74
 
75
  def forward(self, x):
76
  return self.net(x)
@@ -90,31 +85,16 @@ def Normalize(in_channels):
90
 
91
 
92
  class SpatialSelfAttention(nn.Module):
 
93
  def __init__(self, in_channels):
94
  super().__init__()
95
  self.in_channels = in_channels
96
 
97
  self.norm = Normalize(in_channels)
98
- self.q = torch.nn.Conv2d(in_channels,
99
- in_channels,
100
- kernel_size=1,
101
- stride=1,
102
- padding=0)
103
- self.k = torch.nn.Conv2d(in_channels,
104
- in_channels,
105
- kernel_size=1,
106
- stride=1,
107
- padding=0)
108
- self.v = torch.nn.Conv2d(in_channels,
109
- in_channels,
110
- kernel_size=1,
111
- stride=1,
112
- padding=0)
113
- self.proj_out = torch.nn.Conv2d(in_channels,
114
- in_channels,
115
- kernel_size=1,
116
- stride=1,
117
- padding=0)
118
 
119
  def forward(self, x):
120
  h_ = x
@@ -124,7 +104,7 @@ class SpatialSelfAttention(nn.Module):
124
  v = self.v(h_)
125
 
126
  # compute attention
127
- b,c,h,w = q.shape
128
  q = rearrange(q, 'b c h w -> b (h w) c')
129
  k = rearrange(k, 'b c h w -> b c (h w)')
130
  w_ = torch.einsum('bij,bjk->bik', q, k)
@@ -139,26 +119,24 @@ class SpatialSelfAttention(nn.Module):
139
  h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
140
  h_ = self.proj_out(h_)
141
 
142
- return x+h_
143
 
144
 
145
  class CrossAttention(nn.Module):
 
146
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
147
  super().__init__()
148
  inner_dim = dim_head * heads
149
  context_dim = default(context_dim, query_dim)
150
 
151
- self.scale = dim_head ** -0.5
152
  self.heads = heads
153
 
154
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
155
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
156
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
157
 
158
- self.to_out = nn.Sequential(
159
- nn.Linear(inner_dim, query_dim),
160
- nn.Dropout(dropout)
161
- )
162
 
163
  def forward(self, x, context=None, mask=None):
164
  h = self.heads
@@ -171,15 +149,15 @@ class CrossAttention(nn.Module):
171
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
172
 
173
  # force cast to fp32 to avoid overflowing
174
- if _ATTN_PRECISION =="fp32":
175
- with autocast(enabled=False, device_type = 'cuda'):
176
  q, k = q.float(), k.float()
177
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178
  else:
179
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
180
-
181
  del q, k
182
-
183
  if mask is not None:
184
  mask = rearrange(mask, 'b ... -> b (...)')
185
  max_neg_value = -torch.finfo(sim.dtype).max
@@ -221,11 +199,7 @@ class MemoryEfficientCrossAttention(nn.Module):
221
 
222
  b, _, _ = q.shape
223
  q, k, v = map(
224
- lambda t: t.unsqueeze(3)
225
- .reshape(b, t.shape[1], self.heads, self.dim_head)
226
- .permute(0, 2, 1, 3)
227
- .reshape(b * self.heads, t.shape[1], self.dim_head)
228
- .contiguous(),
229
  (q, k, v),
230
  )
231
 
@@ -234,32 +208,25 @@ class MemoryEfficientCrossAttention(nn.Module):
234
 
235
  if mask is not None:
236
  raise NotImplementedError
237
- out = (
238
- out.unsqueeze(0)
239
- .reshape(b, self.heads, out.shape[1], self.dim_head)
240
- .permute(0, 2, 1, 3)
241
- .reshape(b, out.shape[1], self.heads * self.dim_head)
242
- )
243
  return self.to_out(out)
244
 
245
 
246
  class BasicTransformerBlock(nn.Module):
247
  ATTENTION_MODES = {
248
- "softmax": CrossAttention, # vanilla attention
249
  "softmax-xformers": MemoryEfficientCrossAttention
250
  }
251
- def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
252
- disable_self_attn=False):
253
  super().__init__()
254
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
255
  assert attn_mode in self.ATTENTION_MODES
256
  attn_cls = self.ATTENTION_MODES[attn_mode]
257
  self.disable_self_attn = disable_self_attn
258
- self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
259
- context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
260
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
261
- self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
262
- heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
263
  self.norm1 = nn.LayerNorm(dim)
264
  self.norm2 = nn.LayerNorm(dim)
265
  self.norm3 = nn.LayerNorm(dim)
@@ -284,10 +251,8 @@ class SpatialTransformer(nn.Module):
284
  Finally, reshape to image
285
  NEW: use_linear for more efficiency instead of the 1x1 convs
286
  """
287
- def __init__(self, in_channels, n_heads, d_head,
288
- depth=1, dropout=0., context_dim=None,
289
- disable_self_attn=False, use_linear=False,
290
- use_checkpoint=True):
291
  super().__init__()
292
  assert context_dim is not None
293
  if not isinstance(context_dim, list):
@@ -296,25 +261,13 @@ class SpatialTransformer(nn.Module):
296
  inner_dim = n_heads * d_head
297
  self.norm = Normalize(in_channels)
298
  if not use_linear:
299
- self.proj_in = nn.Conv2d(in_channels,
300
- inner_dim,
301
- kernel_size=1,
302
- stride=1,
303
- padding=0)
304
  else:
305
  self.proj_in = nn.Linear(in_channels, inner_dim)
306
 
307
- self.transformer_blocks = nn.ModuleList(
308
- [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
309
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
310
- for d in range(depth)]
311
- )
312
  if not use_linear:
313
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
314
- in_channels,
315
- kernel_size=1,
316
- stride=1,
317
- padding=0))
318
  else:
319
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
320
  self.use_linear = use_linear
@@ -356,11 +309,9 @@ class BasicTransformerBlock3D(BasicTransformerBlock):
356
 
357
 
358
  class SpatialTransformer3D(nn.Module):
359
- ''' 3D self-attention '''
360
- def __init__(self, in_channels, n_heads, d_head,
361
- depth=1, dropout=0., context_dim=None,
362
- disable_self_attn=False, use_linear=False,
363
- use_checkpoint=True):
364
  super().__init__()
365
  assert context_dim is not None
366
  if not isinstance(context_dim, list):
@@ -369,25 +320,13 @@ class SpatialTransformer3D(nn.Module):
369
  inner_dim = n_heads * d_head
370
  self.norm = Normalize(in_channels)
371
  if not use_linear:
372
- self.proj_in = nn.Conv2d(in_channels,
373
- inner_dim,
374
- kernel_size=1,
375
- stride=1,
376
- padding=0)
377
  else:
378
  self.proj_in = nn.Linear(in_channels, inner_dim)
379
 
380
- self.transformer_blocks = nn.ModuleList(
381
- [BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
382
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
383
- for d in range(depth)]
384
- )
385
  if not use_linear:
386
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
387
- in_channels,
388
- kernel_size=1,
389
- stride=1,
390
- padding=0))
391
  else:
392
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
393
  self.use_linear = use_linear
@@ -411,4 +350,4 @@ class SpatialTransformer3D(nn.Module):
411
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
412
  if not self.use_linear:
413
  x = self.proj_out(x)
414
- return x + x_in
 
11
  from typing import Optional, Any
12
  from util import checkpoint
13
 
 
14
  try:
15
  import xformers
16
  import xformers.ops
 
20
 
21
  # CrossAttn precision handling
22
  import os
23
+
24
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
25
 
26
 
27
  def uniq(arr):
28
+ return {el: True for el in arr}.keys()
29
 
30
 
31
  def default(val, d):
 
47
 
48
  # feedforward
49
  class GEGLU(nn.Module):
50
+
51
  def __init__(self, dim_in, dim_out):
52
  super().__init__()
53
  self.proj = nn.Linear(dim_in, dim_out * 2)
 
58
 
59
 
60
  class FeedForward(nn.Module):
61
+
62
  def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
63
  super().__init__()
64
  inner_dim = int(dim * mult)
65
  dim_out = default(dim_out, dim)
66
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
67
+
68
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
 
 
 
 
 
 
 
69
 
70
  def forward(self, x):
71
  return self.net(x)
 
85
 
86
 
87
  class SpatialSelfAttention(nn.Module):
88
+
89
  def __init__(self, in_channels):
90
  super().__init__()
91
  self.in_channels = in_channels
92
 
93
  self.norm = Normalize(in_channels)
94
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
95
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
96
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
97
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def forward(self, x):
100
  h_ = x
 
104
  v = self.v(h_)
105
 
106
  # compute attention
107
+ b, c, h, w = q.shape
108
  q = rearrange(q, 'b c h w -> b (h w) c')
109
  k = rearrange(k, 'b c h w -> b c (h w)')
110
  w_ = torch.einsum('bij,bjk->bik', q, k)
 
119
  h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
120
  h_ = self.proj_out(h_)
121
 
122
+ return x + h_
123
 
124
 
125
  class CrossAttention(nn.Module):
126
+
127
  def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
128
  super().__init__()
129
  inner_dim = dim_head * heads
130
  context_dim = default(context_dim, query_dim)
131
 
132
+ self.scale = dim_head**-0.5
133
  self.heads = heads
134
 
135
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
136
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
137
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
138
 
139
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
 
 
 
140
 
141
  def forward(self, x, context=None, mask=None):
142
  h = self.heads
 
149
  q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
150
 
151
  # force cast to fp32 to avoid overflowing
152
+ if _ATTN_PRECISION == "fp32":
153
+ with autocast(enabled=False, device_type='cuda'):
154
  q, k = q.float(), k.float()
155
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
156
  else:
157
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
158
+
159
  del q, k
160
+
161
  if mask is not None:
162
  mask = rearrange(mask, 'b ... -> b (...)')
163
  max_neg_value = -torch.finfo(sim.dtype).max
 
199
 
200
  b, _, _ = q.shape
201
  q, k, v = map(
202
+ lambda t: t.unsqueeze(3).reshape(b, t.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, t.shape[1], self.dim_head).contiguous(),
 
 
 
 
203
  (q, k, v),
204
  )
205
 
 
208
 
209
  if mask is not None:
210
  raise NotImplementedError
211
+ out = (out.unsqueeze(0).reshape(b, self.heads, out.shape[1], self.dim_head).permute(0, 2, 1, 3).reshape(b, out.shape[1], self.heads * self.dim_head))
 
 
 
 
 
212
  return self.to_out(out)
213
 
214
 
215
  class BasicTransformerBlock(nn.Module):
216
  ATTENTION_MODES = {
217
+ "softmax": CrossAttention, # vanilla attention
218
  "softmax-xformers": MemoryEfficientCrossAttention
219
  }
220
+
221
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
222
  super().__init__()
223
  attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
224
  assert attn_mode in self.ATTENTION_MODES
225
  attn_cls = self.ATTENTION_MODES[attn_mode]
226
  self.disable_self_attn = disable_self_attn
227
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
 
228
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
229
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
 
230
  self.norm1 = nn.LayerNorm(dim)
231
  self.norm2 = nn.LayerNorm(dim)
232
  self.norm3 = nn.LayerNorm(dim)
 
251
  Finally, reshape to image
252
  NEW: use_linear for more efficiency instead of the 1x1 convs
253
  """
254
+
255
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True):
 
 
256
  super().__init__()
257
  assert context_dim is not None
258
  if not isinstance(context_dim, list):
 
261
  inner_dim = n_heads * d_head
262
  self.norm = Normalize(in_channels)
263
  if not use_linear:
264
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
 
 
 
 
265
  else:
266
  self.proj_in = nn.Linear(in_channels, inner_dim)
267
 
268
+ self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) for d in range(depth)])
 
 
 
 
269
  if not use_linear:
270
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
 
 
 
 
271
  else:
272
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
273
  self.use_linear = use_linear
 
309
 
310
 
311
  class SpatialTransformer3D(nn.Module):
312
+ ''' 3D self-attention '''
313
+
314
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, use_checkpoint=True):
 
 
315
  super().__init__()
316
  assert context_dim is not None
317
  if not isinstance(context_dim, list):
 
320
  inner_dim = n_heads * d_head
321
  self.norm = Normalize(in_channels)
322
  if not use_linear:
323
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
 
 
 
 
324
  else:
325
  self.proj_in = nn.Linear(in_channels, inner_dim)
326
 
327
+ self.transformer_blocks = nn.ModuleList([BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) for d in range(depth)])
 
 
 
 
328
  if not use_linear:
329
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
 
 
 
 
330
  else:
331
  self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
332
  self.use_linear = use_linear
 
350
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
351
  if not self.use_linear:
352
  x = self.proj_out(x)
353
+ return x + x_in
scripts/convert_mvdream_to_diffusers.py CHANGED
@@ -3,6 +3,7 @@
3
  import argparse
4
  import torch
5
  import sys
 
6
  sys.path.insert(0, '../')
7
 
8
  from transformers import (
@@ -126,9 +127,7 @@ logger = logging.get_logger(__name__)
126
  # return config
127
 
128
 
129
- def assign_to_checkpoint(
130
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
131
- ):
132
  """
133
  This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
134
  attention layers, and takes into account additional replacements that may arise.
@@ -144,6 +143,7 @@ def assign_to_checkpoint(
144
 
145
  target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
146
 
 
147
  num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
148
 
149
  old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
@@ -211,6 +211,7 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
211
 
212
  return mapping
213
 
 
214
  def renew_attention_paths(old_list, n_shave_prefix_segments=0):
215
  """
216
  Updates paths inside attentions to the new naming scheme (local renaming)
@@ -231,6 +232,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
231
 
232
  return mapping
233
 
 
234
  # def convert_ldm_unet_checkpoint(
235
  # checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
236
  # ):
@@ -496,6 +498,7 @@ def create_vae_diffusers_config(original_config, image_size: int):
496
  }
497
  return config
498
 
 
499
  def convert_ldm_vae_checkpoint(checkpoint, config):
500
  # extract state dict for VAE
501
  vae_state_dict = {}
@@ -528,26 +531,18 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
528
 
529
  # Retrieves the keys for the encoder down blocks only
530
  num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
531
- down_blocks = {
532
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
533
- }
534
 
535
  # Retrieves the keys for the decoder up blocks only
536
  num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
537
- up_blocks = {
538
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
539
- }
540
 
541
  for i in range(num_down_blocks):
542
  resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
543
 
544
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
545
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
546
- f"encoder.down.{i}.downsample.conv.weight"
547
- )
548
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
549
- f"encoder.down.{i}.downsample.conv.bias"
550
- )
551
 
552
  paths = renew_vae_resnet_paths(resnets)
553
  meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
@@ -570,17 +565,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
570
 
571
  for i in range(num_up_blocks):
572
  block_id = num_up_blocks - 1 - i
573
- resnets = [
574
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
575
- ]
576
 
577
  if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
578
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
579
- f"decoder.up.{block_id}.upsample.conv.weight"
580
- ]
581
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
582
- f"decoder.up.{block_id}.upsample.conv.bias"
583
- ]
584
 
585
  paths = renew_vae_resnet_paths(resnets)
586
  meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
@@ -618,6 +607,7 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
618
 
619
  return mapping
620
 
 
621
  def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
622
  """
623
  Updates paths inside attentions to the new naming scheme (local renaming)
@@ -659,12 +649,8 @@ def conv_attn_to_linear(checkpoint):
659
  if checkpoint[key].ndim > 2:
660
  checkpoint[key] = checkpoint[key][:, :, 0]
661
 
662
- def convert_from_original_mvdream_ckpt(
663
- checkpoint_path,
664
- original_config_file,
665
- extract_ema,
666
- device
667
- ):
668
  checkpoint = torch.load(checkpoint_path, map_location=device)
669
  # print(f"Checkpoint: {checkpoint.keys()}")
670
  torch.cuda.empty_cache()
@@ -702,9 +688,7 @@ def convert_from_original_mvdream_ckpt(
702
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
703
  unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**original_config.model.params.unet_config.params)
704
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
705
- unet.load_state_dict({
706
- key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()
707
- })
708
  for param_name, param in unet.state_dict().items():
709
  set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
710
 
@@ -712,25 +696,21 @@ def convert_from_original_mvdream_ckpt(
712
  vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
713
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
714
 
715
- if (
716
- "model" in original_config
717
- and "params" in original_config.model
718
- and "scale_factor" in original_config.model.params
719
- ):
720
  vae_scaling_factor = original_config.model.params.scale_factor
721
  else:
722
- vae_scaling_factor = 0.18215 # default SD scaling factor
723
 
724
  vae_config["scaling_factor"] = vae_scaling_factor
725
 
726
  with init_empty_weights():
727
  vae = AutoencoderKL(**vae_config)
728
-
729
  tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
730
  text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=torch.device("cuda:0")) # type: ignore
731
 
732
  for param_name, param in converted_vae_checkpoint.items():
733
- set_module_tensor_to_device(vae, param_name, "cuda:0", value=param)
734
 
735
  pipe = MVDreamStableDiffusionPipeline(
736
  vae=vae,
@@ -746,30 +726,20 @@ def convert_from_original_mvdream_ckpt(
746
  if __name__ == "__main__":
747
  parser = argparse.ArgumentParser()
748
 
749
- parser.add_argument(
750
- "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
751
- )
752
  parser.add_argument(
753
  "--original_config_file",
754
  default=None,
755
  type=str,
756
  help="The YAML config file corresponding to the original architecture.",
757
  )
758
- parser.add_argument(
759
- "--extract_ema",
760
- action="store_true",
761
- help=(
762
- "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
763
- " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
764
- " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
765
- ),
766
- )
767
  parser.add_argument(
768
  "--to_safetensors",
769
  action="store_true",
770
  help="Whether to store pipeline in safetensors format or not.",
771
  )
772
- parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
 
773
  parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
774
  parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
775
  args = parser.parse_args()
@@ -777,22 +747,21 @@ if __name__ == "__main__":
777
  pipe = convert_from_original_mvdream_ckpt(
778
  checkpoint_path=args.checkpoint_path,
779
  original_config_file=args.original_config_file,
780
- extract_ema=args.extract_ema,
781
  device=args.device,
782
  )
783
 
784
  if args.half:
785
  pipe.to(torch_dtype=torch.float16)
786
-
787
- images = pipe(
788
- prompt="Head of Hatsune Miku",
789
- negative_prompt="painting, bad quality, flat",
790
- output_type="pil",
791
- return_dict=False,
792
- guidance_scale=7.5,
793
- num_inference_steps=50,
794
- )
795
- for i, image in enumerate(images):
796
- image.save(f"image_{i}.png")
797
-
798
- pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
 
3
  import argparse
4
  import torch
5
  import sys
6
+
7
  sys.path.insert(0, '../')
8
 
9
  from transformers import (
 
127
  # return config
128
 
129
 
130
+ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
 
 
131
  """
132
  This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
133
  attention layers, and takes into account additional replacements that may arise.
 
143
 
144
  target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
145
 
146
+ assert config is not None
147
  num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
148
 
149
  old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
 
211
 
212
  return mapping
213
 
214
+
215
  def renew_attention_paths(old_list, n_shave_prefix_segments=0):
216
  """
217
  Updates paths inside attentions to the new naming scheme (local renaming)
 
232
 
233
  return mapping
234
 
235
+
236
  # def convert_ldm_unet_checkpoint(
237
  # checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
238
  # ):
 
498
  }
499
  return config
500
 
501
+
502
  def convert_ldm_vae_checkpoint(checkpoint, config):
503
  # extract state dict for VAE
504
  vae_state_dict = {}
 
531
 
532
  # Retrieves the keys for the encoder down blocks only
533
  num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
534
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
 
 
535
 
536
  # Retrieves the keys for the decoder up blocks only
537
  num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
538
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
 
 
539
 
540
  for i in range(num_down_blocks):
541
  resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
542
 
543
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
544
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
545
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
 
 
 
 
546
 
547
  paths = renew_vae_resnet_paths(resnets)
548
  meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
 
565
 
566
  for i in range(num_up_blocks):
567
  block_id = num_up_blocks - 1 - i
568
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
 
 
569
 
570
  if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
571
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
572
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
 
 
 
 
573
 
574
  paths = renew_vae_resnet_paths(resnets)
575
  meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
 
607
 
608
  return mapping
609
 
610
+
611
  def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
612
  """
613
  Updates paths inside attentions to the new naming scheme (local renaming)
 
649
  if checkpoint[key].ndim > 2:
650
  checkpoint[key] = checkpoint[key][:, :, 0]
651
 
652
+
653
+ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
 
 
 
 
654
  checkpoint = torch.load(checkpoint_path, map_location=device)
655
  # print(f"Checkpoint: {checkpoint.keys()}")
656
  torch.cuda.empty_cache()
 
688
  # print(f"Unet Config: {original_config.model.params.unet_config.params}")
689
  unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**original_config.model.params.unet_config.params)
690
  # print(f"Unet State Dict: {unet.state_dict().keys()}")
691
+ unet.load_state_dict({key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()})
 
 
692
  for param_name, param in unet.state_dict().items():
693
  set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
694
 
 
696
  vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
697
  converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
698
 
699
+ if ("model" in original_config and "params" in original_config.model and "scale_factor" in original_config.model.params):
 
 
 
 
700
  vae_scaling_factor = original_config.model.params.scale_factor
701
  else:
702
+ vae_scaling_factor = 0.18215 # default SD scaling factor
703
 
704
  vae_config["scaling_factor"] = vae_scaling_factor
705
 
706
  with init_empty_weights():
707
  vae = AutoencoderKL(**vae_config)
708
+
709
  tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
710
  text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=torch.device("cuda:0")) # type: ignore
711
 
712
  for param_name, param in converted_vae_checkpoint.items():
713
+ set_module_tensor_to_device(vae, param_name, "cuda:0", value=param)
714
 
715
  pipe = MVDreamStableDiffusionPipeline(
716
  vae=vae,
 
726
  if __name__ == "__main__":
727
  parser = argparse.ArgumentParser()
728
 
729
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert.")
 
 
730
  parser.add_argument(
731
  "--original_config_file",
732
  default=None,
733
  type=str,
734
  help="The YAML config file corresponding to the original architecture.",
735
  )
 
 
 
 
 
 
 
 
 
736
  parser.add_argument(
737
  "--to_safetensors",
738
  action="store_true",
739
  help="Whether to store pipeline in safetensors format or not.",
740
  )
741
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
742
+ parser.add_argument("--test", help="Whether to test inference after convertion.")
743
  parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
744
  parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
745
  args = parser.parse_args()
 
747
  pipe = convert_from_original_mvdream_ckpt(
748
  checkpoint_path=args.checkpoint_path,
749
  original_config_file=args.original_config_file,
 
750
  device=args.device,
751
  )
752
 
753
  if args.half:
754
  pipe.to(torch_dtype=torch.float16)
755
+
756
+ if args.test:
757
+ images = pipe(
758
+ prompt="Head of Hatsune Miku",
759
+ negative_prompt="painting, bad quality, flat",
760
+ output_type="pil",
761
+ guidance_scale=7.5,
762
+ num_inference_steps=50,
763
+ )
764
+ for i, image in enumerate(images):
765
+ image.save(f"image_{i}.png") # type: ignore
766
+
767
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
scripts/models.py CHANGED
@@ -82,29 +82,19 @@ class Upsample(nn.Module):
82
  upsampling occurs in the inner-two dimensions.
83
  """
84
 
85
- def __init__(self,
86
- channels,
87
- use_conv,
88
- dims=2,
89
- out_channels=None,
90
- padding=1):
91
  super().__init__()
92
  self.channels = channels
93
  self.out_channels = out_channels or channels
94
  self.use_conv = use_conv
95
  self.dims = dims
96
  if use_conv:
97
- self.conv = conv_nd(dims,
98
- self.channels,
99
- self.out_channels,
100
- 3,
101
- padding=padding)
102
 
103
  def forward(self, x):
104
  assert x.shape[1] == self.channels
105
  if self.dims == 3:
106
- x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
107
- mode="nearest")
108
  else:
109
  x = F.interpolate(x, scale_factor=2, mode="nearest")
110
  if self.use_conv:
@@ -121,12 +111,7 @@ class Downsample(nn.Module):
121
  downsampling occurs in the inner-two dimensions.
122
  """
123
 
124
- def __init__(self,
125
- channels,
126
- use_conv,
127
- dims=2,
128
- out_channels=None,
129
- padding=1):
130
  super().__init__()
131
  self.channels = channels
132
  self.out_channels = out_channels or channels
@@ -134,12 +119,7 @@ class Downsample(nn.Module):
134
  self.dims = dims
135
  stride = 2 if dims != 3 else (1, 2, 2)
136
  if use_conv:
137
- self.op = conv_nd(dims,
138
- self.channels,
139
- self.out_channels,
140
- 3,
141
- stride=stride,
142
- padding=padding)
143
  else:
144
  assert self.channels == self.out_channels
145
  self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@@ -208,33 +188,22 @@ class ResBlock(TimestepBlock):
208
  nn.SiLU(),
209
  linear(
210
  emb_channels,
211
- 2 * self.out_channels
212
- if use_scale_shift_norm else self.out_channels,
213
  ),
214
  )
215
  self.out_layers = nn.Sequential(
216
  normalization(self.out_channels),
217
  nn.SiLU(),
218
  nn.Dropout(p=dropout),
219
- zero_module(
220
- conv_nd(dims,
221
- self.out_channels,
222
- self.out_channels,
223
- 3,
224
- padding=1)),
225
  )
226
 
227
  if self.out_channels == channels:
228
  self.skip_connection = nn.Identity()
229
  elif use_conv:
230
- self.skip_connection = conv_nd(dims,
231
- channels,
232
- self.out_channels,
233
- 3,
234
- padding=1)
235
  else:
236
- self.skip_connection = conv_nd(dims, channels, self.out_channels,
237
- 1)
238
 
239
  def forward(self, x, emb):
240
  """
@@ -243,8 +212,7 @@ class ResBlock(TimestepBlock):
243
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
244
  :return: an [N x C x ...] Tensor of outputs.
245
  """
246
- return checkpoint(self._forward, (x, emb), self.parameters(),
247
- self.use_checkpoint)
248
 
249
  def _forward(self, x, emb):
250
  if self.updown:
@@ -289,9 +257,7 @@ class AttentionBlock(nn.Module):
289
  if num_head_channels == -1:
290
  self.num_heads = num_heads
291
  else:
292
- assert (
293
- channels % num_head_channels == 0
294
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
295
  self.num_heads = channels // num_head_channels
296
  self.use_checkpoint = use_checkpoint
297
  self.norm = normalization(channels)
@@ -306,9 +272,7 @@ class AttentionBlock(nn.Module):
306
  self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
307
 
308
  def forward(self, x):
309
- return checkpoint(
310
- self._forward, (x, ), self.parameters(), True
311
- ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
312
  #return pt_checkpoint(self._forward, x) # pytorch
313
 
314
  def _forward(self, x):
@@ -358,12 +322,9 @@ class QKVAttentionLegacy(nn.Module):
358
  bs, width, length = qkv.shape
359
  assert width % (3 * self.n_heads) == 0
360
  ch = width // (3 * self.n_heads)
361
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
362
- dim=1)
363
  scale = 1 / math.sqrt(math.sqrt(ch))
364
- weight = th.einsum(
365
- "bct,bcs->bts", q * scale,
366
- k * scale) # More stable with f16 than dividing afterwards
367
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
368
  a = th.einsum("bts,bcs->bct", weight, v)
369
  return a.reshape(bs, -1, length)
@@ -397,10 +358,9 @@ class QKVAttention(nn.Module):
397
  "bct,bcs->bts",
398
  (q * scale).view(bs * self.n_heads, ch, length),
399
  (k * scale).view(bs * self.n_heads, ch, length),
400
- ) # More stable with f16 than dividing afterwards
401
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
402
- a = th.einsum("bts,bcs->bct", weight,
403
- v.reshape(bs * self.n_heads, ch, length))
404
  return a.reshape(bs, -1, length)
405
 
406
  @staticmethod
@@ -450,41 +410,40 @@ class MultiViewUNetModel(nn.Module):
450
  """
451
 
452
  def __init__(
453
- self,
454
- image_size,
455
- in_channels,
456
- model_channels,
457
- out_channels,
458
- num_res_blocks,
459
- attention_resolutions,
460
- dropout=0,
461
- channel_mult=(1, 2, 4, 8),
462
- conv_resample=True,
463
- dims=2,
464
- num_classes=None,
465
- use_checkpoint=False,
466
- use_fp16=False,
467
- use_bf16=False,
468
- num_heads=-1,
469
- num_head_channels=-1,
470
- num_heads_upsample=-1,
471
- use_scale_shift_norm=False,
472
- resblock_updown=False,
473
- use_new_attention_order=False,
474
- use_spatial_transformer=False, # custom transformer support
475
- transformer_depth=1, # custom transformer support
476
- context_dim=None, # custom transformer support
477
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
478
- legacy=True,
479
- disable_self_attentions=None,
480
- num_attention_blocks=None,
481
- disable_middle_self_attn=False,
482
- use_linear_in_transformer=False,
483
- adm_in_channels=None,
484
- camera_dim=None,
485
  ):
486
  super().__init__()
487
- assert num_classes is not None
488
  if use_spatial_transformer:
489
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
490
 
@@ -511,26 +470,19 @@ class MultiViewUNetModel(nn.Module):
511
  self.num_res_blocks = len(channel_mult) * [num_res_blocks]
512
  else:
513
  if len(num_res_blocks) != len(channel_mult):
514
- raise ValueError(
515
- "provide num_res_blocks either as an int (globally constant) or "
516
- "as a list/tuple (per-level) with the same length as channel_mult"
517
- )
518
  self.num_res_blocks = num_res_blocks
519
  if disable_self_attentions is not None:
520
  # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
521
  assert len(disable_self_attentions) == len(channel_mult)
522
  if num_attention_blocks is not None:
523
  assert len(num_attention_blocks) == len(self.num_res_blocks)
524
- assert all(
525
- map(
526
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
527
- ],
528
- range(len(num_attention_blocks))))
529
- print(
530
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
531
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
532
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
533
- f"attention will still not be set.")
534
 
535
  self.attention_resolutions = attention_resolutions
536
  self.dropout = dropout
@@ -562,42 +514,36 @@ class MultiViewUNetModel(nn.Module):
562
 
563
  if self.num_classes is not None:
564
  if isinstance(self.num_classes, int):
565
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
566
  elif self.num_classes == "continuous":
567
  print("setting up linear c_adm embedding layer")
568
  self.label_emb = nn.Linear(1, time_embed_dim)
569
  elif self.num_classes == "sequential":
570
  assert adm_in_channels is not None
571
- self.label_emb = nn.Sequential(
572
- nn.Sequential(
573
- linear(adm_in_channels, time_embed_dim),
574
- nn.SiLU(),
575
- linear(time_embed_dim, time_embed_dim),
576
- ))
577
  else:
578
  raise ValueError()
579
 
580
- self.input_blocks = nn.ModuleList([
581
- TimestepEmbedSequential(
582
- conv_nd(dims, in_channels, model_channels, 3, padding=1))
583
- ])
584
  self._feature_size = model_channels
585
  input_block_chans = [model_channels]
586
  ch = model_channels
587
  ds = 1
588
  for level, mult in enumerate(channel_mult):
589
  for nr in range(self.num_res_blocks[level]):
590
- layers: List[Any] = [
591
- ResBlock(
592
- ch,
593
- time_embed_dim,
594
- dropout,
595
- out_channels=mult * model_channels,
596
- dims=dims,
597
- use_checkpoint=use_checkpoint,
598
- use_scale_shift_norm=use_scale_shift_norm,
599
- )
600
- ]
601
  ch = mult * model_channels
602
  if ds in attention_resolutions:
603
  if num_head_channels == -1:
@@ -613,44 +559,29 @@ class MultiViewUNetModel(nn.Module):
613
  else:
614
  disabled_sa = False
615
 
616
- if num_attention_blocks is None or nr < num_attention_blocks[
617
- level]:
618
- layers.append(
619
- AttentionBlock(
620
- ch,
621
- use_checkpoint=use_checkpoint,
622
- num_heads=num_heads,
623
- num_head_channels=dim_head,
624
- use_new_attention_order=use_new_attention_order,
625
- ) if not use_spatial_transformer else
626
- SpatialTransformer3D(
627
- ch,
628
- num_heads,
629
- dim_head,
630
- depth=transformer_depth,
631
- context_dim=context_dim,
632
- disable_self_attn=disabled_sa,
633
- use_linear=use_linear_in_transformer,
634
- use_checkpoint=use_checkpoint))
635
  self.input_blocks.append(TimestepEmbedSequential(*layers))
636
  self._feature_size += ch
637
  input_block_chans.append(ch)
638
  if level != len(channel_mult) - 1:
639
  out_ch = ch
640
- self.input_blocks.append(
641
- TimestepEmbedSequential(
642
- ResBlock(
643
- ch,
644
- time_embed_dim,
645
- dropout,
646
- out_channels=out_ch,
647
- dims=dims,
648
- use_checkpoint=use_checkpoint,
649
- use_scale_shift_norm=use_scale_shift_norm,
650
- down=True,
651
- ) if resblock_updown else Downsample(
652
- ch, conv_resample, dims=dims, out_channels=out_ch))
653
- )
654
  ch = out_ch
655
  input_block_chans.append(ch)
656
  ds *= 2
@@ -679,16 +610,8 @@ class MultiViewUNetModel(nn.Module):
679
  num_heads=num_heads,
680
  num_head_channels=dim_head,
681
  use_new_attention_order=use_new_attention_order,
682
- ) if not use_spatial_transformer else
683
- SpatialTransformer3D( # always uses a self-attn
684
- ch,
685
- num_heads,
686
- dim_head,
687
- depth=transformer_depth,
688
- context_dim=context_dim,
689
- disable_self_attn=disable_middle_self_attn,
690
- use_linear=use_linear_in_transformer,
691
- use_checkpoint=use_checkpoint),
692
  ResBlock(
693
  ch,
694
  time_embed_dim,
@@ -704,17 +627,15 @@ class MultiViewUNetModel(nn.Module):
704
  for level, mult in list(enumerate(channel_mult))[::-1]:
705
  for i in range(self.num_res_blocks[level] + 1):
706
  ich = input_block_chans.pop()
707
- layers = [
708
- ResBlock(
709
- ch + ich,
710
- time_embed_dim,
711
- dropout,
712
- out_channels=model_channels * mult,
713
- dims=dims,
714
- use_checkpoint=use_checkpoint,
715
- use_scale_shift_norm=use_scale_shift_norm,
716
- )
717
- ]
718
  ch = model_channels * mult
719
  if ds in attention_resolutions:
720
  if num_head_channels == -1:
@@ -730,39 +651,26 @@ class MultiViewUNetModel(nn.Module):
730
  else:
731
  disabled_sa = False
732
 
733
- if num_attention_blocks is None or i < num_attention_blocks[
734
- level]:
735
- layers.append(
736
- AttentionBlock(
737
- ch,
738
- use_checkpoint=use_checkpoint,
739
- num_heads=num_heads_upsample,
740
- num_head_channels=dim_head,
741
- use_new_attention_order=use_new_attention_order,
742
- ) if not use_spatial_transformer else
743
- SpatialTransformer3D(
744
- ch,
745
- num_heads,
746
- dim_head,
747
- depth=transformer_depth,
748
- context_dim=context_dim,
749
- disable_self_attn=disabled_sa,
750
- use_linear=use_linear_in_transformer,
751
- use_checkpoint=use_checkpoint))
752
- if level and i == self.num_res_blocks[level]:
753
- out_ch = ch
754
- layers.append(
755
- ResBlock(
756
  ch,
757
- time_embed_dim,
758
- dropout,
759
- out_channels=out_ch,
760
- dims=dims,
761
  use_checkpoint=use_checkpoint,
762
- use_scale_shift_norm=use_scale_shift_norm,
763
- up=True,
764
- ) if resblock_updown else Upsample(
765
- ch, conv_resample, dims=dims, out_channels=out_ch))
 
 
 
 
 
 
 
 
 
 
 
 
766
  ds //= 2
767
  self.output_blocks.append(TimestepEmbedSequential(*layers))
768
  self._feature_size += ch
@@ -770,8 +678,7 @@ class MultiViewUNetModel(nn.Module):
770
  self.out = nn.Sequential(
771
  normalization(ch),
772
  nn.SiLU(),
773
- zero_module(
774
- conv_nd(dims, model_channels, out_channels, 3, padding=1)),
775
  )
776
  if self.predict_codebook_ids:
777
  self.id_predictor = nn.Sequential(
@@ -796,14 +703,7 @@ class MultiViewUNetModel(nn.Module):
796
  self.middle_block.apply(convert_module_to_f32)
797
  self.output_blocks.apply(convert_module_to_f32)
798
 
799
- def forward(self,
800
- x,
801
- timesteps=None,
802
- context=None,
803
- y: Optional[Tensor] = None,
804
- camera=None,
805
- num_frames=1,
806
- **kwargs):
807
  """
808
  Apply the model to an input batch.
809
  :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
@@ -813,15 +713,10 @@ class MultiViewUNetModel(nn.Module):
813
  :param num_frames: a integer indicating number of frames for tensor reshaping.
814
  :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
815
  """
816
- assert x.shape[
817
- 0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!"
818
- assert (y is not None) == (
819
- self.num_classes is not None
820
- ), "must specify y if and only if the model is class-conditional"
821
  hs = []
822
- t_emb = timestep_embedding(timesteps,
823
- self.model_channels,
824
- repeat_only=False)
825
  emb = self.time_embed(t_emb)
826
 
827
  if self.num_classes is not None:
 
82
  upsampling occurs in the inner-two dimensions.
83
  """
84
 
85
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
 
 
 
 
 
86
  super().__init__()
87
  self.channels = channels
88
  self.out_channels = out_channels or channels
89
  self.use_conv = use_conv
90
  self.dims = dims
91
  if use_conv:
92
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
 
 
 
 
93
 
94
  def forward(self, x):
95
  assert x.shape[1] == self.channels
96
  if self.dims == 3:
97
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
 
98
  else:
99
  x = F.interpolate(x, scale_factor=2, mode="nearest")
100
  if self.use_conv:
 
111
  downsampling occurs in the inner-two dimensions.
112
  """
113
 
114
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
 
 
 
 
 
115
  super().__init__()
116
  self.channels = channels
117
  self.out_channels = out_channels or channels
 
119
  self.dims = dims
120
  stride = 2 if dims != 3 else (1, 2, 2)
121
  if use_conv:
122
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
 
 
 
 
 
123
  else:
124
  assert self.channels == self.out_channels
125
  self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
 
188
  nn.SiLU(),
189
  linear(
190
  emb_channels,
191
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
 
192
  ),
193
  )
194
  self.out_layers = nn.Sequential(
195
  normalization(self.out_channels),
196
  nn.SiLU(),
197
  nn.Dropout(p=dropout),
198
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
 
 
 
 
 
199
  )
200
 
201
  if self.out_channels == channels:
202
  self.skip_connection = nn.Identity()
203
  elif use_conv:
204
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
 
 
 
 
205
  else:
206
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
 
207
 
208
  def forward(self, x, emb):
209
  """
 
212
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
213
  :return: an [N x C x ...] Tensor of outputs.
214
  """
215
+ return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
 
216
 
217
  def _forward(self, x, emb):
218
  if self.updown:
 
257
  if num_head_channels == -1:
258
  self.num_heads = num_heads
259
  else:
260
+ assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
 
 
261
  self.num_heads = channels // num_head_channels
262
  self.use_checkpoint = use_checkpoint
263
  self.norm = normalization(channels)
 
272
  self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
273
 
274
  def forward(self, x):
275
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
 
 
276
  #return pt_checkpoint(self._forward, x) # pytorch
277
 
278
  def _forward(self, x):
 
322
  bs, width, length = qkv.shape
323
  assert width % (3 * self.n_heads) == 0
324
  ch = width // (3 * self.n_heads)
325
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
 
326
  scale = 1 / math.sqrt(math.sqrt(ch))
327
+ weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
 
 
328
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
329
  a = th.einsum("bts,bcs->bct", weight, v)
330
  return a.reshape(bs, -1, length)
 
358
  "bct,bcs->bts",
359
  (q * scale).view(bs * self.n_heads, ch, length),
360
  (k * scale).view(bs * self.n_heads, ch, length),
361
+ ) # More stable with f16 than dividing afterwards
362
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
363
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
 
364
  return a.reshape(bs, -1, length)
365
 
366
  @staticmethod
 
410
  """
411
 
412
  def __init__(
413
+ self,
414
+ image_size,
415
+ in_channels,
416
+ model_channels,
417
+ out_channels,
418
+ num_res_blocks,
419
+ attention_resolutions,
420
+ dropout=0,
421
+ channel_mult=(1, 2, 4, 8),
422
+ conv_resample=True,
423
+ dims=2,
424
+ num_classes=None,
425
+ use_checkpoint=False,
426
+ use_fp16=False,
427
+ use_bf16=False,
428
+ num_heads=-1,
429
+ num_head_channels=-1,
430
+ num_heads_upsample=-1,
431
+ use_scale_shift_norm=False,
432
+ resblock_updown=False,
433
+ use_new_attention_order=False,
434
+ use_spatial_transformer=False, # custom transformer support
435
+ transformer_depth=1, # custom transformer support
436
+ context_dim=None, # custom transformer support
437
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
438
+ legacy=True,
439
+ disable_self_attentions=None,
440
+ num_attention_blocks=None,
441
+ disable_middle_self_attn=False,
442
+ use_linear_in_transformer=False,
443
+ adm_in_channels=None,
444
+ camera_dim=None,
445
  ):
446
  super().__init__()
 
447
  if use_spatial_transformer:
448
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
449
 
 
470
  self.num_res_blocks = len(channel_mult) * [num_res_blocks]
471
  else:
472
  if len(num_res_blocks) != len(channel_mult):
473
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
474
+ "as a list/tuple (per-level) with the same length as channel_mult")
 
 
475
  self.num_res_blocks = num_res_blocks
476
  if disable_self_attentions is not None:
477
  # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
478
  assert len(disable_self_attentions) == len(channel_mult)
479
  if num_attention_blocks is not None:
480
  assert len(num_attention_blocks) == len(self.num_res_blocks)
481
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
482
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
483
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
484
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
485
+ f"attention will still not be set.")
 
 
 
 
 
486
 
487
  self.attention_resolutions = attention_resolutions
488
  self.dropout = dropout
 
514
 
515
  if self.num_classes is not None:
516
  if isinstance(self.num_classes, int):
517
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
518
  elif self.num_classes == "continuous":
519
  print("setting up linear c_adm embedding layer")
520
  self.label_emb = nn.Linear(1, time_embed_dim)
521
  elif self.num_classes == "sequential":
522
  assert adm_in_channels is not None
523
+ self.label_emb = nn.Sequential(nn.Sequential(
524
+ linear(adm_in_channels, time_embed_dim),
525
+ nn.SiLU(),
526
+ linear(time_embed_dim, time_embed_dim),
527
+ ))
 
528
  else:
529
  raise ValueError()
530
 
531
+ self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))])
 
 
 
532
  self._feature_size = model_channels
533
  input_block_chans = [model_channels]
534
  ch = model_channels
535
  ds = 1
536
  for level, mult in enumerate(channel_mult):
537
  for nr in range(self.num_res_blocks[level]):
538
+ layers: List[Any] = [ResBlock(
539
+ ch,
540
+ time_embed_dim,
541
+ dropout,
542
+ out_channels=mult * model_channels,
543
+ dims=dims,
544
+ use_checkpoint=use_checkpoint,
545
+ use_scale_shift_norm=use_scale_shift_norm,
546
+ )]
 
 
547
  ch = mult * model_channels
548
  if ds in attention_resolutions:
549
  if num_head_channels == -1:
 
559
  else:
560
  disabled_sa = False
561
 
562
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
563
+ layers.append(AttentionBlock(
564
+ ch,
565
+ use_checkpoint=use_checkpoint,
566
+ num_heads=num_heads,
567
+ num_head_channels=dim_head,
568
+ use_new_attention_order=use_new_attention_order,
569
+ ) if not use_spatial_transformer else SpatialTransformer3D(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint))
 
 
 
 
 
 
 
 
 
 
 
570
  self.input_blocks.append(TimestepEmbedSequential(*layers))
571
  self._feature_size += ch
572
  input_block_chans.append(ch)
573
  if level != len(channel_mult) - 1:
574
  out_ch = ch
575
+ self.input_blocks.append(TimestepEmbedSequential(ResBlock(
576
+ ch,
577
+ time_embed_dim,
578
+ dropout,
579
+ out_channels=out_ch,
580
+ dims=dims,
581
+ use_checkpoint=use_checkpoint,
582
+ use_scale_shift_norm=use_scale_shift_norm,
583
+ down=True,
584
+ ) if resblock_updown else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)))
 
 
 
 
585
  ch = out_ch
586
  input_block_chans.append(ch)
587
  ds *= 2
 
610
  num_heads=num_heads,
611
  num_head_channels=dim_head,
612
  use_new_attention_order=use_new_attention_order,
613
+ ) if not use_spatial_transformer else SpatialTransformer3D( # always uses a self-attn
614
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint),
 
 
 
 
 
 
 
 
615
  ResBlock(
616
  ch,
617
  time_embed_dim,
 
627
  for level, mult in list(enumerate(channel_mult))[::-1]:
628
  for i in range(self.num_res_blocks[level] + 1):
629
  ich = input_block_chans.pop()
630
+ layers = [ResBlock(
631
+ ch + ich,
632
+ time_embed_dim,
633
+ dropout,
634
+ out_channels=model_channels * mult,
635
+ dims=dims,
636
+ use_checkpoint=use_checkpoint,
637
+ use_scale_shift_norm=use_scale_shift_norm,
638
+ )]
 
 
639
  ch = model_channels * mult
640
  if ds in attention_resolutions:
641
  if num_head_channels == -1:
 
651
  else:
652
  disabled_sa = False
653
 
654
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
655
+ layers.append(AttentionBlock(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  ch,
 
 
 
 
657
  use_checkpoint=use_checkpoint,
658
+ num_heads=num_heads_upsample,
659
+ num_head_channels=dim_head,
660
+ use_new_attention_order=use_new_attention_order,
661
+ ) if not use_spatial_transformer else SpatialTransformer3D(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint))
662
+ if level and i == self.num_res_blocks[level]:
663
+ out_ch = ch
664
+ layers.append(ResBlock(
665
+ ch,
666
+ time_embed_dim,
667
+ dropout,
668
+ out_channels=out_ch,
669
+ dims=dims,
670
+ use_checkpoint=use_checkpoint,
671
+ use_scale_shift_norm=use_scale_shift_norm,
672
+ up=True,
673
+ ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch))
674
  ds //= 2
675
  self.output_blocks.append(TimestepEmbedSequential(*layers))
676
  self._feature_size += ch
 
678
  self.out = nn.Sequential(
679
  normalization(ch),
680
  nn.SiLU(),
681
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
 
682
  )
683
  if self.predict_codebook_ids:
684
  self.id_predictor = nn.Sequential(
 
703
  self.middle_block.apply(convert_module_to_f32)
704
  self.output_blocks.apply(convert_module_to_f32)
705
 
706
+ def forward(self, x, timesteps=None, context=None, y: Optional[Tensor] = None, camera=None, num_frames=1, **kwargs):
 
 
 
 
 
 
 
707
  """
708
  Apply the model to an input batch.
709
  :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
 
713
  :param num_frames: a integer indicating number of frames for tensor reshaping.
714
  :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
715
  """
716
+ assert x.shape[0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!"
717
+ assert (y is not None) == (self.num_classes is not None), "must specify y if and only if the model is class-conditional"
 
 
 
718
  hs = []
719
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
 
 
720
  emb = self.time_embed(t_emb)
721
 
722
  if self.num_classes is not None:
scripts/pipeline_mvdream.py CHANGED
@@ -1,9 +1,8 @@
 
 
1
  import inspect
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
-
4
- import torch
5
  from transformers import CLIPTextModel, CLIPTokenizer
6
-
7
  from diffusers import AutoencoderKL, DiffusionPipeline
8
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
9
  from diffusers.utils import (
@@ -13,22 +12,16 @@ from diffusers.utils import (
13
  logging,
14
  replace_example_docstring,
15
  )
16
-
17
- try:
18
- from diffusers import randn_tensor # old import
19
- except ImportError:
20
- from diffusers.utils.torch_utils import randn_tensor # new import
21
-
22
  from diffusers.configuration_utils import FrozenDict
23
- import numpy as np
24
  from diffusers.schedulers import DDIMScheduler
25
- from models import MultiViewUNetModel, MultiViewUNetWrapperModel
26
-
27
- EXAMPLE_DOC_STRING = ""
 
28
 
29
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
 
31
- import numpy as np
32
 
33
  def create_camera_to_world_matrix(elevation, azimuth):
34
  elevation = np.radians(elevation)
@@ -59,29 +52,21 @@ def create_camera_to_world_matrix(elevation, azimuth):
59
  def convert_opengl_to_blender(camera_matrix):
60
  if isinstance(camera_matrix, np.ndarray):
61
  # Construct transformation matrix to convert from OpenGL space to Blender space
62
- flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
63
- [0, 0, 0, 1]])
64
  camera_matrix_blender = np.dot(flip_yz, camera_matrix)
65
  else:
66
  # Construct transformation matrix to convert from OpenGL space to Blender space
67
- flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
68
- [0, 0, 0, 1]])
69
  if camera_matrix.ndim == 3:
70
  flip_yz = flip_yz.unsqueeze(0)
71
- camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix),
72
- camera_matrix)
73
  return camera_matrix_blender
74
 
75
 
76
- def get_camera(num_frames,
77
- elevation=15,
78
- azimuth_start=0,
79
- azimuth_span=360,
80
- blender_coord=True):
81
  angle_gap = azimuth_span / num_frames
82
  cameras = []
83
- for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start,
84
- angle_gap):
85
  camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
86
  if blender_coord:
87
  camera_matrix = convert_opengl_to_blender(camera_matrix)
@@ -101,36 +86,25 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
101
  ):
102
  super().__init__()
103
 
104
- if hasattr(scheduler.config,
105
- "steps_offset") and scheduler.config.steps_offset != 1:
106
- deprecation_message = (
107
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
108
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
109
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
110
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
111
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
112
- " file")
113
- deprecate("steps_offset!=1",
114
- "1.0.0",
115
- deprecation_message,
116
- standard_warn=False)
117
  new_config = dict(scheduler.config)
118
  new_config["steps_offset"] = 1
119
  scheduler._internal_dict = FrozenDict(new_config)
120
 
121
- if hasattr(scheduler.config,
122
- "clip_sample") and scheduler.config.clip_sample is True:
123
- deprecation_message = (
124
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
125
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
126
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
127
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
128
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
129
- )
130
- deprecate("clip_sample not set",
131
- "1.0.0",
132
- deprecation_message,
133
- standard_warn=False)
134
  new_config = dict(scheduler.config)
135
  new_config["clip_sample"] = False
136
  scheduler._internal_dict = FrozenDict(new_config)
@@ -142,8 +116,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
142
  tokenizer=tokenizer,
143
  text_encoder=text_encoder,
144
  )
145
- self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) -
146
- 1)
147
  self.register_to_config(requires_safety_checker=False)
148
 
149
  def enable_vae_slicing(self):
@@ -189,16 +162,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
189
  if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
190
  from accelerate import cpu_offload
191
  else:
192
- raise ImportError(
193
- "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
194
- )
195
 
196
  device = torch.device(f"cuda:{gpu_id}")
197
 
198
  if self.device.type != "cpu":
199
  self.to("cpu", silence_dtype_warnings=True)
200
- torch.cuda.empty_cache(
201
- ) # otherwise we don't see the memory savings (but they probably exist)
202
 
203
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
204
  cpu_offload(cpu_offloaded_model, device)
@@ -210,26 +180,20 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
210
  method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
211
  `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
212
  """
213
- if is_accelerate_available() and is_accelerate_version(
214
- ">=", "0.17.0.dev0"):
215
  from accelerate import cpu_offload_with_hook
216
  else:
217
- raise ImportError(
218
- "`enable_model_offload` requires `accelerate v0.17.0` or higher."
219
- )
220
 
221
  device = torch.device(f"cuda:{gpu_id}")
222
 
223
  if self.device.type != "cpu":
224
  self.to("cpu", silence_dtype_warnings=True)
225
- torch.cuda.empty_cache(
226
- ) # otherwise we don't see the memory savings (but they probably exist)
227
 
228
  hook = None
229
  for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
230
- _, hook = cpu_offload_with_hook(cpu_offloaded_model,
231
- device,
232
- prev_module_hook=hook)
233
 
234
  # We'll offload the last model manually.
235
  self.final_offload_hook = hook
@@ -244,9 +208,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
244
  if not hasattr(self.unet, "_hf_hook"):
245
  return self.device
246
  for module in self.unet.modules():
247
- if (hasattr(module, "_hf_hook")
248
- and hasattr(module._hf_hook, "execution_device")
249
- and module._hf_hook.execution_device is not None):
250
  return torch.device(module._hf_hook.execution_device)
251
  return self.device
252
 
@@ -257,8 +219,6 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
257
  num_images_per_prompt,
258
  do_classifier_free_guidance: bool,
259
  negative_prompt=None,
260
- prompt_embeds: Optional[torch.FloatTensor] = None,
261
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
262
  ):
263
  r"""
264
  Encodes the prompt into text encoder hidden states.
@@ -289,67 +249,55 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
289
  elif prompt is not None and isinstance(prompt, list):
290
  batch_size = len(prompt)
291
  else:
292
- batch_size = prompt_embeds.shape[0]
 
 
 
 
 
 
 
 
 
 
293
 
294
- if prompt_embeds is None:
295
- text_inputs = self.tokenizer(
296
- prompt,
297
- padding="max_length",
298
- max_length=self.tokenizer.model_max_length,
299
- truncation=True,
300
- return_tensors="pt",
301
- )
302
- text_input_ids = text_inputs.input_ids
303
- untruncated_ids = self.tokenizer(prompt,
304
- padding="longest",
305
- return_tensors="pt").input_ids
306
-
307
- if untruncated_ids.shape[-1] >= text_input_ids.shape[
308
- -1] and not torch.equal(text_input_ids, untruncated_ids):
309
- removed_text = self.tokenizer.batch_decode(
310
- untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
311
- logger.warning(
312
- "The following part of your input was truncated because CLIP can only handle sequences up to"
313
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
314
- )
315
 
316
- if hasattr(self.text_encoder.config, "use_attention_mask"
317
- ) and self.text_encoder.config.use_attention_mask:
318
- attention_mask = text_inputs.attention_mask.to(device)
319
- else:
320
- attention_mask = None
321
 
322
- prompt_embeds = self.text_encoder(
323
- text_input_ids.to(device),
324
- attention_mask=attention_mask,
325
- )
326
- prompt_embeds = prompt_embeds[0]
327
 
328
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype,
329
- device=device)
330
 
331
  bs_embed, seq_len, _ = prompt_embeds.shape
332
  # duplicate text embeddings for each generation per prompt, using mps friendly method
333
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
334
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
335
- seq_len, -1)
336
 
337
  # get unconditional embeddings for classifier free guidance
338
- if do_classifier_free_guidance and negative_prompt_embeds is None:
339
  uncond_tokens: List[str]
340
  if negative_prompt is None:
341
  uncond_tokens = [""] * batch_size
342
  elif type(prompt) is not type(negative_prompt):
343
- raise TypeError(
344
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
345
- f" {type(prompt)}.")
346
  elif isinstance(negative_prompt, str):
347
  uncond_tokens = [negative_prompt]
348
  elif batch_size != len(negative_prompt):
349
- raise ValueError(
350
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
351
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
352
- " the batch size of `prompt`.")
353
  else:
354
  uncond_tokens = negative_prompt
355
 
@@ -362,8 +310,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
362
  return_tensors="pt",
363
  )
364
 
365
- if hasattr(self.text_encoder.config, "use_attention_mask"
366
- ) and self.text_encoder.config.use_attention_mask:
367
  attention_mask = uncond_input.attention_mask.to(device)
368
  else:
369
  attention_mask = None
@@ -374,17 +321,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
374
  )
375
  negative_prompt_embeds = negative_prompt_embeds[0]
376
 
377
- if do_classifier_free_guidance:
378
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
379
  seq_len = negative_prompt_embeds.shape[1]
380
 
381
- negative_prompt_embeds = negative_prompt_embeds.to(
382
- dtype=self.text_encoder.dtype, device=device)
383
 
384
- negative_prompt_embeds = negative_prompt_embeds.repeat(
385
- 1, num_images_per_prompt, 1)
386
- negative_prompt_embeds = negative_prompt_embeds.view(
387
- batch_size * num_images_per_prompt, seq_len, -1)
388
 
389
  # For classifier free guidance, we need to do two forward passes.
390
  # Here we concatenate the unconditional and text embeddings into a single batch
@@ -407,42 +350,25 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
407
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
408
  # and should be between [0, 1]
409
 
410
- accepts_eta = "eta" in set(
411
- inspect.signature(self.scheduler.step).parameters.keys())
412
  extra_step_kwargs = {}
413
  if accepts_eta:
414
  extra_step_kwargs["eta"] = eta
415
 
416
  # check if the scheduler accepts generator
417
- accepts_generator = "generator" in set(
418
- inspect.signature(self.scheduler.step).parameters.keys())
419
  if accepts_generator:
420
  extra_step_kwargs["generator"] = generator
421
  return extra_step_kwargs
422
 
423
- def prepare_latents(self,
424
- batch_size,
425
- num_channels_latents,
426
- height,
427
- width,
428
- dtype,
429
- device,
430
- generator,
431
- latents=None):
432
- shape = (batch_size, num_channels_latents,
433
- height // self.vae_scale_factor,
434
- width // self.vae_scale_factor)
435
  if isinstance(generator, list) and len(generator) != batch_size:
436
- raise ValueError(
437
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
438
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
439
- )
440
 
441
  if latents is None:
442
- latents = randn_tensor(shape,
443
- generator=generator,
444
- device=device,
445
- dtype=dtype)
446
  else:
447
  latents = latents.to(device)
448
 
@@ -451,7 +377,6 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
451
  return latents
452
 
453
  @torch.no_grad()
454
- @replace_example_docstring(EXAMPLE_DOC_STRING)
455
  def __call__(
456
  self,
457
  height: int = 256,
@@ -462,87 +387,11 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
462
  negative_prompt: str = "bad quality",
463
  num_images_per_prompt: int = 1,
464
  eta: float = 0.0,
465
- generator: Optional[Union[torch.Generator,
466
- List[torch.Generator]]] = None,
467
  output_type: Optional[str] = "pil",
468
- return_dict: bool = True,
469
- callback: Optional[Callable[[int, int, torch.FloatTensor],
470
- None]] = None,
471
  callback_steps: int = 1,
472
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
473
- controlnet_conditioning_scale: float = 1.0,
474
  ):
475
- r"""
476
- Function invoked when calling the pipeline for generation.
477
-
478
- Args:
479
- input_imgs (`PIL` or `List[PIL]`, *optional*):
480
- The single input image for each 3D object
481
- prompt_imgs (`PIL` or `List[PIL]`, *optional*):
482
- Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature
483
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
484
- The height in pixels of the generated image.
485
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
486
- The width in pixels of the generated image.
487
- num_inference_steps (`int`, *optional*, defaults to 50):
488
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
489
- expense of slower inference.
490
- guidance_scale (`float`, *optional*, defaults to 7.5):
491
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
492
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
493
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
494
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
495
- usually at the expense of lower image quality.
496
- negative_prompt (`str` or `List[str]`, *optional*):
497
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
498
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
499
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
500
- num_images_per_prompt (`int`, *optional*, defaults to 1):
501
- The number of images to generate per prompt.
502
- eta (`float`, *optional*, defaults to 0.0):
503
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
504
- [`schedulers.DDIMScheduler`], will be ignored for others.
505
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
506
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
507
- to make generation deterministic.
508
- latents (`torch.FloatTensor`, *optional*):
509
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
510
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
511
- tensor will ge generated by sampling using the supplied random `generator`.
512
- prompt_embeds (`torch.FloatTensor`, *optional*):
513
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
514
- provided, text embeddings will be generated from `prompt` input argument.
515
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
516
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
517
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
518
- argument.
519
- output_type (`str`, *optional*, defaults to `"pil"`):
520
- The output format of the generate image. Choose between
521
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
522
- return_dict (`bool`, *optional*, defaults to `True`):
523
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
524
- plain tuple.
525
- callback (`Callable`, *optional*):
526
- A function that will be called every `callback_steps` steps during inference. The function will be
527
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
528
- callback_steps (`int`, *optional*, defaults to 1):
529
- The frequency at which the `callback` function will be called. If not specified, the callback will be
530
- called at every step.
531
- cross_attention_kwargs (`dict`, *optional*):
532
- A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
533
- `self.processor` in
534
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
535
-
536
- Examples:
537
-
538
- Returns:
539
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
540
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
541
- When returning a tuple, the first element is a list with the generated images, and the second element is a
542
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
543
- (nsfw) content, according to the `safety_checker`.
544
- """
545
- # 0. Default height and width to unet
546
  batch_size = 4
547
  device = torch.device("cuda:0")
548
 
@@ -553,7 +402,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
553
  # corresponds to doing no classifier free guidance.
554
  do_classifier_free_guidance = guidance_scale > 1.0
555
 
556
- # 4. Prepare timesteps
557
  self.scheduler.set_timesteps(num_inference_steps, device=device)
558
  timesteps = self.scheduler.timesteps
559
 
@@ -563,26 +412,10 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
563
  num_images_per_prompt=num_images_per_prompt,
564
  do_classifier_free_guidance=do_classifier_free_guidance,
565
  negative_prompt=negative_prompt,
566
- ) # type: ignore
567
  prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
568
-
569
- _, prompt_embeds_pos_2 = self._encode_prompt(
570
- prompt="watermellon",
571
- device=device,
572
- num_images_per_prompt=num_images_per_prompt,
573
- do_classifier_free_guidance=do_classifier_free_guidance,
574
- negative_prompt=negative_prompt,
575
- ).chunk(2) # type: ignore
576
-
577
- _, prompt_embeds_pos_4 = self._encode_prompt(
578
- prompt="long hair",
579
- device=device,
580
- num_images_per_prompt=num_images_per_prompt,
581
- do_classifier_free_guidance=do_classifier_free_guidance,
582
- negative_prompt=negative_prompt,
583
- ).chunk(2) # type: ignore
584
 
585
- # 5. Prepare latent variables
586
  latents: torch.Tensor = self.prepare_latents(
587
  batch_size * num_images_per_prompt,
588
  4,
@@ -594,33 +427,23 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
594
  None,
595
  )
596
 
597
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
598
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
599
 
600
- # 7. Denoising loop
601
- num_warmup_steps = len(
602
- timesteps) - num_inference_steps * self.scheduler.order
603
  with self.progress_bar(total=num_inference_steps) as progress_bar:
604
  for i, t in enumerate(timesteps):
605
  # expand the latents if we are doing classifier free guidance
606
  multiplier = 2 if do_classifier_free_guidance else 1
607
  latent_model_input = torch.cat([latents] * multiplier)
608
- latent_model_input = self.scheduler.scale_model_input(
609
- latent_model_input, t)
610
 
611
  # predict the noise residual
612
- # print(
613
- # f"shape of latent_model_input: {latent_model_input.shape}"
614
- # ) # [2*4, 4, 32, 32]
615
- # print(f"shape of prompt_embeds: {prompt_embeds.shape}"
616
- # ) # [2*4, 77, 768]
617
- # print(f"shape of camera: {camera.shape}") # [4, 16]
618
  noise_pred = self.unet.forward(
619
  x=latent_model_input,
620
- timesteps=torch.tensor([t] * 4 * multiplier,
621
- device=device),
622
- context=torch.cat([prompt_embeds_neg] * 4 +
623
- [prompt_embeds_pos, prompt_embeds_pos_2, prompt_embeds_pos, prompt_embeds_pos_4]),
624
  num_frames=4,
625
  camera=torch.cat([camera] * multiplier),
626
  )
@@ -628,46 +451,29 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
628
  # perform guidance
629
  if do_classifier_free_guidance:
630
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
631
- noise_pred = noise_pred_uncond + guidance_scale * (
632
- noise_pred_text - noise_pred_uncond)
633
 
634
  # compute the previous noisy sample x_t -> x_t-1
635
  # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
636
- latents: torch.Tensor = self.scheduler.step(
637
- noise_pred,
638
- t,
639
- latents,
640
- **extra_step_kwargs,
641
- return_dict=False)[0]
642
 
643
  # call the callback, if provided
644
- if i == len(timesteps) - 1 or (
645
- (i + 1) > num_warmup_steps and
646
- (i + 1) % self.scheduler.order == 0):
647
  progress_bar.update()
648
  if callback is not None and i % callback_steps == 0:
649
- callback(i, t, latents) # type: ignore
650
 
651
- # 8. Post-processing
652
  if output_type == "latent":
653
  image = latents
654
  elif output_type == "pil":
655
- # 8. Post-processing
656
  image = self.decode_latents(latents)
657
- # 10. Convert to PIL
658
  image = self.numpy_to_pil(image)
659
  else:
660
- # 8. Post-processing
661
  image = self.decode_latents(latents)
662
 
663
  # Offload last model to CPU
664
- if hasattr(
665
- self,
666
- "final_offload_hook") and self.final_offload_hook is not None:
667
  self.final_offload_hook.offload()
668
 
669
- if not return_dict:
670
- return image
671
-
672
- return StableDiffusionPipelineOutput(images=image,
673
- nsfw_content_detected=None)
 
1
+ import torch
2
+ import numpy as np
3
  import inspect
4
  from typing import Any, Callable, Dict, List, Optional, Union
 
 
5
  from transformers import CLIPTextModel, CLIPTokenizer
 
6
  from diffusers import AutoencoderKL, DiffusionPipeline
7
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
8
  from diffusers.utils import (
 
12
  logging,
13
  replace_example_docstring,
14
  )
 
 
 
 
 
 
15
  from diffusers.configuration_utils import FrozenDict
 
16
  from diffusers.schedulers import DDIMScheduler
17
+ try:
18
+ from diffusers import randn_tensor # old import # type: ignore
19
+ except ImportError:
20
+ from diffusers.utils.torch_utils import randn_tensor # new import # type: ignore
21
 
22
+ from models import MultiViewUNetWrapperModel
23
 
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
 
26
  def create_camera_to_world_matrix(elevation, azimuth):
27
  elevation = np.radians(elevation)
 
52
  def convert_opengl_to_blender(camera_matrix):
53
  if isinstance(camera_matrix, np.ndarray):
54
  # Construct transformation matrix to convert from OpenGL space to Blender space
55
+ flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
 
56
  camera_matrix_blender = np.dot(flip_yz, camera_matrix)
57
  else:
58
  # Construct transformation matrix to convert from OpenGL space to Blender space
59
+ flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
 
60
  if camera_matrix.ndim == 3:
61
  flip_yz = flip_yz.unsqueeze(0)
62
+ camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
 
63
  return camera_matrix_blender
64
 
65
 
66
+ def get_camera(num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True):
 
 
 
 
67
  angle_gap = azimuth_span / num_frames
68
  cameras = []
69
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
 
70
  camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
71
  if blender_coord:
72
  camera_matrix = convert_opengl_to_blender(camera_matrix)
 
86
  ):
87
  super().__init__()
88
 
89
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
90
+ deprecation_message = (f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
91
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
92
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
93
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
94
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
95
+ " file")
96
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
 
 
 
 
 
97
  new_config = dict(scheduler.config)
98
  new_config["steps_offset"] = 1
99
  scheduler._internal_dict = FrozenDict(new_config)
100
 
101
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
102
+ deprecation_message = (f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
103
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
104
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
105
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
106
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file")
107
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
 
 
 
 
 
 
108
  new_config = dict(scheduler.config)
109
  new_config["clip_sample"] = False
110
  scheduler._internal_dict = FrozenDict(new_config)
 
116
  tokenizer=tokenizer,
117
  text_encoder=text_encoder,
118
  )
119
+ self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1)
 
120
  self.register_to_config(requires_safety_checker=False)
121
 
122
  def enable_vae_slicing(self):
 
162
  if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
163
  from accelerate import cpu_offload
164
  else:
165
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
 
 
166
 
167
  device = torch.device(f"cuda:{gpu_id}")
168
 
169
  if self.device.type != "cpu":
170
  self.to("cpu", silence_dtype_warnings=True)
171
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
 
172
 
173
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
174
  cpu_offload(cpu_offloaded_model, device)
 
180
  method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
181
  `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
182
  """
183
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
 
184
  from accelerate import cpu_offload_with_hook
185
  else:
186
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
 
 
187
 
188
  device = torch.device(f"cuda:{gpu_id}")
189
 
190
  if self.device.type != "cpu":
191
  self.to("cpu", silence_dtype_warnings=True)
192
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
 
193
 
194
  hook = None
195
  for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
196
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
 
 
197
 
198
  # We'll offload the last model manually.
199
  self.final_offload_hook = hook
 
208
  if not hasattr(self.unet, "_hf_hook"):
209
  return self.device
210
  for module in self.unet.modules():
211
+ if (hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None):
 
 
212
  return torch.device(module._hf_hook.execution_device)
213
  return self.device
214
 
 
219
  num_images_per_prompt,
220
  do_classifier_free_guidance: bool,
221
  negative_prompt=None,
 
 
222
  ):
223
  r"""
224
  Encodes the prompt into text encoder hidden states.
 
249
  elif prompt is not None and isinstance(prompt, list):
250
  batch_size = len(prompt)
251
  else:
252
+ raise ValueError(f"`prompt` should be either a string or a list of strings, but got {type(prompt)}.")
253
+
254
+ text_inputs = self.tokenizer(
255
+ prompt,
256
+ padding="max_length",
257
+ max_length=self.tokenizer.model_max_length,
258
+ truncation=True,
259
+ return_tensors="pt",
260
+ )
261
+ text_input_ids = text_inputs.input_ids
262
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
263
 
264
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
265
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
266
+ logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"
267
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
270
+ attention_mask = text_inputs.attention_mask.to(device)
271
+ else:
272
+ attention_mask = None
 
273
 
274
+ prompt_embeds = self.text_encoder(
275
+ text_input_ids.to(device),
276
+ attention_mask=attention_mask,
277
+ )
278
+ prompt_embeds = prompt_embeds[0]
279
 
280
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
281
 
282
  bs_embed, seq_len, _ = prompt_embeds.shape
283
  # duplicate text embeddings for each generation per prompt, using mps friendly method
284
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
285
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
286
 
287
  # get unconditional embeddings for classifier free guidance
288
+ if do_classifier_free_guidance:
289
  uncond_tokens: List[str]
290
  if negative_prompt is None:
291
  uncond_tokens = [""] * batch_size
292
  elif type(prompt) is not type(negative_prompt):
293
+ raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
294
+ f" {type(prompt)}.")
 
295
  elif isinstance(negative_prompt, str):
296
  uncond_tokens = [negative_prompt]
297
  elif batch_size != len(negative_prompt):
298
+ raise ValueError(f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
299
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
300
+ " the batch size of `prompt`.")
 
301
  else:
302
  uncond_tokens = negative_prompt
303
 
 
310
  return_tensors="pt",
311
  )
312
 
313
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
 
314
  attention_mask = uncond_input.attention_mask.to(device)
315
  else:
316
  attention_mask = None
 
321
  )
322
  negative_prompt_embeds = negative_prompt_embeds[0]
323
 
 
324
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
325
  seq_len = negative_prompt_embeds.shape[1]
326
 
327
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
328
 
329
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
330
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
331
 
332
  # For classifier free guidance, we need to do two forward passes.
333
  # Here we concatenate the unconditional and text embeddings into a single batch
 
350
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
  # and should be between [0, 1]
352
 
353
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
354
  extra_step_kwargs = {}
355
  if accepts_eta:
356
  extra_step_kwargs["eta"] = eta
357
 
358
  # check if the scheduler accepts generator
359
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
360
  if accepts_generator:
361
  extra_step_kwargs["generator"] = generator
362
  return extra_step_kwargs
363
 
364
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
365
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
 
 
 
 
 
 
366
  if isinstance(generator, list) and len(generator) != batch_size:
367
+ raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
368
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators.")
 
 
369
 
370
  if latents is None:
371
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
 
372
  else:
373
  latents = latents.to(device)
374
 
 
377
  return latents
378
 
379
  @torch.no_grad()
 
380
  def __call__(
381
  self,
382
  height: int = 256,
 
387
  negative_prompt: str = "bad quality",
388
  num_images_per_prompt: int = 1,
389
  eta: float = 0.0,
390
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
391
  output_type: Optional[str] = "pil",
392
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
 
393
  callback_steps: int = 1,
 
 
394
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  batch_size = 4
396
  device = torch.device("cuda:0")
397
 
 
402
  # corresponds to doing no classifier free guidance.
403
  do_classifier_free_guidance = guidance_scale > 1.0
404
 
405
+ # Prepare timesteps
406
  self.scheduler.set_timesteps(num_inference_steps, device=device)
407
  timesteps = self.scheduler.timesteps
408
 
 
412
  num_images_per_prompt=num_images_per_prompt,
413
  do_classifier_free_guidance=do_classifier_free_guidance,
414
  negative_prompt=negative_prompt,
415
+ ) # type: ignore
416
  prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
+ # Prepare latent variables
419
  latents: torch.Tensor = self.prepare_latents(
420
  batch_size * num_images_per_prompt,
421
  4,
 
427
  None,
428
  )
429
 
430
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
431
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
432
 
433
+ # Denoising loop
434
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
435
  with self.progress_bar(total=num_inference_steps) as progress_bar:
436
  for i, t in enumerate(timesteps):
437
  # expand the latents if we are doing classifier free guidance
438
  multiplier = 2 if do_classifier_free_guidance else 1
439
  latent_model_input = torch.cat([latents] * multiplier)
440
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
441
 
442
  # predict the noise residual
 
 
 
 
 
 
443
  noise_pred = self.unet.forward(
444
  x=latent_model_input,
445
+ timesteps=torch.tensor([t] * 4 * multiplier, device=device),
446
+ context=torch.cat([prompt_embeds_neg] * 4 + [prompt_embeds_pos] * 4),
 
 
447
  num_frames=4,
448
  camera=torch.cat([camera] * multiplier),
449
  )
 
451
  # perform guidance
452
  if do_classifier_free_guidance:
453
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
454
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
455
 
456
  # compute the previous noisy sample x_t -> x_t-1
457
  # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
458
+ latents: torch.Tensor = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
459
 
460
  # call the callback, if provided
461
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
462
  progress_bar.update()
463
  if callback is not None and i % callback_steps == 0:
464
+ callback(i, t, latents) # type: ignore
465
 
466
+ # Post-processing
467
  if output_type == "latent":
468
  image = latents
469
  elif output_type == "pil":
 
470
  image = self.decode_latents(latents)
 
471
  image = self.numpy_to_pil(image)
472
  else:
 
473
  image = self.decode_latents(latents)
474
 
475
  # Offload last model to CPU
476
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
 
 
477
  self.final_offload_hook.offload()
478
 
479
+ return image
 
 
 
 
vae/diffusion_pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f1b909aa85cc520a2986d6fc379478e0c46c41f853f9a7c73c0150b2c9c9b8b
3
  size 334716034
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:660d2d3c357697e87aded9b7d821dd726977291c049be64489132cd442ce6477
3
  size 334716034