Commit
·
bbc6d7c
1
Parent(s):
7d2a362
feat: disable flash attn if not supported CUDA version or device capability
Browse files- modeling_clip.py +19 -0
modeling_clip.py
CHANGED
@@ -144,6 +144,25 @@ def _resolve_attention_libs(config: JinaCLIPConfig):
|
|
144 |
'for installation instructions, disabling'
|
145 |
)
|
146 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
return True
|
148 |
return False
|
149 |
|
|
|
144 |
'for installation instructions, disabling'
|
145 |
)
|
146 |
return False
|
147 |
+
major, minor, *_ = torch.version.cuda.split('.')
|
148 |
+
major, minor = int(major), int(minor)
|
149 |
+
if major < 11 or (major == 11 and minor < 7):
|
150 |
+
warnings.warn(
|
151 |
+
'Flash attention requires CUDA>=11.7. Found version '
|
152 |
+
f'{major}.{minor}, disabling'
|
153 |
+
)
|
154 |
+
return False
|
155 |
+
capability = torch.cuda.get_device_capability()
|
156 |
+
major, *_ = capability
|
157 |
+
major = int(major)
|
158 |
+
if major < 8:
|
159 |
+
device_name = torch.cuda.get_device_properties(0).name
|
160 |
+
warnings.warn(
|
161 |
+
'Flash attention requires device capability>=8.0 (NVIDIA Ampere, '
|
162 |
+
f'Hopper or ADA). Found device {device_name} with capability '
|
163 |
+
f'{capability}, disabling'
|
164 |
+
)
|
165 |
+
return False
|
166 |
return True
|
167 |
return False
|
168 |
|