File size: 10,776 Bytes
f6b575c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
from typing import Optional, Tuple, Union
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.embeddings import get_fourier_embeds_from_boundingbox
import torch
import torch.nn as nn
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()
def init_(self):
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
return self.emb(n)[None, :, :]
class InteractDiffusionInteractionProjection(nn.Module):
def __init__(self, in_dim, out_dim, fourier_freqs=8):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
self.interaction_embedding = AbsolutePositionalEmbedding(dim=out_dim, max_seq_len=30)
self.position_embedding = AbsolutePositionalEmbedding(dim=out_dim, max_seq_len=3)
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.linear_action = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
self.null_action_feature = torch.nn.Parameter(torch.zeros([self.in_dim]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def get_between_box(self, bbox1, bbox2):
""" Between Set Operation
Operation of Box A between Box B from Prof. Jiang idea
"""
all_x = torch.cat([bbox1[:, :, 0::2], bbox2[:, :, 0::2]],dim=-1)
all_y = torch.cat([bbox1[:, :, 1::2], bbox2[:, :, 1::2]],dim=-1)
all_x, _ = all_x.sort()
all_y, _ = all_y.sort()
return torch.stack([all_x[:,:,1], all_y[:,:,1], all_x[:,:,2], all_y[:,:,2]],2)
def forward(
self,
subject_boxes, object_boxes,
masks,
subject_positive_embeddings, object_positive_embeddings, action_positive_embeddings
):
masks = masks.unsqueeze(-1)
# embedding position (it may include padding as placeholder)
action_boxes = self.get_between_box(subject_boxes, object_boxes)
subject_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, subject_boxes) # B*N*4 --> B*N*C
object_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, object_boxes) # B*N*4 --> B*N*C
action_xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, action_boxes) # B*N*4 --> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
action_null = self.null_action_feature.view(1, 1, -1)
# replace padding with learnable null embedding
subject_positive_embeddings = subject_positive_embeddings * masks + (1 - masks) * positive_null
object_positive_embeddings = object_positive_embeddings * masks + (1 - masks) * positive_null
subject_xyxy_embedding = subject_xyxy_embedding * masks + (1 - masks) * xyxy_null
object_xyxy_embedding = object_xyxy_embedding * masks + (1 - masks) * xyxy_null
action_xyxy_embedding = action_xyxy_embedding * masks + (1 - masks) * xyxy_null
action_positive_embeddings = action_positive_embeddings * masks + (1 - masks) * action_null
# project the input embeddings
objs_subject = self.linears(torch.cat([subject_positive_embeddings, subject_xyxy_embedding], dim=-1))
objs_object = self.linears(torch.cat([object_positive_embeddings, object_xyxy_embedding], dim=-1))
objs_action = self.linear_action(torch.cat([action_positive_embeddings, action_xyxy_embedding], dim=-1))
# impose role embedding
objs_subject = objs_subject + self.interaction_embedding(objs_subject)
objs_object = objs_object + self.interaction_embedding(objs_object)
objs_action = objs_action + self.interaction_embedding(objs_action)
# impose instance embedding
objs_subject = objs_subject + self.position_embedding.emb(torch.tensor(0).to(objs_subject.device))
objs_object = objs_object + self.position_embedding.emb(torch.tensor(1).to(objs_object.device))
objs_action = objs_action + self.position_embedding.emb(torch.tensor(2).to(objs_action.device))
objs = torch.cat([objs_subject, objs_action, objs_object], dim=1)
return objs
class InteractDiffusionUNet2DConditionModel(UNet2DConditionModel):
def __init__(self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads: int = 64,
):
super(InteractDiffusionUNet2DConditionModel, self).__init__(
sample_size=sample_size,
in_channels=in_channels,
out_channels=out_channels,
center_input_sample=center_input_sample,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
mid_block_type=mid_block_type,
up_block_types=up_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
dropout=dropout,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
class_embed_type=class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
num_class_embeds=num_class_embeds,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
time_embedding_type=time_embedding_type,
time_embedding_dim=time_embedding_dim,
time_embedding_act_fn=time_embedding_act_fn,
timestep_post_act=timestep_post_act,
time_cond_proj_dim=time_cond_proj_dim,
conv_in_kernel=conv_in_kernel,
conv_out_kernel=conv_out_kernel,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
attention_type=attention_type,
class_embeddings_concat=class_embeddings_concat,
mid_block_only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
addition_embed_type_num_heads=addition_embed_type_num_heads
)
# load position_net
positive_len = 768
if isinstance(self.config.cross_attention_dim, int):
positive_len = self.config.cross_attention_dim
elif isinstance(self.config.cross_attention_dim, tuple) or isinstance(self.config.cross_attention_dim, list):
positive_len = self.config.cross_attention_dim[0]
self.position_net = InteractDiffusionInteractionProjection(
in_dim=positive_len, out_dim=self.config.cross_attention_dim
)
|