updt how flash_attn_triton is imported

#36
by vchiley - opened
Files changed (1) hide show
  1. attention.py +20 -9
attention.py CHANGED
@@ -5,6 +5,7 @@ from typing import Optional
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
 
8
  from torch import nn
9
  from .norm import LPLayerNorm
10
 
@@ -87,9 +88,17 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
87
 
88
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
- from flash_attn import flash_attn_triton
91
  except:
92
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
 
 
 
 
 
 
 
 
93
  check_valid_inputs(query, key, value)
94
  if dropout_p:
95
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
@@ -108,7 +117,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
108
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
- attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112
  output = attn_output.view(*attn_output.shape[:2], -1)
113
  return (output, None)
114
 
@@ -119,7 +128,7 @@ class MultiheadAttention(nn.Module):
119
  additive bias.
120
  """
121
 
122
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
123
  super().__init__()
124
  self.attn_impl = attn_impl
125
  self.clip_qkv = clip_qkv
@@ -141,10 +150,11 @@ class MultiheadAttention(nn.Module):
141
  self.attn_fn = flash_attn_fn
142
  elif self.attn_impl == 'triton':
143
  self.attn_fn = triton_flash_attn_fn
144
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
145
  elif self.attn_impl == 'torch':
146
  self.attn_fn = scaled_multihead_dot_product_attention
147
- if torch.cuda.is_available():
148
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
149
  else:
150
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -178,7 +188,7 @@ class MultiQueryAttention(nn.Module):
178
  additive bias.
179
  """
180
 
181
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
182
  super().__init__()
183
  self.attn_impl = attn_impl
184
  self.clip_qkv = clip_qkv
@@ -201,10 +211,11 @@ class MultiQueryAttention(nn.Module):
201
  self.attn_fn = flash_attn_fn
202
  elif self.attn_impl == 'triton':
203
  self.attn_fn = triton_flash_attn_fn
204
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
205
  elif self.attn_impl == 'torch':
206
  self.attn_fn = scaled_multihead_dot_product_attention
207
- if torch.cuda.is_available():
208
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
209
  else:
210
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
+ from packaging import version
9
  from torch import nn
10
  from .norm import LPLayerNorm
11
 
 
88
 
89
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
90
  try:
91
+ from .flash_attn_triton import flash_attn_func
92
  except:
93
+ _installed = False
94
+ if version.parse(torch.__version__) < version.parse('2.0.0'):
95
+ _installed = True
96
+ try:
97
+ from flash_attn.flash_attn_triton import flash_attn_func
98
+ except:
99
+ _installed = False
100
+ if not _installed:
101
+ raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
102
  check_valid_inputs(query, key, value)
103
  if dropout_p:
104
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
 
117
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
118
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
119
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
120
+ attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
121
  output = attn_output.view(*attn_output.shape[:2], -1)
122
  return (output, None)
123
 
 
128
  additive bias.
129
  """
130
 
131
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
132
  super().__init__()
133
  self.attn_impl = attn_impl
134
  self.clip_qkv = clip_qkv
 
150
  self.attn_fn = flash_attn_fn
151
  elif self.attn_impl == 'triton':
152
  self.attn_fn = triton_flash_attn_fn
153
+ if verbose:
154
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
155
  elif self.attn_impl == 'torch':
156
  self.attn_fn = scaled_multihead_dot_product_attention
157
+ if torch.cuda.is_available() and verbose:
158
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
159
  else:
160
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
188
  additive bias.
189
  """
190
 
191
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
192
  super().__init__()
193
  self.attn_impl = attn_impl
194
  self.clip_qkv = clip_qkv
 
211
  self.attn_fn = flash_attn_fn
212
  elif self.attn_impl == 'triton':
213
  self.attn_fn = triton_flash_attn_fn
214
+ if verbose:
215
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
216
  elif self.attn_impl == 'torch':
217
  self.attn_fn = scaled_multihead_dot_product_attention
218
+ if torch.cuda.is_available() and verbose:
219
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
220
  else:
221
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')