frankleeeee commited on
Commit
6ac94e3
1 Parent(s): 5c162ac

Upload STDiT2

Browse files
config.json CHANGED
@@ -10,7 +10,7 @@
10
  "class_dropout_prob": 0.1,
11
  "depth": 28,
12
  "drop_path": 0.0,
13
- "enable_flashattn": false,
14
  "enable_layernorm_kernel": false,
15
  "enable_sequence_parallelism": false,
16
  "freeze": null,
 
10
  "class_dropout_prob": 0.1,
11
  "depth": 28,
12
  "drop_path": 0.0,
13
+ "enable_flash_attn": false,
14
  "enable_layernorm_kernel": false,
15
  "enable_sequence_parallelism": false,
16
  "freeze": null,
configuration_stdit2.py CHANGED
@@ -24,7 +24,7 @@ class STDiT2Config(PretrainedConfig):
24
  model_max_length=120,
25
  freeze=None,
26
  qk_norm=False,
27
- enable_flashattn=False,
28
  enable_layernorm_kernel=False,
29
  enable_sequence_parallelism=False,
30
  **kwargs,
@@ -45,7 +45,7 @@ class STDiT2Config(PretrainedConfig):
45
  self.model_max_length = model_max_length
46
  self.freeze = freeze
47
  self.qk_norm = qk_norm
48
- self.enable_flashattn = enable_flashattn
49
  self.enable_layernorm_kernel = enable_layernorm_kernel
50
  self.enable_sequence_parallelism = enable_sequence_parallelism
51
  super().__init__(**kwargs)
 
24
  model_max_length=120,
25
  freeze=None,
26
  qk_norm=False,
27
+ enable_flash_attn=False,
28
  enable_layernorm_kernel=False,
29
  enable_sequence_parallelism=False,
30
  **kwargs,
 
45
  self.model_max_length = model_max_length
46
  self.freeze = freeze
47
  self.qk_norm = qk_norm
48
+ self.enable_flash_attn = enable_flash_attn
49
  self.enable_layernorm_kernel = enable_layernorm_kernel
50
  self.enable_sequence_parallelism = enable_sequence_parallelism
51
  super().__init__(**kwargs)
layers.py CHANGED
@@ -30,7 +30,7 @@ class STDiT2Block(nn.Module):
30
  num_heads,
31
  mlp_ratio=4.0,
32
  drop_path=0.0,
33
- enable_flashattn=False,
34
  enable_layernorm_kernel=False,
35
  enable_sequence_parallelism=False,
36
  rope=None,
@@ -38,7 +38,7 @@ class STDiT2Block(nn.Module):
38
  ):
39
  super().__init__()
40
  self.hidden_size = hidden_size
41
- self.enable_flashattn = enable_flashattn
42
  self._enable_sequence_parallelism = enable_sequence_parallelism
43
 
44
  assert not self._enable_sequence_parallelism, "Sequence parallelism is not supported."
@@ -55,7 +55,7 @@ class STDiT2Block(nn.Module):
55
  hidden_size,
56
  num_heads=num_heads,
57
  qkv_bias=True,
58
- enable_flashattn=enable_flashattn,
59
  qk_norm=qk_norm,
60
  )
61
  self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
@@ -76,7 +76,7 @@ class STDiT2Block(nn.Module):
76
  hidden_size,
77
  num_heads=num_heads,
78
  qkv_bias=True,
79
- enable_flashattn=self.enable_flashattn,
80
  rope=rope,
81
  qk_norm=qk_norm,
82
  )
@@ -196,7 +196,7 @@ class Attention(nn.Module):
196
  attn_drop: float = 0.0,
197
  proj_drop: float = 0.0,
198
  norm_layer: nn.Module = LlamaRMSNorm,
199
- enable_flashattn: bool = False,
200
  rope=None,
201
  ) -> None:
202
  super().__init__()
@@ -205,7 +205,7 @@ class Attention(nn.Module):
205
  self.num_heads = num_heads
206
  self.head_dim = dim // num_heads
207
  self.scale = self.head_dim**-0.5
208
- self.enable_flashattn = enable_flashattn
209
 
210
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
211
  self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
@@ -222,7 +222,7 @@ class Attention(nn.Module):
222
  def forward(self, x: torch.Tensor) -> torch.Tensor:
223
  B, N, C = x.shape
224
  # flash attn is not memory efficient for small sequences, this is empirical
225
- enable_flashattn = self.enable_flashattn and (N > B)
226
  qkv = self.qkv(x)
227
  qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
228
 
@@ -233,7 +233,7 @@ class Attention(nn.Module):
233
  k = self.rotary_emb(k)
234
  q, k = self.q_norm(q), self.k_norm(k)
235
 
236
- if enable_flashattn:
237
  from flash_attn import flash_attn_func
