fix: fix LoRA implementation
Browse files- modeling_lora.py +2 -1
modeling_lora.py
CHANGED
@@ -210,6 +210,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
210 |
self._num_adaptions = config.num_loras
|
211 |
self._register_lora(self._num_adaptions)
|
212 |
self.main_params_trainable = False
|
|
|
213 |
self.current_task = 0
|
214 |
|
215 |
@property
|
@@ -265,7 +266,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
265 |
@current_task.setter
|
266 |
def current_task(self, task_idx: Union[None, int]):
|
267 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
268 |
-
if self._task_idx != task_idx
|
269 |
self._task_idx = task_idx
|
270 |
self.apply(
|
271 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
|
|
210 |
self._num_adaptions = config.num_loras
|
211 |
self._register_lora(self._num_adaptions)
|
212 |
self.main_params_trainable = False
|
213 |
+
self._task_idx = None
|
214 |
self.current_task = 0
|
215 |
|
216 |
@property
|
|
|
266 |
@current_task.setter
|
267 |
def current_task(self, task_idx: Union[None, int]):
|
268 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
269 |
+
if self._task_idx != task_idx:
|
270 |
self._task_idx = task_idx
|
271 |
self.apply(
|
272 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|