Markus28 commited on
Commit
0ff7c3d
·
1 Parent(s): faa9951

feat: use property in LoRA parametrization

Browse files
Files changed (1) hide show
  1. modeling_lora.py +8 -3
modeling_lora.py CHANGED
@@ -116,8 +116,13 @@ class LoRAParametrization(nn.Module):
116
  def forward(self, X):
117
  return self.forward_fn(X)
118
 
119
- def select_task(self, task=None):
120
- self.current_task = task
 
 
 
 
 
121
  if task is None:
122
  self.forward_fn = lambda x: x
123
  else:
@@ -192,7 +197,7 @@ class LoRAParametrization(nn.Module):
192
  @classmethod
193
  def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
194
  if isinstance(layer, LoRAParametrization):
195
- layer.select_task(task_idx)
196
 
197
 
198
  class BertLoRA(BertPreTrainedModel):
 
116
  def forward(self, X):
117
  return self.forward_fn(X)
118
 
119
+ @property
120
+ def current_task(self):
121
+ return self._current_task
122
+
123
+ @current_task.setter
124
+ def current_task(self, task: Union[None, int]):
125
+ self._current_task = task
126
  if task is None:
127
  self.forward_fn = lambda x: x
128
  else:
 
197
  @classmethod
198
  def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
199
  if isinstance(layer, LoRAParametrization):
200
+ layer.current_task = task_idx
201
 
202
 
203
  class BertLoRA(BertPreTrainedModel):