feat: make num of loras part of the config
Browse files- configuration_bert.py +3 -1
- modeling_lora.py +3 -3
configuration_bert.py
CHANGED
@@ -86,6 +86,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
86 |
use_qk_norm=True,
|
87 |
emb_pooler=None,
|
88 |
classifier_dropout=None,
|
|
|
89 |
**kwargs,
|
90 |
):
|
91 |
assert 'position_embedding_type' not in kwargs
|
@@ -118,4 +119,5 @@ class JinaBertConfig(PretrainedConfig):
|
|
118 |
self.use_flash_attn = use_flash_attn
|
119 |
self.use_qk_norm = use_qk_norm
|
120 |
self.emb_pooler = emb_pooler
|
121 |
-
self.classifier_dropout = classifier_dropout
|
|
|
|
86 |
use_qk_norm=True,
|
87 |
emb_pooler=None,
|
88 |
classifier_dropout=None,
|
89 |
+
num_loras=5,
|
90 |
**kwargs,
|
91 |
):
|
92 |
assert 'position_embedding_type' not in kwargs
|
|
|
119 |
self.use_flash_attn = use_flash_attn
|
120 |
self.use_qk_norm = use_qk_norm
|
121 |
self.emb_pooler = emb_pooler
|
122 |
+
self.classifier_dropout = classifier_dropout
|
123 |
+
self.num_loras = num_loras
|
modeling_lora.py
CHANGED
@@ -201,14 +201,14 @@ class LoRAParametrization(nn.Module):
|
|
201 |
|
202 |
|
203 |
class BertLoRA(BertPreTrainedModel):
|
204 |
-
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True
|
205 |
super().__init__(config)
|
206 |
if bert is None:
|
207 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
208 |
else:
|
209 |
self.bert = bert
|
210 |
-
self._num_adaptions =
|
211 |
-
self._register_lora(
|
212 |
self.main_params_trainable = False
|
213 |
self.current_task = 0
|
214 |
|
|
|
201 |
|
202 |
|
203 |
class BertLoRA(BertPreTrainedModel):
|
204 |
+
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
|
205 |
super().__init__(config)
|
206 |
if bert is None:
|
207 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
208 |
else:
|
209 |
self.bert = bert
|
210 |
+
self._num_adaptions = config.num_loras
|
211 |
+
self._register_lora(self._num_adaptions)
|
212 |
self.main_params_trainable = False
|
213 |
self.current_task = 0
|
214 |
|