update model code
Browse files- config.json +1 -1
- modeling_plamo.py +152 -59
config.json
CHANGED
@@ -30,7 +30,7 @@
|
|
30 |
"mamba_num_heads": 32,
|
31 |
"mamba_step": 2,
|
32 |
"max_position_embeddings": 10485760,
|
33 |
-
"model_type": "
|
34 |
"n_expert": null,
|
35 |
"num_attention_heads": 16,
|
36 |
"num_hidden_layers": 16,
|
|
|
30 |
"mamba_num_heads": 32,
|
31 |
"mamba_step": 2,
|
32 |
"max_position_embeddings": 10485760,
|
33 |
+
"model_type": "plamo2",
|
34 |
"n_expert": null,
|
35 |
"num_attention_heads": 16,
|
36 |
"num_hidden_layers": 16,
|
modeling_plamo.py
CHANGED
@@ -551,6 +551,68 @@ def _ssd_chunk_scan_combined_naive(
|
|
551 |
return torch.cat(ys, dim=1), ssm_state
|
552 |
|
553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
def ssd_chunk_scan_combined(
|
555 |
x: torch.Tensor,
|
556 |
dt: torch.Tensor,
|
@@ -587,19 +649,19 @@ def ssd_chunk_scan_combined(
|
|
587 |
To avoid updating state, we set dt to -inf and x to 0
|
588 |
because `softplus(-inf) = 0` and `exp(0) = 1`
|
589 |
"""
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
|
599 |
-
|
600 |
-
length = x.shape[1]
|
601 |
-
assert length % chunk_size == 0, (length, chunk_size)
|
602 |
|
|
|
|
|
|
|
|
|
603 |
dtype = _get_trition_dtype(x.dtype)
|
604 |
out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
|
605 |
x.to(dtype),
|
@@ -622,19 +684,75 @@ def ssd_chunk_scan_combined(
|
|
622 |
assert isinstance(out, torch.Tensor)
|
623 |
return out[:, pad:]
|
624 |
else:
|
625 |
-
if ssm_state is None:
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
if return_final_states:
|
633 |
return tmp
|
634 |
else:
|
635 |
return tmp[0]
|
636 |
|
637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
638 |
def _causal_conv1d(
|
639 |
conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
640 |
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
@@ -670,52 +788,27 @@ def _causal_conv1d(
|
|
670 |
else:
|
671 |
x = tmp
|
672 |
else:
|
673 |
-
if
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
)
|
687 |
-
|
688 |
-
x = out
|
689 |
if return_final_states:
|
690 |
return x, conv_state
|
691 |
else:
|
692 |
return x, None
|
693 |
|
694 |
|
695 |
-
def _causal_conv1d_update(
|
696 |
-
conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
|
697 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
698 |
-
dtype = conv_state.dtype
|
699 |
-
xBC = xBC.to(dtype)
|
700 |
-
weight = weight.to(dtype)
|
701 |
-
if conv_state.is_cuda:
|
702 |
-
x = causal_conv1d.causal_conv1d_update(
|
703 |
-
x=xBC,
|
704 |
-
conv_state=conv_state,
|
705 |
-
weight=weight[:, 0, :],
|
706 |
-
activation="silu",
|
707 |
-
)
|
708 |
-
return x, conv_state
|
709 |
-
else:
|
710 |
-
x = causal_conv1d.causal_conv1d_update_ref(
|
711 |
-
x=xBC,
|
712 |
-
conv_state=conv_state,
|
713 |
-
weight=weight[:, 0, :],
|
714 |
-
activation="silu",
|
715 |
-
)
|
716 |
-
return x, conv_state
|
717 |
-
|
718 |
-
|
719 |
class Mamba(torch.nn.Module):
|
720 |
def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
|
721 |
super().__init__()
|
|
|
551 |
return torch.cat(ys, dim=1), ssm_state
|
552 |
|
553 |
|
554 |
+
def _ssd_chunk_scan_combined_cpu(
|
555 |
+
x: torch.Tensor,
|
556 |
+
dt: torch.Tensor,
|
557 |
+
A: torch.Tensor,
|
558 |
+
B: torch.Tensor,
|
559 |
+
C: torch.Tensor,
|
560 |
+
chunk_size: int,
|
561 |
+
D: torch.Tensor,
|
562 |
+
z: torch.Tensor,
|
563 |
+
dt_bias: torch.Tensor,
|
564 |
+
dt_softplus: bool,
|
565 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
566 |
+
# (bsize, nhead, nchunk, chunk_size)
|
567 |
+
dt = dt.float() # We want high precision for this before cumsum
|
568 |
+
dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) # type: ignore
|
569 |
+
if dt_bias is not None:
|
570 |
+
dt = dt + dt_bias[None, :, None, None]
|
571 |
+
if dt_softplus:
|
572 |
+
dt = F.softplus(dt)
|
573 |
+
dA = dt * A[None, :, None, None]
|
574 |
+
dA_cumsum = torch.cumsum(dA, dim=-1)
|
575 |
+
|
576 |
+
_, _, nheads, _ = x.shape
|
577 |
+
dstate = B.shape[-1]
|
578 |
+
_ = dt.shape[2]
|
579 |
+
|
580 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"):
|
581 |
+
# Following is equivalent to `mamba_ssm.ops.triton.ssd_combined.chunk_state_ref(B, x, dt, dA_cumsum)`
|
582 |
+
# But `einsum` in the above function is too slow in CPU.
|
583 |
+
x_ = torch.unflatten(x, 1, (-1, chunk_size))
|
584 |
+
assert B.shape[2] == nheads # B should be already expanded
|
585 |
+
B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) # (bsize, nchunk, chunk_size, nheads, dstate)
|
586 |
+
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype)
|
587 |
+
dt_ = dt.to(x.dtype)
|
588 |
+
|
589 |
+
# einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B_, decay_states, dt_, x_)
|
590 |
+
B_ = B_.permute(0, 1, 3, 4, 2) # bchnl
|
591 |
+
tmp = dt_ * decay_states # bhcl
|
592 |
+
tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] # bch1l
|
593 |
+
tmp = B_ * tmp # bchnl
|
594 |
+
x_ = x_.permute(0, 1, 3, 2, 4) # bchlp
|
595 |
+
tmp = tmp @ x_ # bchnp
|
596 |
+
states = tmp.permute(0, 1, 2, 4, 3) # bchpn
|
597 |
+
|
598 |
+
states_dtype = states.dtype
|
599 |
+
if states.dtype not in [torch.float32, torch.float64]:
|
600 |
+
states = states.to(torch.float32)
|
601 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"):
|
602 |
+
out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref(
|
603 |
+
states.flatten(start_dim=-2, end_dim=-1),
|
604 |
+
dA_cumsum[:, :, :, -1],
|
605 |
+
)
|
606 |
+
states = torch.unflatten(out, -1, (-1, dstate))
|
607 |
+
last_state = torch.unflatten(last_state, -1, (-1, dstate))
|
608 |
+
states = states.to(states_dtype)
|
609 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"):
|
610 |
+
out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
611 |
+
|
612 |
+
return out, last_state
|
613 |
+
|
614 |
+
|
615 |
+
@torch.profiler.record_function("ssd_chunk_scan_combined")
|
616 |
def ssd_chunk_scan_combined(
|
617 |
x: torch.Tensor,
|
618 |
dt: torch.Tensor,
|
|
|
649 |
To avoid updating state, we set dt to -inf and x to 0
|
650 |
because `softplus(-inf) = 0` and `exp(0) = 1`
|
651 |
"""
|
652 |
+
pad = (chunk_size - length % chunk_size) % chunk_size
|
653 |
+
x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
654 |
+
dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
|
655 |
+
B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
656 |
+
C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
657 |
+
z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
658 |
+
if seq_idx is not None:
|
659 |
+
seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
|
|
|
|
|
|
|
|
|
660 |
|
661 |
+
length = x.shape[1]
|
662 |
+
assert length % chunk_size == 0, (length, chunk_size)
|
663 |
+
|
664 |
+
if dt.is_cuda:
|
665 |
dtype = _get_trition_dtype(x.dtype)
|
666 |
out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
|
667 |
x.to(dtype),
|
|
|
684 |
assert isinstance(out, torch.Tensor)
|
685 |
return out[:, pad:]
|
686 |
else:
|
687 |
+
if ssm_state is None and seq_idx is None:
|
688 |
+
tmp = _ssd_chunk_scan_combined_cpu(
|
689 |
+
x,
|
690 |
+
dt,
|
691 |
+
A,
|
692 |
+
B,
|
693 |
+
C,
|
694 |
+
chunk_size,
|
695 |
+
D=D,
|
696 |
+
z=z,
|
697 |
+
dt_bias=dt_bias.float(),
|
698 |
+
dt_softplus=dt_softplus,
|
699 |
+
)
|
700 |
+
else:
|
701 |
+
if ssm_state is None:
|
702 |
+
bsize, _, num_heads, channel = x.shape
|
703 |
+
state = B.shape[-1]
|
704 |
+
ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
|
705 |
+
tmp = _ssd_chunk_scan_combined_naive(
|
706 |
+
x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
|
707 |
+
)
|
708 |
+
tmp = (tmp[0][:, pad:], tmp[1])
|
709 |
if return_final_states:
|
710 |
return tmp
|
711 |
else:
|
712 |
return tmp[0]
|
713 |
|
714 |
|
715 |
+
def _causal_conv1d_update(
|
716 |
+
conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
|
717 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
718 |
+
dtype = conv_state.dtype
|
719 |
+
xBC = xBC.to(dtype)
|
720 |
+
weight = weight.to(dtype)
|
721 |
+
if conv_state.is_cuda:
|
722 |
+
x = causal_conv1d.causal_conv1d_update(
|
723 |
+
x=xBC,
|
724 |
+
conv_state=conv_state,
|
725 |
+
weight=weight[:, 0, :],
|
726 |
+
activation="silu",
|
727 |
+
)
|
728 |
+
return x, conv_state
|
729 |
+
else:
|
730 |
+
x = causal_conv1d.causal_conv1d_update_ref(
|
731 |
+
x=xBC,
|
732 |
+
conv_state=conv_state,
|
733 |
+
weight=weight[:, 0, :],
|
734 |
+
activation="silu",
|
735 |
+
)
|
736 |
+
return x, conv_state
|
737 |
+
|
738 |
+
|
739 |
+
def _causal_conv1d_naive(
|
740 |
+
conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
741 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
742 |
+
length = x.shape[-1]
|
743 |
+
out = torch.zeros_like(x)
|
744 |
+
for i in range(length):
|
745 |
+
if i != 0 and seq_idx is not None:
|
746 |
+
conv_state = torch.where(
|
747 |
+
(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
|
748 |
+
torch.zeros_like(conv_state),
|
749 |
+
conv_state,
|
750 |
+
)
|
751 |
+
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
|
752 |
+
return out, conv_state
|
753 |
+
|
754 |
+
|
755 |
+
@torch.profiler.record_function("causal_conv1d")
|
756 |
def _causal_conv1d(
|
757 |
conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
758 |
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
|
788 |
else:
|
789 |
x = tmp
|
790 |
else:
|
791 |
+
if seq_idx is None:
|
792 |
+
x, conv_state = causal_conv1d.causal_conv1d_ref(
|
793 |
+
x=x,
|
794 |
+
initial_states=conv_state,
|
795 |
+
return_final_states=True,
|
796 |
+
weight=weight[:, 0, :],
|
797 |
+
activation="silu",
|
798 |
+
)
|
799 |
+
else:
|
800 |
+
if conv_state is None:
|
801 |
+
bsize = x.shape[0]
|
802 |
+
dim = weight.shape[0]
|
803 |
+
d_conv = weight.shape[-1]
|
804 |
+
conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
|
805 |
+
x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx)
|
|
|
806 |
if return_final_states:
|
807 |
return x, conv_state
|
808 |
else:
|
809 |
return x, None
|
810 |
|
811 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
812 |
class Mamba(torch.nn.Module):
|
813 |
def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
|
814 |
super().__init__()
|