gmastrapas commited on
Commit
bbc6d7c
·
1 Parent(s): 7d2a362

feat: disable flash attn if not supported CUDA version or device capability

Browse files
Files changed (1) hide show
  1. modeling_clip.py +19 -0
modeling_clip.py CHANGED
@@ -144,6 +144,25 @@ def _resolve_attention_libs(config: JinaCLIPConfig):
144
  'for installation instructions, disabling'
145
  )
146
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  return True
148
  return False
149
 
 
144
  'for installation instructions, disabling'
145
  )
146
  return False
147
+ major, minor, *_ = torch.version.cuda.split('.')
148
+ major, minor = int(major), int(minor)
149
+ if major < 11 or (major == 11 and minor < 7):
150
+ warnings.warn(
151
+ 'Flash attention requires CUDA>=11.7. Found version '
152
+ f'{major}.{minor}, disabling'
153
+ )
154
+ return False
155
+ capability = torch.cuda.get_device_capability()
156
+ major, *_ = capability
157
+ major = int(major)
158
+ if major < 8:
159
+ device_name = torch.cuda.get_device_properties(0).name
160
+ warnings.warn(
161
+ 'Flash attention requires device capability>=8.0 (NVIDIA Ampere, '
162
+ f'Hopper or ADA). Found device {device_name} with capability '
163
+ f'{capability}, disabling'
164
+ )
165
+ return False
166
  return True
167
  return False
168