jupyterjazz
commited on
Commit
•
3703946
1
Parent(s):
851aaca
refactor: stuff
Browse filesSigned-off-by: jupyterjazz <[email protected]>
- modeling_lora.py +37 -26
modeling_lora.py
CHANGED
@@ -14,6 +14,9 @@ from transformers import PretrainedConfig
|
|
14 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
|
15 |
|
16 |
|
|
|
|
|
|
|
17 |
def initialized_weights(
|
18 |
shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
|
19 |
) -> torch.Tensor:
|
@@ -214,7 +217,17 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
214 |
):
|
215 |
super().__init__(config)
|
216 |
|
217 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
self._rank = config.lora_rank
|
219 |
self._dropout_p = config.lora_dropout_p
|
220 |
self._alpha = config.lora_alpha
|
@@ -294,14 +307,20 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
294 |
return self._task_idx
|
295 |
|
296 |
@current_task.setter
|
297 |
-
def current_task(self,
|
298 |
"""Set the LoRA that is to be used.
|
299 |
The LoRA is specified by `task_idx`, which may be an integer >= 0,
|
300 |
indexing the available LoRAs. If it is None, no LoRA is used.
|
301 |
-
:param
|
302 |
:return:
|
303 |
"""
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
if self._task_idx != task_idx:
|
306 |
# In this case, we need to update the LoRAs everywhere
|
307 |
self._task_idx = task_idx
|
@@ -309,9 +328,9 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
309 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
310 |
)
|
311 |
|
312 |
-
def forward(self, *args,
|
313 |
-
if
|
314 |
-
self.current_task =
|
315 |
return super().forward(*args, **kwargs)
|
316 |
|
317 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
@@ -331,35 +350,27 @@ class XLMRobertaLoRA(XLMRobertaModel):
|
|
331 |
def encode(
|
332 |
self,
|
333 |
*args,
|
334 |
-
task:
|
335 |
**kwargs,
|
336 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
337 |
"""
|
338 |
Computes sentence embeddings
|
339 |
|
340 |
-
task(`str`, *optional*, defaults to
|
341 |
-
Specifies the task for which the encoding is intended. This
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
model
|
|
|
346 |
"""
|
347 |
-
|
348 |
-
|
349 |
-
if task:
|
350 |
-
if task in self.config.lora_adaptations:
|
351 |
-
lora_adapter_num = self.config.lora_adaptations.index(task)
|
352 |
-
else:
|
353 |
-
raise ValueError(
|
354 |
-
f"Unsupported task '{task}'. "
|
355 |
-
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
356 |
-
)
|
357 |
-
else:
|
358 |
warnings.warn(
|
359 |
f"Task-specific embeddings are disabled. To enable, specify the `task` "
|
360 |
f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
|
361 |
category=UserWarning,
|
362 |
)
|
363 |
-
|
364 |
|
365 |
return super().encode(*args, **kwargs)
|
|
|
14 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel
|
15 |
|
16 |
|
17 |
+
LORA_NO_UPDATE = '__lora_no_update__'
|
18 |
+
|
19 |
+
|
20 |
def initialized_weights(
|
21 |
shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
|
22 |
) -> torch.Tensor:
|
|
|
217 |
):
|
218 |
super().__init__(config)
|
219 |
|
220 |
+
self._lora_adaptations = config.lora_adaptations
|
221 |
+
if (
|
222 |
+
not isinstance(self._lora_adaptations, list)
|
223 |
+
or len(self._lora_adaptations) < 1
|
224 |
+
):
|
225 |
+
raise ValueError(
|
226 |
+
f'`lora_adaptations` must be a list and contain at least one element'
|
227 |
+
)
|
228 |
+
self._adaptation_map = {
|
229 |
+
name: idx for idx, name in enumerate(self._lora_adaptations)
|
230 |
+
}
|
231 |
self._rank = config.lora_rank
|
232 |
self._dropout_p = config.lora_dropout_p
|
233 |
self._alpha = config.lora_alpha
|
|
|
307 |
return self._task_idx
|
308 |
|
309 |
@current_task.setter
|
310 |
+
def current_task(self, task_name: Union[None, str]):
|
311 |
"""Set the LoRA that is to be used.
|
312 |
The LoRA is specified by `task_idx`, which may be an integer >= 0,
|
313 |
indexing the available LoRAs. If it is None, no LoRA is used.
|
314 |
+
:param task_name: Which LoRA to use
|
315 |
:return:
|
316 |
"""
|
317 |
+
if task_name and task_name not in self._lora_adaptations:
|
318 |
+
raise ValueError(
|
319 |
+
f"Unsupported task '{task_name}'. "
|
320 |
+
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
321 |
+
f"Alternatively, set `task` to `None` if you want to disable LoRA."
|
322 |
+
)
|
323 |
+
task_idx = self._adaptation_map[task_name] if task_name else None
|
324 |
if self._task_idx != task_idx:
|
325 |
# In this case, we need to update the LoRAs everywhere
|
326 |
self._task_idx = task_idx
|
|
|
328 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
329 |
)
|
330 |
|
331 |
+
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
332 |
+
if task != LORA_NO_UPDATE:
|
333 |
+
self.current_task = task
|
334 |
return super().forward(*args, **kwargs)
|
335 |
|
336 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
|
350 |
def encode(
|
351 |
self,
|
352 |
*args,
|
353 |
+
task: Union[str, None] = LORA_NO_UPDATE,
|
354 |
**kwargs,
|
355 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
356 |
"""
|
357 |
Computes sentence embeddings
|
358 |
|
359 |
+
task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
|
360 |
+
Specifies the task for which the encoding is intended. This parameter controls the
|
361 |
+
use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
|
362 |
+
to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
|
363 |
+
existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
|
364 |
+
adapters are disabled, and the model reverts to its original, general-purpose weights.
|
365 |
+
If `task` is set to a specific LoRA adaptation, that adaptation is activated.
|
366 |
"""
|
367 |
+
if task != LORA_NO_UPDATE:
|
368 |
+
if not task:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
warnings.warn(
|
370 |
f"Task-specific embeddings are disabled. To enable, specify the `task` "
|
371 |
f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
|
372 |
category=UserWarning,
|
373 |
)
|
374 |
+
self.current_task = task
|
375 |
|
376 |
return super().encode(*args, **kwargs)
|