Update modeling_jamba.py - LoRA support in Mamba (#6)
Browse files- Update modeling_jamba.py - LoRA support in Mamba (409c904957803838229e49676ec3958c2205783d)
- modeling_jamba.py +12 -4
modeling_jamba.py
CHANGED
@@ -943,14 +943,22 @@ class JambaMambaMixer(nn.Module):
|
|
943 |
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
|
944 |
# linear layers, and requires to call the forward pass directly.
|
945 |
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
|
946 |
-
|
947 |
-
|
|
|
|
|
|
|
|
|
|
|
948 |
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
|
949 |
-
self.dt_proj
|
|
|
|
|
|
|
950 |
|
951 |
A = -torch.exp(self.A_log.float())
|
952 |
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
953 |
-
time_proj_bias =
|
954 |
if cache_params is not None and cache_params.seqlen_offset > 0:
|
955 |
scan_outputs = selective_state_update(
|
956 |
cache_params.ssm_states[self.layer_idx],
|
|
|
943 |
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
|
944 |
# linear layers, and requires to call the forward pass directly.
|
945 |
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
|
946 |
+
if hasattr(self.dt_proj, "base_layer"):
|
947 |
+
# In case of LoRA, we need to access the base layer to get the weight
|
948 |
+
time_proj_bias = self.dt_proj.base_layer.bias
|
949 |
+
self.dt_proj.base_layer.bias = None
|
950 |
+
else:
|
951 |
+
time_proj_bias = self.dt_proj.bias
|
952 |
+
self.dt_proj.bias = None
|
953 |
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
|
954 |
+
if hasattr(self.dt_proj, "base_layer"):
|
955 |
+
self.dt_proj.base_layer.bias = time_proj_bias
|
956 |
+
else:
|
957 |
+
self.dt_proj.bias = time_proj_bias
|
958 |
|
959 |
A = -torch.exp(self.A_log.float())
|
960 |
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
961 |
+
time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
|
962 |
if cache_params is not None and cache_params.seqlen_offset > 0:
|
963 |
scan_outputs = selective_state_update(
|
964 |
cache_params.ssm_states[self.layer_idx],
|