feat: added docstrings
Browse files- modeling_lora.py +30 -2
modeling_lora.py
CHANGED
@@ -65,6 +65,8 @@ class LoRAParametrization(nn.Module):
|
|
65 |
fan_in_fan_out = layer_type == "embedding"
|
66 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
67 |
|
|
|
|
|
68 |
if layer_type == "linear":
|
69 |
self.lora_A = nn.Parameter(
|
70 |
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
@@ -225,7 +227,15 @@ class BertLoRA(BertPreTrainedModel):
|
|
225 |
return self._main_params_trainable
|
226 |
|
227 |
@main_params_trainable.setter
|
228 |
-
def main_params_trainable(self, val):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
self._main_params_trainable = val
|
230 |
for name, param in super().named_parameters():
|
231 |
if "lora" not in name:
|
@@ -259,7 +269,13 @@ class BertLoRA(BertPreTrainedModel):
|
|
259 |
use_safetensors: bool = None,
|
260 |
**kwargs,
|
261 |
):
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
return cls.from_bert(pretrained_model_name_or_path)
|
264 |
|
265 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
@@ -275,14 +291,26 @@ class BertLoRA(BertPreTrainedModel):
|
|
275 |
|
276 |
@property
|
277 |
def current_task(self):
|
|
|
|
|
|
|
278 |
return self._task_idx
|
279 |
|
280 |
@current_task.setter
|
281 |
def current_task(self, task_idx: Union[None, int]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
if self._is_merged:
|
283 |
raise Exception('LoRA has been merged, cannot select new task')
|
284 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
285 |
if self._task_idx != task_idx:
|
|
|
286 |
self._task_idx = task_idx
|
287 |
self.apply(
|
288 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
|
|
65 |
fan_in_fan_out = layer_type == "embedding"
|
66 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
67 |
|
68 |
+
# For the officially "correct" LoRA initialization, check here: https://github.com/microsoft/LoRA
|
69 |
+
# TODO: Ensure that the initialization here is correct
|
70 |
if layer_type == "linear":
|
71 |
self.lora_A = nn.Parameter(
|
72 |
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
|
|
227 |
return self._main_params_trainable
|
228 |
|
229 |
@main_params_trainable.setter
|
230 |
+
def main_params_trainable(self, val: bool):
|
231 |
+
"""Whether the main parameters (i.e. those that are not LoRA) should be trainable.
|
232 |
+
|
233 |
+
This method sets the `requires_grad_` attribute of the main weights
|
234 |
+
and controls which parameters are returned in `self.parameters()`.
|
235 |
+
|
236 |
+
:param val: Whether or not to make the parameters trainable.
|
237 |
+
:return: None
|
238 |
+
"""
|
239 |
self._main_params_trainable = val
|
240 |
for name, param in super().named_parameters():
|
241 |
if "lora" not in name:
|
|
|
269 |
use_safetensors: bool = None,
|
270 |
**kwargs,
|
271 |
):
|
272 |
+
"""
|
273 |
+
TODO: choose between from_bert and super().from_pretrained
|
274 |
+
|
275 |
+
We want to be able to load both a pretrained BertModel, and a trained
|
276 |
+
BertLoRA via this method. To this end, we need to check which of these
|
277 |
+
models we are expected to load.
|
278 |
+
"""
|
279 |
return cls.from_bert(pretrained_model_name_or_path)
|
280 |
|
281 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
|
|
291 |
|
292 |
@property
|
293 |
def current_task(self):
|
294 |
+
""" Which LoRA is currently selected
|
295 |
+
:return: Integer or None (when LoRA is disabled)
|
296 |
+
"""
|
297 |
return self._task_idx
|
298 |
|
299 |
@current_task.setter
|
300 |
def current_task(self, task_idx: Union[None, int]):
|
301 |
+
"""Set the LoRA that is to be used.
|
302 |
+
|
303 |
+
The LoRA is specified by `task_idx`, which may be an integer >= 0,
|
304 |
+
indexing the available LoRAs. If it is None, no LoRA is used.
|
305 |
+
|
306 |
+
:param task_idx: Which LoRA to use
|
307 |
+
:return:
|
308 |
+
"""
|
309 |
if self._is_merged:
|
310 |
raise Exception('LoRA has been merged, cannot select new task')
|
311 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
312 |
if self._task_idx != task_idx:
|
313 |
+
# In this case, we need to update the LoRAs everywhere
|
314 |
self._task_idx = task_idx
|
315 |
self.apply(
|
316 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|