Commit
•
2cc8472
1
Parent(s):
3eceb33
2-adapter-tuning-initial-impl (#30)
Browse files- 2 adapter tuning (3fd28cf83a7aeb3b39b4da99337ae29c84f1b424)
Co-authored-by: Jack Min Ong <[email protected]>
- block.py +11 -1
- embedding.py +26 -4
- mha.py +37 -5
- mlp.py +21 -3
- modeling_lora.py +0 -1
- modeling_xlm_roberta.py +18 -5
block.py
CHANGED
@@ -233,7 +233,17 @@ class Block(nn.Module):
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
if self.return_residual: # mlp out is actually a pair here
|
238 |
mlp_out, hidden_states = mlp_out
|
239 |
if not self.fused_dropout_add_ln:
|
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
+
task_type = mixer_kwargs.get('task_type')
|
237 |
+
if task_type:
|
238 |
+
if isinstance(task_type, tuple):
|
239 |
+
assert mixer_kwargs['cu_seqlens'].shape[0] % 9 == 1
|
240 |
+
split_index = int((mixer_kwargs['cu_seqlens'].shape[0] - 1) / 9)
|
241 |
+
split = mixer_kwargs['cu_seqlens'][split_index]
|
242 |
+
mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'), split=split)
|
243 |
+
else:
|
244 |
+
mlp_out = self.mlp(hidden_states, task_type=task_type)
|
245 |
+
else:
|
246 |
+
mlp_out = self.mlp(hidden_states)
|
247 |
if self.return_residual: # mlp out is actually a pair here
|
248 |
mlp_out, hidden_states = mlp_out
|
249 |
if not self.fused_dropout_add_ln:
|
embedding.py
CHANGED
@@ -47,8 +47,18 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
if self.max_position_embeddings > 0:
|
53 |
if position_ids is None:
|
54 |
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
@@ -58,6 +68,18 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
58 |
if self.type_vocab_size > 0:
|
59 |
if token_type_ids is None:
|
60 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
return embeddings
|
|
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
+
if isinstance(task_type, tuple):
|
51 |
+
assert input_ids.shape[0] % 9 == 0
|
52 |
+
split = int(input_ids.shape[0] / 9)
|
53 |
+
tensor1 = input_ids[:split, :]
|
54 |
+
tensor2 = input_ids[split:, :]
|
55 |
+
emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
|
56 |
+
emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
|
57 |
+
embeddings = torch.cat((emb1, emb2), dim=0)
|
58 |
+
else:
|
59 |
+
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
60 |
+
embeddings = self.word_embeddings(input_ids, **lora_kwargs)
|
61 |
+
|
62 |
if self.max_position_embeddings > 0:
|
63 |
if position_ids is None:
|
64 |
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
|
|
68 |
if self.type_vocab_size > 0:
|
69 |
if token_type_ids is None:
|
70 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
71 |
+
if isinstance(task_type, tuple):
|
72 |
+
assert embeddings.shape[0] % 9 == 0
|
73 |
+
split = int(embeddings.shape[0] / 9)
|
74 |
+
emb1 = embeddings[:split, :, :]
|
75 |
+
emb2 = embeddings[split:, :, :]
|
76 |
+
token_type_embs1 = self.token_type_embeddings(token_type_ids, task_type=task_type[0])
|
77 |
+
token_type_embs2 = self.token_type_embeddings(token_type_ids, task_type=task_type[1])
|
78 |
+
emb1 = emb1 + token_type_embs1
|
79 |
+
emb2 = emb2 + token_type_embs2
|
80 |
+
embeddings = torch.cat((emb1, emb2), dim=0)
|
81 |
+
else:
|
82 |
+
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
83 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
|
84 |
+
embeddings = embeddings + token_type_embeddings
|
85 |
return embeddings
|
mha.py
CHANGED
@@ -643,15 +643,39 @@ class MHA(nn.Module):
|
|
643 |
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
644 |
)
|
645 |
batch, seqlen = x.shape[:2]
|
|
|
646 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
647 |
assert x_kv is None and mixer_subset is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
|
|
649 |
if not self.return_residual:
|
650 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
else:
|
652 |
-
if
|
653 |
-
|
654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
|
656 |
if self.dwconv:
|
657 |
qkv = rearrange(
|
@@ -739,5 +763,13 @@ class MHA(nn.Module):
|
|
739 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
740 |
|
741 |
lora_kwargs.pop('residual', None)
|
742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
return out if not self.return_residual else (out, x)
|
|
|
643 |
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
644 |
)
|
645 |
batch, seqlen = x.shape[:2]
|
646 |
+
lora_kwargs = {}
|
647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
648 |
assert x_kv is None and mixer_subset is None
|
649 |
+
|
650 |
+
split = None
|
651 |
+
if isinstance(task_type, tuple):
|
652 |
+
assert cu_seqlens.shape[0] % 9 == 1
|
653 |
+
split_index = int((cu_seqlens.shape[0] - 1) / 9)
|
654 |
+
split = cu_seqlens[split_index]
|
655 |
+
|
656 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
657 |
+
|
658 |
if not self.return_residual:
|
659 |
+
if isinstance(task_type, tuple):
|
660 |
+
tensor1 = x[:split, :]
|
661 |
+
tensor2 = x[split:, :]
|
662 |
+
qkv1 = self.Wqkv(tensor1, task_type=task_type[0])
|
663 |
+
qkv2 = self.Wqkv(tensor2, task_type=task_type[1])
|
664 |
+
qkv = torch.cat((qkv1, qkv2), dim=0)
|
665 |
+
else:
|
666 |
+
qkv = self.Wqkv(x, **lora_kwargs)
|
667 |
else:
|
668 |
+
if isinstance(task_type, tuple):
|
669 |
+
tensor1 = x[:split, :]
|
670 |
+
tensor2 = x[split:, :]
|
671 |
+
qkv1, tensor1 = self.Wqkv(tensor1, task_type=task_type[0], residual=True)
|
672 |
+
qkv2, tensor2 = self.Wqkv(tensor2, task_type=task_type[1], residual=True)
|
673 |
+
qkv = torch.cat((qkv1, qkv2), dim=0)
|
674 |
+
x = torch.cat((tensor1, tensor2), dim=0)
|
675 |
+
else:
|
676 |
+
if lora_kwargs:
|
677 |
+
lora_kwargs['residual'] = True
|
678 |
+
qkv, x = self.Wqkv(x, **lora_kwargs)
|
679 |
|
680 |
if self.dwconv:
|
681 |
qkv = rearrange(
|
|
|
763 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
764 |
|
765 |
lora_kwargs.pop('residual', None)
|
766 |
+
inp = rearrange(context, "... h d -> ... (h d)")
|
767 |
+
if isinstance(task_type, tuple):
|
768 |
+
tensor1 = inp[:split, :]
|
769 |
+
tensor2 = inp[split:, :]
|
770 |
+
out1 = self.out_proj(tensor1, task_type=task_type[0])
|
771 |
+
out2 = self.out_proj(tensor2, task_type=task_type[1])
|
772 |
+
out = torch.cat((out1, out2), dim=0)
|
773 |
+
else:
|
774 |
+
out = self.out_proj(inp, **lora_kwargs)
|
775 |
return out if not self.return_residual else (out, x)
|
mlp.py
CHANGED
@@ -47,11 +47,29 @@ class Mlp(nn.Module):
|
|
47 |
self.activation = activation
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
-
def forward(self, x, task_type=None):
|
51 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
y = self.activation(y)
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return y if not self.return_residual else (y, x)
|
56 |
|
57 |
|
|
|
47 |
self.activation = activation
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
+
def forward(self, x, task_type=None, split=None):
|
51 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
52 |
+
if split:
|
53 |
+
assert isinstance(task_type, tuple)
|
54 |
+
tensor1 = x[:split, :]
|
55 |
+
tensor2 = x[split:, :]
|
56 |
+
y1 = self.fc1(tensor1, task_type=task_type[0])
|
57 |
+
y2 = self.fc1(tensor2, task_type=task_type[1])
|
58 |
+
y = torch.cat((y1, y2), dim=0)
|
59 |
+
else:
|
60 |
+
y = self.fc1(x, **lora_kwargs)
|
61 |
+
|
62 |
y = self.activation(y)
|
63 |
+
|
64 |
+
if split:
|
65 |
+
assert isinstance(task_type, tuple)
|
66 |
+
tensor1 = y[:split, :]
|
67 |
+
tensor2 = y[split:, :]
|
68 |
+
y1 = self.fc2(tensor1, task_type=task_type[0])
|
69 |
+
y2 = self.fc2(tensor2, task_type=task_type[1])
|
70 |
+
y = torch.cat((y1, y2), dim=0)
|
71 |
+
else:
|
72 |
+
y = self.fc2(y, **lora_kwargs)
|
73 |
return y if not self.return_residual else (y, x)
|
74 |
|
75 |
|
modeling_lora.py
CHANGED
@@ -227,7 +227,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
227 |
roberta: Optional[XLMRobertaModel] = None
|
228 |
):
|
229 |
super().__init__(config)
|
230 |
-
|
231 |
if roberta is None:
|
232 |
self.roberta = XLMRobertaModel(config)
|
233 |
else:
|
|
|
227 |
roberta: Optional[XLMRobertaModel] = None
|
228 |
):
|
229 |
super().__init__(config)
|
|
|
230 |
if roberta is None:
|
231 |
self.roberta = XLMRobertaModel(config)
|
232 |
else:
|
modeling_xlm_roberta.py
CHANGED
@@ -210,10 +210,12 @@ class XLMRobertaEncoder(nn.Module):
|
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
211 |
"""
|
212 |
if key_padding_mask is None or not self.use_flash_attn:
|
213 |
-
mixer_kwargs =
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
217 |
for layer in self.layers:
|
218 |
if self._grad_checkpointing:
|
219 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
@@ -314,7 +316,18 @@ class XLMRobertaPooler(nn.Module):
|
|
314 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
315 |
|
316 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
pooled_output = self.activation(pooled_output)
|
319 |
return pooled_output
|
320 |
|
|
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
211 |
"""
|
212 |
if key_padding_mask is None or not self.use_flash_attn:
|
213 |
+
mixer_kwargs = (
|
214 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
215 |
+
if key_padding_mask is not None
|
216 |
+
else None
|
217 |
+
)
|
218 |
+
mixer_kwargs['task_type'] = task_type
|
219 |
for layer in self.layers:
|
220 |
if self._grad_checkpointing:
|
221 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
316 |
lora_kwargs = {'task_type': task_type} if task_type is not None else {}
|
317 |
|
318 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
319 |
+
|
320 |
+
if isinstance(task_type, tuple):
|
321 |
+
assert first_token_tensor.shape[0] % 9 == 0
|
322 |
+
split = int(first_token_tensor.shape[0] / 9)
|
323 |
+
tensor1 = first_token_tensor[:split, :]
|
324 |
+
tensor2 = first_token_tensor[split:, :]
|
325 |
+
pooled_out1 = self.dense(tensor1, task_type=task_type[0])
|
326 |
+
pooled_out2 = self.dense(tensor2, task_type=task_type[0])
|
327 |
+
pooled_output = torch.cat((pooled_out1, pooled_out2), dim=0)
|
328 |
+
else:
|
329 |
+
pooled_output = self.dense(first_token_tensor, **lora_kwargs)
|
330 |
+
|
331 |
pooled_output = self.activation(pooled_output)
|
332 |
return pooled_output
|
333 |
|