jiuntian commited on
Commit
f6b575c
1 Parent(s): 032ddf2
unet/config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_class_name": "UNet2DConditionModel",
3
  "_diffusers_version": "0.27.0.dev0",
4
  "act_fn": "silu",
5
  "addition_embed_type": null,
 
1
  {
2
+ "_class_name": "InteractDiffusionUNet2DConditionModel",
3
  "_diffusers_version": "0.27.0.dev0",
4
  "act_fn": "silu",
5
  "addition_embed_type": null,
unet/interactdiffusion_unet_2d_condition.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
3
+ from diffusers.models.embeddings import get_fourier_embeds_from_boundingbox
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ class AbsolutePositionalEmbedding(nn.Module):
8
+ def __init__(self, dim, max_seq_len):
9
+ super().__init__()
10
+ self.emb = nn.Embedding(max_seq_len, dim)
11
+ self.init_()
12
+
13
+ def init_(self):
14
+ nn.init.normal_(self.emb.weight, std=0.02)
15
+
16
+ def forward(self, x):
17
+ n = torch.arange(x.shape[1], device=x.device)
18
+ return self.emb(n)[None, :, :]
19
+
20
+
21
+ class InteractDiffusionInteractionProjection(nn.Module):
22
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
23
+ super().__init__()
24
+ self.in_dim = in_dim
25
+ self.out_dim = out_dim
26
+
27
+ self.fourier_embedder_dim = fourier_freqs
28
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
29
+ self.interaction_embedding = AbsolutePositionalEmbedding(dim=out_dim, max_seq_len=30)
30
+ self.position_embedding = AbsolutePositionalEmbedding(dim=out_dim, max_seq_len=3)
31
+
32
+ if isinstance(out_dim, tuple):
33
+ out_dim = out_dim[0]
34
+
35
+ self.linears = nn.Sequential(
36
+ nn.Linear(self.in_dim + self.position_dim, 512),
37
+ nn.SiLU(),
38
+ nn.Linear(512, 512),
39
+ nn.SiLU(),
40
+ nn.Linear(512, out_dim),
41
+ )
42
+
43
+ self.linear_action = nn.Sequential(
44
+ nn.Linear(self.in_dim + self.position_dim, 512),
45
+ nn.SiLU(),
46
+ nn.Linear(512, 512),
47
+ nn.SiLU(),
48
+ nn.Linear(512, out_dim),
49
+ )
50
+
51
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
52
+ self.null_action_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
53
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
54
+
55
+ def get_between_box(self, bbox1, bbox2):
56
+ """ Between Set Operation
57
+ Operation of Box A between Box B from Prof. Jiang idea
58
+ """
59
+ all_x = torch.cat([bbox1[:, :, 0::2], bbox2[:, :, 0::2]],dim=-1)
60
+ all_y = torch.cat([bbox1[:, :, 1::2], bbox2[:, :, 1::2]],dim=-1)
61
+ all_x, _ = all_x.sort()
62
+ all_y, _ = all_y.sort()
63
+ return torch.stack([all_x[:,:,1], all_y[:,:,1], all_x[:,:,2], all_y[:,:,2]],2)
64
+
65
+ def forward(
66
+ self,
67
+ subject_boxes, object_boxes,
68
+ masks,
69
+ subject_positive_embeddings, object_positive_embeddings, action_positive_embeddings
70
+ ):
71
+ masks = masks.unsqueeze(-1)
72
+
73
+ # embedding position (it may include padding as placeholder)
74
+ action_boxes = self.get_between_box(subject_boxes, object_boxes)
75
+ subject_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, subject_boxes) # B*N*4 --> B*N*C
76
+ object_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, object_boxes) # B*N*4 --> B*N*C
77
+ action_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, action_boxes) # B*N*4 --> B*N*C
78
+
79
+ # learnable null embedding
80
+ positive_null = self.null_positive_feature.view(1, 1, -1)
81
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
82
+ action_null = self.null_action_feature.view(1, 1, -1)
83
+
84
+ # replace padding with learnable null embedding
85
+ subject_positive_embeddings = subject_positive_embeddings * masks + (1 - masks) * positive_null
86
+ object_positive_embeddings = object_positive_embeddings * masks + (1 - masks) * positive_null
87
+
88
+ subject_xyxy_embedding = subject_xyxy_embedding * masks + (1 - masks) * xyxy_null
89
+ object_xyxy_embedding = object_xyxy_embedding * masks + (1 - masks) * xyxy_null
90
+ action_xyxy_embedding = action_xyxy_embedding * masks + (1 - masks) * xyxy_null
91
+
92
+ action_positive_embeddings = action_positive_embeddings * masks + (1 - masks) * action_null
93
+
94
+ # project the input embeddings
95
+ objs_subject = self.linears(torch.cat([subject_positive_embeddings, subject_xyxy_embedding], dim=-1))
96
+ objs_object = self.linears(torch.cat([object_positive_embeddings, object_xyxy_embedding], dim=-1))
97
+ objs_action = self.linear_action(torch.cat([action_positive_embeddings, action_xyxy_embedding], dim=-1))
98
+
99
+ # impose role embedding
100
+ objs_subject = objs_subject + self.interaction_embedding(objs_subject)
101
+ objs_object = objs_object + self.interaction_embedding(objs_object)
102
+ objs_action = objs_action + self.interaction_embedding(objs_action)
103
+
104
+ # impose instance embedding
105
+ objs_subject = objs_subject + self.position_embedding.emb(torch.tensor(0).to(objs_subject.device))
106
+ objs_object = objs_object + self.position_embedding.emb(torch.tensor(1).to(objs_object.device))
107
+ objs_action = objs_action + self.position_embedding.emb(torch.tensor(2).to(objs_action.device))
108
+
109
+ objs = torch.cat([objs_subject, objs_action, objs_object], dim=1)
110
+
111
+ return objs
112
+
113
+
114
+ class InteractDiffusionUNet2DConditionModel(UNet2DConditionModel):
115
+ def __init__(self,
116
+ sample_size: Optional[int] = None,
117
+ in_channels: int = 4,
118
+ out_channels: int = 4,
119
+ center_input_sample: bool = False,
120
+ flip_sin_to_cos: bool = True,
121
+ freq_shift: int = 0,
122
+ down_block_types: Tuple[str] = (
123
+ "CrossAttnDownBlock2D",
124
+ "CrossAttnDownBlock2D",
125
+ "CrossAttnDownBlock2D",
126
+ "DownBlock2D",
127
+ ),
128
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
129
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
130
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
131
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
132
+ layers_per_block: Union[int, Tuple[int]] = 2,
133
+ downsample_padding: int = 1,
134
+ mid_block_scale_factor: float = 1,
135
+ dropout: float = 0.0,
136
+ act_fn: str = "silu",
137
+ norm_num_groups: Optional[int] = 32,
138
+ norm_eps: float = 1e-5,
139
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
140
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
141
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
142
+ encoder_hid_dim: Optional[int] = None,
143
+ encoder_hid_dim_type: Optional[str] = None,
144
+ attention_head_dim: Union[int, Tuple[int]] = 8,
145
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
146
+ dual_cross_attention: bool = False,
147
+ use_linear_projection: bool = False,
148
+ class_embed_type: Optional[str] = None,
149
+ addition_embed_type: Optional[str] = None,
150
+ addition_time_embed_dim: Optional[int] = None,
151
+ num_class_embeds: Optional[int] = None,
152
+ upcast_attention: bool = False,
153
+ resnet_time_scale_shift: str = "default",
154
+ resnet_skip_time_act: bool = False,
155
+ resnet_out_scale_factor: float = 1.0,
156
+ time_embedding_type: str = "positional",
157
+ time_embedding_dim: Optional[int] = None,
158
+ time_embedding_act_fn: Optional[str] = None,
159
+ timestep_post_act: Optional[str] = None,
160
+ time_cond_proj_dim: Optional[int] = None,
161
+ conv_in_kernel: int = 3,
162
+ conv_out_kernel: int = 3,
163
+ projection_class_embeddings_input_dim: Optional[int] = None,
164
+ attention_type: str = "default",
165
+ class_embeddings_concat: bool = False,
166
+ mid_block_only_cross_attention: Optional[bool] = None,
167
+ cross_attention_norm: Optional[str] = None,
168
+ addition_embed_type_num_heads: int = 64,
169
+ ):
170
+ super(InteractDiffusionUNet2DConditionModel, self).__init__(
171
+ sample_size=sample_size,
172
+ in_channels=in_channels,
173
+ out_channels=out_channels,
174
+ center_input_sample=center_input_sample,
175
+ flip_sin_to_cos=flip_sin_to_cos,
176
+ freq_shift=freq_shift,
177
+ down_block_types=down_block_types,
178
+ mid_block_type=mid_block_type,
179
+ up_block_types=up_block_types,
180
+ only_cross_attention=only_cross_attention,
181
+ block_out_channels=block_out_channels,
182
+ layers_per_block=layers_per_block,
183
+ downsample_padding=downsample_padding,
184
+ mid_block_scale_factor=mid_block_scale_factor,
185
+ dropout=dropout,
186
+ act_fn=act_fn,
187
+ norm_num_groups=norm_num_groups,
188
+ norm_eps=norm_eps,
189
+ cross_attention_dim=cross_attention_dim,
190
+ transformer_layers_per_block=transformer_layers_per_block,
191
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
192
+ encoder_hid_dim=encoder_hid_dim,
193
+ encoder_hid_dim_type=encoder_hid_dim_type,
194
+ attention_head_dim=attention_head_dim,
195
+ num_attention_heads=num_attention_heads,
196
+ dual_cross_attention=dual_cross_attention,
197
+ use_linear_projection=use_linear_projection,
198
+ class_embed_type=class_embed_type,
199
+ addition_embed_type=addition_embed_type,
200
+ addition_time_embed_dim=addition_time_embed_dim,
201
+ num_class_embeds=num_class_embeds,
202
+ upcast_attention=upcast_attention,
203
+ resnet_time_scale_shift=resnet_time_scale_shift,
204
+ resnet_skip_time_act=resnet_skip_time_act,
205
+ resnet_out_scale_factor=resnet_out_scale_factor,
206
+ time_embedding_type=time_embedding_type,
207
+ time_embedding_dim=time_embedding_dim,
208
+ time_embedding_act_fn=time_embedding_act_fn,
209
+ timestep_post_act=timestep_post_act,
210
+ time_cond_proj_dim=time_cond_proj_dim,
211
+ conv_in_kernel=conv_in_kernel,
212
+ conv_out_kernel=conv_out_kernel,
213
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
214
+ attention_type=attention_type,
215
+ class_embeddings_concat=class_embeddings_concat,
216
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
217
+ cross_attention_norm=cross_attention_norm,
218
+ addition_embed_type_num_heads=addition_embed_type_num_heads
219
+ )
220
+
221
+ # load position_net
222
+ positive_len = 768
223
+ if isinstance(self.config.cross_attention_dim, int):
224
+ positive_len = self.config.cross_attention_dim
225
+ elif isinstance(self.config.cross_attention_dim, tuple) or isinstance(self.config.cross_attention_dim, list):
226
+ positive_len = self.config.cross_attention_dim[0]
227
+
228
+ self.position_net = InteractDiffusionInteractionProjection(
229
+ in_dim=positive_len, out_dim=self.config.cross_attention_dim
230
+ )
vae/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.27.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "latents_mean": null,
21
+ "latents_std": null,
22
+ "layers_per_block": 2,
23
+ "norm_num_groups": 32,
24
+ "out_channels": 3,
25
+ "sample_size": 512,
26
+ "scaling_factor": 0.18215,
27
+ "up_block_types": [
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D"
32
+ ]
33
+ }
vae/diffusion_pytorch_model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fbcf0ebe55a0984f5a5e00d8c4521d52359af7229bb4d81890039d2aa16dd7c
3
+ size 167335342
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4d2b5932bb4151e54e694fd31ccf51fca908223c9485bd56cd0e1d83ad94c49
3
+ size 334643268