|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
try: |
|
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda |
|
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda |
|
except ImportError as e: |
|
print(e) |
|
causal_dot_product_cuda = causal_dot_backward_cuda = None |
|
|
|
|
|
class CausalDotProduct(torch.autograd.Function): |
|
"""Compute the weighted sum of values but attending only to previous |
|
values.""" |
|
dot = { |
|
|
|
"cuda": causal_dot_product_cuda |
|
} |
|
dot_backward = { |
|
|
|
"cuda": causal_dot_backward_cuda |
|
} |
|
|
|
@staticmethod |
|
def forward(ctx, Q, K, V): |
|
|
|
ctx.save_for_backward(Q, K, V) |
|
|
|
|
|
device = Q.device |
|
N, H, L, _ = Q.shape |
|
_, _, _, M = V.shape |
|
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device) |
|
|
|
|
|
CausalDotProduct.dot[device.type]( |
|
Q.data, |
|
K.data, |
|
V.data, |
|
product |
|
) |
|
|
|
|
|
|
|
return product |
|
|
|
@staticmethod |
|
def backward(ctx, grad_out): |
|
|
|
Q, K, V = ctx.saved_tensors |
|
|
|
|
|
grad_Q = torch.zeros_like(Q) |
|
grad_K = torch.zeros_like(K) |
|
grad_V = torch.zeros_like(V) |
|
|
|
|
|
CausalDotProduct.dot_backward[Q.device.type]( |
|
Q.data, |
|
K.data, |
|
V.data, |
|
grad_out, |
|
grad_Q, |
|
grad_K, |
|
grad_V |
|
) |
|
|
|
return grad_Q, grad_K, grad_V |
|
|
|
|
|
|
|
causal_dot_product = CausalDotProduct.apply |