frankleeeee
commited on
Upload STDiT
Browse files- modeling_stdit.py +2 -1
modeling_stdit.py
CHANGED
@@ -148,7 +148,8 @@ class STDiT(PreTrainedModel):
|
|
148 |
tpe = self.pos_embed_temporal
|
149 |
else:
|
150 |
tpe = None
|
151 |
-
x =
|
|
|
152 |
|
153 |
if self.enable_sequence_parallelism:
|
154 |
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|
|
|
148 |
tpe = self.pos_embed_temporal
|
149 |
else:
|
150 |
tpe = None
|
151 |
+
x = block(x, y, t0, y_lens, tpe)
|
152 |
+
# x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
|
153 |
|
154 |
if self.enable_sequence_parallelism:
|
155 |
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|