support use_flash_attn in from_pretrained

#2
Files changed (2) hide show
  1. configuration_clip.py +4 -0
  2. modeling_clip.py +5 -0
configuration_clip.py CHANGED
@@ -155,6 +155,8 @@ class JinaCLIPConfig(PretrainedConfig):
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
 
 
158
  **kwargs,
159
  ):
160
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -163,6 +165,8 @@ class JinaCLIPConfig(PretrainedConfig):
163
 
164
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
165
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
 
 
166
 
167
  super().__init__(**kwargs)
168
 
 
155
  add_projections: bool = False,
156
  projection_dim: int = 768,
157
  logit_scale_init_value: float = 2.6592,
158
+ use_text_flash_attn: Optional[bool] = None,
159
+ use_vision_xformers: Optional[bool] = None,
160
  **kwargs,
161
  ):
162
  # If `_config_dict` exist, we use them for the backward compatibility.
 
165
 
166
  text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
167
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
168
+ self.use_text_flash_attn = use_text_flash_attn
169
+ self.use_vision_xformers = use_vision_xformers
170
 
171
  super().__init__(**kwargs)
172
 
modeling_clip.py CHANGED
@@ -210,6 +210,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
210
  text_config = config.text_config
211
  vision_config = config.vision_config
212
 
 
 
 
 
 
213
  self.add_projections = config.add_projections
214
  self.projection_dim = config.projection_dim
215
  self.text_embed_dim = text_config.embed_dim
 
210
  text_config = config.text_config
211
  vision_config = config.vision_config
212
 
213
+ if config.use_text_flash_attn is not None:
214
+ text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
215
+ if config.use_vision_xformers is not None:
216
+ vision_config.x_attention = config.use_vision_xformers
217
+
218
  self.add_projections = config.add_projections
219
  self.projection_dim = config.projection_dim
220
  self.text_embed_dim = text_config.embed_dim