Markus28 commited on
Commit
20706dd
·
1 Parent(s): b641603

fix: fix LoRA implementation

Browse files
Files changed (1) hide show
  1. modeling_lora.py +2 -1
modeling_lora.py CHANGED
@@ -210,6 +210,7 @@ class BertLoRA(BertPreTrainedModel):
210
  self._num_adaptions = config.num_loras
211
  self._register_lora(self._num_adaptions)
212
  self.main_params_trainable = False
 
213
  self.current_task = 0
214
 
215
  @property
@@ -265,7 +266,7 @@ class BertLoRA(BertPreTrainedModel):
265
  @current_task.setter
266
  def current_task(self, task_idx: Union[None, int]):
267
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
268
- if self._task_idx != task_idx
269
  self._task_idx = task_idx
270
  self.apply(
271
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
 
210
  self._num_adaptions = config.num_loras
211
  self._register_lora(self._num_adaptions)
212
  self.main_params_trainable = False
213
+ self._task_idx = None
214
  self.current_task = 0
215
 
216
  @property
 
266
  @current_task.setter
267
  def current_task(self, task_idx: Union[None, int]):
268
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
269
+ if self._task_idx != task_idx:
270
  self._task_idx = task_idx
271
  self.apply(
272
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)