jupyterjazz commited on
Commit
4e13c90
1 Parent(s): acffa62

refactor: kwargs comprehension

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (4) hide show
  1. embedding.py +1 -3
  2. mha.py +3 -6
  3. mlp.py +1 -3
  4. modeling_xlm_roberta.py +2 -6
embedding.py CHANGED
@@ -47,9 +47,7 @@ class XLMRobertaEmbeddings(nn.Module):
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- lora_kwargs = {}
51
- if task is not None:
52
- lora_kwargs['task'] = task
53
  embeddings = self.word_embeddings(input_ids, **lora_kwargs)
54
  if self.max_position_embeddings > 0:
55
  if position_ids is None:
 
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ lora_kwargs = {'task': task} if task is not None else {}
 
 
51
  embeddings = self.word_embeddings(input_ids, **lora_kwargs)
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
mha.py CHANGED
@@ -645,14 +645,11 @@ class MHA(nn.Module):
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
- lora_kwargs = {}
649
- if task is not None:
650
- lora_kwargs['task'] = task
651
- lora_kwargs['residual'] = self.return_residual
652
-
653
  if not self.return_residual:
654
  qkv = self.Wqkv(x, **lora_kwargs)
655
  else:
 
656
  qkv, x = self.Wqkv(x, **lora_kwargs)
657
 
658
  if self.dwconv:
@@ -739,6 +736,6 @@ class MHA(nn.Module):
739
  context = self._update_kvcache_attention(q, kv, inference_params)
740
  else:
741
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
742
- lora_kwargs.pop('residual', None)
743
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
744
  return out if not self.return_residual else (out, x)
 
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
+ lora_kwargs = {'task': task} if task is not None else {}
 
 
 
 
649
  if not self.return_residual:
650
  qkv = self.Wqkv(x, **lora_kwargs)
651
  else:
652
+ lora_kwargs['residual'] = True
653
  qkv, x = self.Wqkv(x, **lora_kwargs)
654
 
655
  if self.dwconv:
 
736
  context = self._update_kvcache_attention(q, kv, inference_params)
737
  else:
738
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
739
+
740
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
741
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -48,9 +48,7 @@ class Mlp(nn.Module):
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
  def forward(self, x, task):
51
- lora_kwargs = {}
52
- if task is not None:
53
- lora_kwargs['task'] = task
54
  y = self.fc1(x, **lora_kwargs)
55
  y = self.activation(y)
56
  y = self.fc2(y, **lora_kwargs)
 
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
  def forward(self, x, task):
51
+ lora_kwargs = {'task': task} if task is not None else {}
 
 
52
  y = self.fc1(x, **lora_kwargs)
53
  y = self.activation(y)
54
  y = self.fc2(y, **lora_kwargs)
modeling_xlm_roberta.py CHANGED
@@ -313,9 +313,7 @@ class XLMRobertaPooler(nn.Module):
313
  def forward(self, hidden_states, pool=True, task=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
- lora_kwargs = {}
317
- if task is not None:
318
- lora_kwargs['task'] = task
319
 
320
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
321
  pooled_output = self.dense(first_token_tensor, **lora_kwargs)
@@ -550,9 +548,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
550
  )
551
  else:
552
  range_iter = range(0, len(sentences), batch_size)
553
- lora_kwargs = {}
554
- if task is not None:
555
- lora_kwargs['task'] = task
556
  for i in range_iter:
557
  encoded_input = self.tokenizer(
558
  sentences[i : i + batch_size],
 
313
  def forward(self, hidden_states, pool=True, task=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
+ lora_kwargs = {'task': task} if task is not None else {}
 
 
317
 
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
  pooled_output = self.dense(first_token_tensor, **lora_kwargs)
 
548
  )
549
  else:
550
  range_iter = range(0, len(sentences), batch_size)
551
+ lora_kwargs = {'task': task} if task is not None else {}
 
 
552
  for i in range_iter:
553
  encoded_input = self.tokenizer(
554
  sentences[i : i + batch_size],