238
 
239
  # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
@@ -258,7 +258,7 @@ class Attention(nn.Module):
258
  x = attn @ v
259
 
260
  x_output_shape = (B, N, C)
261
- if not enable_flashattn:
262
  x = x.transpose(1, 2)
263
  x = x.reshape(x_output_shape)
264
  x = self.proj(x)
 
30
  num_heads,
31
  mlp_ratio=4.0,
32
  drop_path=0.0,
33
+ enable_flash_attn=False,
34
  enable_layernorm_kernel=False,
35
  enable_sequence_parallelism=False,
36
  rope=None,
 
38
  ):
39
  super().__init__()
40
  self.hidden_size = hidden_size
41
+ self.enable_flash_attn = enable_flash_attn
42
  self._enable_sequence_parallelism = enable_sequence_parallelism
43
 
44
  assert not self._enable_sequence_parallelism, "Sequence parallelism is not supported."
 
55
  hidden_size,
56
  num_heads=num_heads,
57
  qkv_bias=True,
58
+ enable_flash_attn=enable_flash_attn,
59
  qk_norm=qk_norm,
60
  )
61
  self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
 
76
  hidden_size,
77
  num_heads=num_heads,
78
  qkv_bias=True,
79
+ enable_flash_attn=self.enable_flash_attn,
80
  rope=rope,
81
  qk_norm=qk_norm,
82
  )
 
196
  attn_drop: float = 0.0,
197
  proj_drop: float = 0.0,
198
  norm_layer: nn.Module = LlamaRMSNorm,
199
+ enable_flash_attn: bool = False,
200
  rope=None,
201
  ) -> None:
202
  super().__init__()
 
205
  self.num_heads = num_heads
206
  self.head_dim = dim // num_heads
207
  self.scale = self.head_dim**-0.5
208
+ self.enable_flash_attn = enable_flash_attn
209
 
210
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
211
  self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 
222
  def forward(self, x: torch.Tensor) -> torch.Tensor:
223
  B, N, C = x.shape
224
  # flash attn is not memory efficient for small sequences, this is empirical
225
+ enable_flash_attn = self.enable_flash_attn and (N > B)
226
  qkv = self.qkv(x)
227
  qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
228
 
 
233
  k = self.rotary_emb(k)
234
  q, k = self.q_norm(q), self.k_norm(k)
235
 
236
+ if enable_flash_attn:
237
  from flash_attn import flash_attn_func
238
 
239
  # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
 
258
  x = attn @ v
259
 
260
  x_output_shape = (B, N, C)
261
+ if not enable_flash_attn:
262
  x = x.transpose(1, 2)
263
  x = x.reshape(x_output_shape)
264
  x = self.proj(x)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e777ae49713478957c48f97eda4405e392e3fd12580e01be944465b741c6521c
3
  size 3071846872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:033319fd49c2ff9bc57836e2fcaacea5ac64f6efa357f603f91f06a9164d0e1c
3
  size 3071846872
modeling_stdit2.py CHANGED
@@ -38,7 +38,7 @@ class STDiT2(PreTrainedModel):
38
  self.no_temporal_pos_emb = config.no_temporal_pos_emb
39
  self.depth = config.depth
40
  self.mlp_ratio = config.mlp_ratio
41
- self.enable_flashattn = config.enable_flashattn
42
  self.enable_layernorm_kernel = config.enable_layernorm_kernel
43
  self.enable_sequence_parallelism = config.enable_sequence_parallelism
44
 
@@ -69,7 +69,7 @@ class STDiT2(PreTrainedModel):
69
  self.num_heads,
70
  mlp_ratio=self.mlp_ratio,
71
  drop_path=drop_path[i],
72
- enable_flashattn=self.enable_flashattn,
73
  enable_layernorm_kernel=self.enable_layernorm_kernel,
74
  enable_sequence_parallelism=self.enable_sequence_parallelism,
75
  rope=self.rope.rotate_queries_or_keys,
 
38
  self.no_temporal_pos_emb = config.no_temporal_pos_emb
39
  self.depth = config.depth
40
  self.mlp_ratio = config.mlp_ratio
41
+ self.enable_flash_attn = config.enable_flash_attn
42
  self.enable_layernorm_kernel = config.enable_layernorm_kernel
43
  self.enable_sequence_parallelism = config.enable_sequence_parallelism
44
 
 
69
  self.num_heads,
70
  mlp_ratio=self.mlp_ratio,
71
  drop_path=drop_path[i],
72
+ enable_flash_attn=self.enable_flash_attn,
73
  enable_layernorm_kernel=self.enable_layernorm_kernel,
74
  enable_sequence_parallelism=self.enable_sequence_parallelism,
75
  rope=self.rope.rotate_queries_or_keys,