ssuzuki65 yhirokawa commited on
Commit
d6085e9
·
verified ·
1 Parent(s): 47b769d

update model code (#2)

Browse files

- update model code (27be45fa8bcf584ce06cb229606f20793f67f152)


Co-authored-by: Yuta HIROKAWA <[email protected]>

Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_plamo.py +152 -59
config.json CHANGED
@@ -30,7 +30,7 @@
30
  "mamba_num_heads": 32,
31
  "mamba_step": 2,
32
  "max_position_embeddings": 10485760,
33
- "model_type": "plamo",
34
  "n_expert": null,
35
  "num_attention_heads": 16,
36
  "num_hidden_layers": 16,
 
30
  "mamba_num_heads": 32,
31
  "mamba_step": 2,
32
  "max_position_embeddings": 10485760,
33
+ "model_type": "plamo2",
34
  "n_expert": null,
35
  "num_attention_heads": 16,
36
  "num_hidden_layers": 16,
modeling_plamo.py CHANGED
@@ -551,6 +551,68 @@ def _ssd_chunk_scan_combined_naive(
551
  return torch.cat(ys, dim=1), ssm_state
552
 
553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  def ssd_chunk_scan_combined(
555
  x: torch.Tensor,
556
  dt: torch.Tensor,
@@ -587,19 +649,19 @@ def ssd_chunk_scan_combined(
587
  To avoid updating state, we set dt to -inf and x to 0
588
  because `softplus(-inf) = 0` and `exp(0) = 1`
589
  """
590
- if dt.is_cuda:
591
- pad = (chunk_size - length % chunk_size) % chunk_size
592
- x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
593
- dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
594
- B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
595
- C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
596
- z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
597
- if seq_idx is not None:
598
- seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
599
-
600
- length = x.shape[1]
601
- assert length % chunk_size == 0, (length, chunk_size)
602
 
 
 
 
 
603
  dtype = _get_trition_dtype(x.dtype)
604
  out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
605
  x.to(dtype),
@@ -622,19 +684,75 @@ def ssd_chunk_scan_combined(
622
  assert isinstance(out, torch.Tensor)
623
  return out[:, pad:]
624
  else:
625
- if ssm_state is None:
626
- bsize, _, num_heads, channel = x.shape
627
- state = B.shape[-1]
628
- ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
629
- tmp = _ssd_chunk_scan_combined_naive(
630
- x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
631
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  if return_final_states:
633
  return tmp
634
  else:
635
  return tmp[0]
636
 
637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  def _causal_conv1d(
639
  conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
640
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
@@ -670,52 +788,27 @@ def _causal_conv1d(
670
  else:
671
  x = tmp
672
  else:
673
- if conv_state is None:
674
- bsize = x.shape[0]
675
- dim = weight.shape[0]
676
- d_conv = weight.shape[-1]
677
- conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
678
- length = x.shape[-1]
679
- out = torch.zeros_like(x)
680
- for i in range(length):
681
- if i != 0 and seq_idx is not None:
682
- conv_state = torch.where(
683
- (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
684
- torch.zeros_like(conv_state),
685
- conv_state,
686
- )
687
- out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
688
- x = out
689
  if return_final_states:
690
  return x, conv_state
691
  else:
692
  return x, None
693
 
694
 
695
- def _causal_conv1d_update(
696
- conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
697
- ) -> tuple[torch.Tensor, torch.Tensor]:
698
- dtype = conv_state.dtype
699
- xBC = xBC.to(dtype)
700
- weight = weight.to(dtype)
701
- if conv_state.is_cuda:
702
- x = causal_conv1d.causal_conv1d_update(
703
- x=xBC,
704
- conv_state=conv_state,
705
- weight=weight[:, 0, :],
706
- activation="silu",
707
- )
708
- return x, conv_state
709
- else:
710
- x = causal_conv1d.causal_conv1d_update_ref(
711
- x=xBC,
712
- conv_state=conv_state,
713
- weight=weight[:, 0, :],
714
- activation="silu",
715
- )
716
- return x, conv_state
717
-
718
-
719
  class Mamba(torch.nn.Module):
720
  def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
721
  super().__init__()
 
551
  return torch.cat(ys, dim=1), ssm_state
552
 
553
 
554
+ def _ssd_chunk_scan_combined_cpu(
555
+ x: torch.Tensor,
556
+ dt: torch.Tensor,
557
+ A: torch.Tensor,
558
+ B: torch.Tensor,
559
+ C: torch.Tensor,
560
+ chunk_size: int,
561
+ D: torch.Tensor,
562
+ z: torch.Tensor,
563
+ dt_bias: torch.Tensor,
564
+ dt_softplus: bool,
565
+ ) -> tuple[torch.Tensor, torch.Tensor]:
566
+ # (bsize, nhead, nchunk, chunk_size)
567
+ dt = dt.float() # We want high precision for this before cumsum
568
+ dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) # type: ignore
569
+ if dt_bias is not None:
570
+ dt = dt + dt_bias[None, :, None, None]
571
+ if dt_softplus:
572
+ dt = F.softplus(dt)
573
+ dA = dt * A[None, :, None, None]
574
+ dA_cumsum = torch.cumsum(dA, dim=-1)
575
+
576
+ _, _, nheads, _ = x.shape
577
+ dstate = B.shape[-1]
578
+ _ = dt.shape[2]
579
+
580
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"):
581
+ # Following is equivalent to `mamba_ssm.ops.triton.ssd_combined.chunk_state_ref(B, x, dt, dA_cumsum)`
582
+ # But `einsum` in the above function is too slow in CPU.
583
+ x_ = torch.unflatten(x, 1, (-1, chunk_size))
584
+ assert B.shape[2] == nheads # B should be already expanded
585
+ B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) # (bsize, nchunk, chunk_size, nheads, dstate)
586
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype)
587
+ dt_ = dt.to(x.dtype)
588
+
589
+ # einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B_, decay_states, dt_, x_)
590
+ B_ = B_.permute(0, 1, 3, 4, 2) # bchnl
591
+ tmp = dt_ * decay_states # bhcl
592
+ tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] # bch1l
593
+ tmp = B_ * tmp # bchnl
594
+ x_ = x_.permute(0, 1, 3, 2, 4) # bchlp
595
+ tmp = tmp @ x_ # bchnp
596
+ states = tmp.permute(0, 1, 2, 4, 3) # bchpn
597
+
598
+ states_dtype = states.dtype
599
+ if states.dtype not in [torch.float32, torch.float64]:
600
+ states = states.to(torch.float32)
601
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"):
602
+ out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref(
603
+ states.flatten(start_dim=-2, end_dim=-1),
604
+ dA_cumsum[:, :, :, -1],
605
+ )
606
+ states = torch.unflatten(out, -1, (-1, dstate))
607
+ last_state = torch.unflatten(last_state, -1, (-1, dstate))
608
+ states = states.to(states_dtype)
609
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"):
610
+ out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
611
+
612
+ return out, last_state
613
+
614
+
615
+ @torch.profiler.record_function("ssd_chunk_scan_combined")
616
  def ssd_chunk_scan_combined(
617
  x: torch.Tensor,
618
  dt: torch.Tensor,
 
649
  To avoid updating state, we set dt to -inf and x to 0
650
  because `softplus(-inf) = 0` and `exp(0) = 1`
651
  """
652
+ pad = (chunk_size - length % chunk_size) % chunk_size
653
+ x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
654
+ dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
655
+ B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
656
+ C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
657
+ z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
658
+ if seq_idx is not None:
659
+ seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
 
 
 
 
660
 
661
+ length = x.shape[1]
662
+ assert length % chunk_size == 0, (length, chunk_size)
663
+
664
+ if dt.is_cuda:
665
  dtype = _get_trition_dtype(x.dtype)
666
  out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
667
  x.to(dtype),
 
684
  assert isinstance(out, torch.Tensor)
685
  return out[:, pad:]
686
  else:
687
+ if ssm_state is None and seq_idx is None:
688
+ tmp = _ssd_chunk_scan_combined_cpu(
689
+ x,
690
+ dt,
691
+ A,
692
+ B,
693
+ C,
694
+ chunk_size,
695
+ D=D,
696
+ z=z,
697
+ dt_bias=dt_bias.float(),
698
+ dt_softplus=dt_softplus,
699
+ )
700
+ else:
701
+ if ssm_state is None:
702
+ bsize, _, num_heads, channel = x.shape
703
+ state = B.shape[-1]
704
+ ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
705
+ tmp = _ssd_chunk_scan_combined_naive(
706
+ x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
707
+ )
708
+ tmp = (tmp[0][:, pad:], tmp[1])
709
  if return_final_states:
710
  return tmp
711
  else:
712
  return tmp[0]
713
 
714
 
715
+ def _causal_conv1d_update(
716
+ conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
717
+ ) -> tuple[torch.Tensor, torch.Tensor]:
718
+ dtype = conv_state.dtype
719
+ xBC = xBC.to(dtype)
720
+ weight = weight.to(dtype)
721
+ if conv_state.is_cuda:
722
+ x = causal_conv1d.causal_conv1d_update(
723
+ x=xBC,
724
+ conv_state=conv_state,
725
+ weight=weight[:, 0, :],
726
+ activation="silu",
727
+ )
728
+ return x, conv_state
729
+ else:
730
+ x = causal_conv1d.causal_conv1d_update_ref(
731
+ x=xBC,
732
+ conv_state=conv_state,
733
+ weight=weight[:, 0, :],
734
+ activation="silu",
735
+ )
736
+ return x, conv_state
737
+
738
+
739
+ def _causal_conv1d_naive(
740
+ conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
741
+ ) -> tuple[torch.Tensor, torch.Tensor]:
742
+ length = x.shape[-1]
743
+ out = torch.zeros_like(x)
744
+ for i in range(length):
745
+ if i != 0 and seq_idx is not None:
746
+ conv_state = torch.where(
747
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
748
+ torch.zeros_like(conv_state),
749
+ conv_state,
750
+ )
751
+ out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
752
+ return out, conv_state
753
+
754
+
755
+ @torch.profiler.record_function("causal_conv1d")
756
  def _causal_conv1d(
757
  conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
758
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
 
788
  else:
789
  x = tmp
790
  else:
791
+ if seq_idx is None:
792
+ x, conv_state = causal_conv1d.causal_conv1d_ref(
793
+ x=x,
794
+ initial_states=conv_state,
795
+ return_final_states=True,
796
+ weight=weight[:, 0, :],
797
+ activation="silu",
798
+ )
799
+ else:
800
+ if conv_state is None:
801
+ bsize = x.shape[0]
802
+ dim = weight.shape[0]
803
+ d_conv = weight.shape[-1]
804
+ conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
805
+ x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx)
 
806
  if return_final_states:
807
  return x, conv_state
808
  else:
809
  return x, None
810
 
811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  class Mamba(torch.nn.Module):
813
  def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
814
  super().__init__()