Markus28 commited on
Commit
cdf5490
·
1 Parent(s): c0b46cc

feat: make main parameters trainable

Browse files
Files changed (1) hide show
  1. modeling_lora.py +14 -3
modeling_lora.py CHANGED
@@ -207,11 +207,21 @@ class BertLoRA(BertPreTrainedModel):
207
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
208
  else:
209
  self.bert = bert
 
210
  self._register_lora(num_adaptions)
 
 
 
 
 
 
 
 
 
 
211
  for name, param in super().named_parameters():
212
  if "lora" not in name:
213
- param.requires_grad_(False)
214
- self.current_task = 0
215
 
216
  @classmethod
217
  def from_bert(cls, *args, num_adaptions=1, **kwargs):
@@ -254,6 +264,7 @@ class BertLoRA(BertPreTrainedModel):
254
 
255
  @current_task.setter
256
  def current_task(self, task_idx: Union[None, int]):
 
257
  self._task_idx = task_idx
258
  self.apply(
259
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
@@ -274,5 +285,5 @@ class BertLoRA(BertPreTrainedModel):
274
  for name, param in super().named_parameters(
275
  prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
276
  ):
277
- if "lora" in name:
278
  yield name, param
 
207
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
208
  else:
209
  self.bert = bert
210
+ self._num_adaptions = num_adaptions
211
  self._register_lora(num_adaptions)
212
+ self.main_params_trainable = False
213
+ self.current_task = 0
214
+
215
+ @property
216
+ def main_params_trainable(self):
217
+ return self._main_params_trainable
218
+
219
+ @main_params_trainable.setter
220
+ def main_params_trainable(self, val):
221
+ self._main_params_trainable = val
222
  for name, param in super().named_parameters():
223
  if "lora" not in name:
224
+ param.requires_grad_(val)
 
225
 
226
  @classmethod
227
  def from_bert(cls, *args, num_adaptions=1, **kwargs):
 
264
 
265
  @current_task.setter
266
  def current_task(self, task_idx: Union[None, int]):
267
+ assert task_idx is None or 0 <= task_idx < self._num_adaptions
268
  self._task_idx = task_idx
269
  self.apply(
270
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
 
285
  for name, param in super().named_parameters(
286
  prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
287
  ):
288
+ if "lora" in name or self.main_params_trainable:
289
  yield name, param