Markus28 commited on
Commit
dae5c58
·
1 Parent(s): 702e6c9

feat: added docstrings

Browse files
Files changed (1) hide show
  1. modeling_lora.py +30 -2
modeling_lora.py CHANGED
@@ -65,6 +65,8 @@ class LoRAParametrization(nn.Module):
65
  fan_in_fan_out = layer_type == "embedding"
66
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
67
 
 
 
68
  if layer_type == "linear":
69
  self.lora_A = nn.Parameter(
70
  initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
@@ -225,7 +227,15 @@ class BertLoRA(BertPreTrainedModel):
225
  return self._main_params_trainable
226
 
227
  @main_params_trainable.setter
228
- def main_params_trainable(self, val):
 
 
 
 
 
 
 
 
229
  self._main_params_trainable = val
230
  for name, param in super().named_parameters():
231
  if "lora" not in name:
@@ -259,7 +269,13 @@ class BertLoRA(BertPreTrainedModel):
259
  use_safetensors: bool = None,
260
  **kwargs,
261
  ):
262
- # TODO: choose between from_bert and super().from_pretrained
 
 
 
 
 
 
263
  return cls.from_bert(pretrained_model_name_or_path)
264
 
265
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
@@ -275,14 +291,26 @@ class BertLoRA(BertPreTrainedModel):
275
 
276
  @property
277
  def current_task(self):
 
 
 
278
  return self._task_idx
279
 
280
  @current_task.setter
281
  def current_task(self, task_idx: Union[None, int]):
 
 
 
 
 
 
 
 
282
  if self._is_merged:
283
  raise Exception('LoRA has been merged, cannot select new task')
284
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
285
  if self._task_idx != task_idx:
 
286
  self._task_idx = task_idx
287
  self.apply(
288
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
 
65
  fan_in_fan_out = layer_type == "embedding"
66
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
67
 
68
+ # For the officially "correct" LoRA initialization, check here: https://github.com/microsoft/LoRA
69
+ # TODO: Ensure that the initialization here is correct
70
  if layer_type == "linear":
71
  self.lora_A = nn.Parameter(
72
  initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
 
227
  return self._main_params_trainable
228
 
229
  @main_params_trainable.setter
230
+ def main_params_trainable(self, val: bool):
231
+ """Whether the main parameters (i.e. those that are not LoRA) should be trainable.
232
+
233
+ This method sets the `requires_grad_` attribute of the main weights
234
+ and controls which parameters are returned in `self.parameters()`.
235
+
236
+ :param val: Whether or not to make the parameters trainable.
237
+ :return: None
238
+ """
239
  self._main_params_trainable = val
240
  for name, param in super().named_parameters():
241
  if "lora" not in name:
 
269
  use_safetensors: bool = None,
270
  **kwargs,
271
  ):
272
+ """
273
+ TODO: choose between from_bert and super().from_pretrained
274
+
275
+ We want to be able to load both a pretrained BertModel, and a trained
276
+ BertLoRA via this method. To this end, we need to check which of these
277
+ models we are expected to load.
278
+ """
279
  return cls.from_bert(pretrained_model_name_or_path)
280
 
281
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
 
291
 
292
  @property
293
  def current_task(self):
294
+ """ Which LoRA is currently selected
295
+ :return: Integer or None (when LoRA is disabled)
296
+ """
297
  return self._task_idx
298
 
299
  @current_task.setter
300
  def current_task(self, task_idx: Union[None, int]):
301
+ """Set the LoRA that is to be used.
302
+
303
+ The LoRA is specified by `task_idx`, which may be an integer >= 0,
304
+ indexing the available LoRAs. If it is None, no LoRA is used.
305
+
306
+ :param task_idx: Which LoRA to use
307
+ :return:
308
+ """
309
  if self._is_merged:
310
  raise Exception('LoRA has been merged, cannot select new task')
311
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
312
  if self._task_idx != task_idx:
313
+ # In this case, we need to update the LoRAs everywhere
314
  self._task_idx = task_idx
315
  self.apply(
316
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)