michael-guenther
commited on
Commit
•
ab448a5
1
Parent(s):
d7c984c
change use_flash_attn and add x_attention attribute
Browse files- configuration_clip.py +8 -4
configuration_clip.py
CHANGED
@@ -155,7 +155,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
155 |
add_projections: bool = False,
|
156 |
projection_dim: int = 768,
|
157 |
logit_scale_init_value: float = 2.6592,
|
158 |
-
|
|
|
159 |
**kwargs,
|
160 |
):
|
161 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
@@ -164,7 +165,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
164 |
|
165 |
text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
|
166 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
167 |
-
self.
|
|
|
168 |
|
169 |
super().__init__(**kwargs)
|
170 |
|
@@ -261,8 +263,10 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
261 |
'with default values.'
|
262 |
)
|
263 |
|
264 |
-
if
|
265 |
-
text_config.hf_model_config_kwargs.use_flash_attn =
|
|
|
|
|
266 |
|
267 |
self.text_config = JinaCLIPTextConfig(**text_config)
|
268 |
self.vision_config = JinaCLIPVisionConfig(**vision_config)
|
|
|
155 |
add_projections: bool = False,
|
156 |
projection_dim: int = 768,
|
157 |
logit_scale_init_value: float = 2.6592,
|
158 |
+
use_text_flash_attn: Optional[bool] = None,
|
159 |
+
use_vision_xformers: Optional[bool] = None,
|
160 |
**kwargs,
|
161 |
):
|
162 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
|
|
165 |
|
166 |
text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
|
167 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
168 |
+
self.use_text_flash_attn = use_text_flash_attn
|
169 |
+
self.use_vision_xformers = use_vision_xformers
|
170 |
|
171 |
super().__init__(**kwargs)
|
172 |
|
|
|
263 |
'with default values.'
|
264 |
)
|
265 |
|
266 |
+
if use_text_flash_attn:
|
267 |
+
text_config.hf_model_config_kwargs.use_flash_attn = use_text_flash_attn
|
268 |
+
if use_vision_xformers:
|
269 |
+
vision_config.x_attention = use_vision_xformers
|
270 |
|
271 |
self.text_config = JinaCLIPTextConfig(**text_config)
|
272 |
self.vision_config = JinaCLIPVisionConfig(**vision_config)
|