|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/extension.h> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void vvt_dot(float *a, float *b, float *out, int A, int B) { |
|
for (int i=0; i<A; i++) { |
|
float * bi = b; |
|
for (int j=0; j<B; j++) { |
|
*out += (*a) * (*bi); |
|
out++; |
|
bi++; |
|
} |
|
a++; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void vm_dot(float *v, float *m, float *out, int A, int B) { |
|
|
|
|
|
for (int i=0; i<B; i++) { |
|
out[i] = 0; |
|
} |
|
|
|
for (int i=0; i<A; i++) { |
|
float *oi = out; |
|
for (int j=0; j<B; j++) { |
|
*oi += (*v) * (*m); |
|
oi++; |
|
m++; |
|
} |
|
v++; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline void vmt_dot(float *v, float *m, float *out, int A, int B) { |
|
for (int i=0; i<A; i++) { |
|
float *vi = v; |
|
float s = 0; |
|
for (int j=0; j<B; j++) { |
|
s += (*vi) * (*m); |
|
vi++; |
|
m++; |
|
} |
|
|
|
*out = s; |
|
out++; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void causal_dot_product( |
|
const torch::Tensor queries, |
|
const torch::Tensor keys, |
|
const torch::Tensor values, |
|
torch::Tensor product |
|
) { |
|
|
|
int N = queries.size(0); |
|
int H = queries.size(1); |
|
int L = queries.size(2); |
|
int E = queries.size(3); |
|
int M = values.size(3); |
|
|
|
|
|
auto qa = queries.accessor<float, 4>(); |
|
auto ka = keys.accessor<float, 4>(); |
|
auto va = values.accessor<float, 4>(); |
|
auto pa = product.accessor<float, 4>(); |
|
|
|
#pragma omp parallel for collapse(2) |
|
for (int n=0; n<N; n++) { |
|
for (int h=0; h<H; h++) { |
|
auto kv = torch::zeros({E, M}, queries.options()); |
|
float *kvp = kv.data_ptr<float>(); |
|
for (int l=0; l<L; l++) { |
|
vvt_dot( |
|
&ka[n][h][l][0], |
|
&va[n][h][l][0], |
|
kvp, |
|
E, |
|
M |
|
); |
|
vm_dot( |
|
&qa[n][h][l][0], |
|
kvp, |
|
&pa[n][h][l][0], |
|
E, |
|
M |
|
); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void causal_dot_backward( |
|
const torch::Tensor queries, |
|
const torch::Tensor keys, |
|
const torch::Tensor values, |
|
const torch::Tensor grad_out, |
|
torch::Tensor grad_queries, |
|
torch::Tensor grad_keys, |
|
torch::Tensor grad_values |
|
) { |
|
|
|
int N = queries.size(0); |
|
int H = queries.size(1); |
|
int L = queries.size(2); |
|
int E = queries.size(3); |
|
int M = values.size(3); |
|
|
|
|
|
auto qa = queries.accessor<float, 4>(); |
|
auto ka = keys.accessor<float, 4>(); |
|
auto va = values.accessor<float, 4>(); |
|
auto ga = grad_out.accessor<float, 4>(); |
|
auto gqa = grad_queries.accessor<float, 4>(); |
|
auto gka = grad_keys.accessor<float, 4>(); |
|
auto gva = grad_values.accessor<float, 4>(); |
|
|
|
#pragma omp parallel for collapse(2) |
|
for (int n=0; n<N; n++) { |
|
for (int h=0; h<H; h++) { |
|
auto kv = torch::zeros({E, M}, queries.options()); |
|
float *kvp = kv.data_ptr<float>(); |
|
|
|
|
|
for (int l=0; l<L; l++) { |
|
vvt_dot( |
|
&ka[n][h][l][0], |
|
&va[n][h][l][0], |
|
kvp, |
|
E, |
|
M |
|
); |
|
vmt_dot( |
|
&ga[n][h][l][0], |
|
kvp, |
|
&gqa[n][h][l][0], |
|
E, |
|
M |
|
); |
|
} |
|
|
|
|
|
kv.zero_(); |
|
for (int l=L-1; l>=0; l--) { |
|
vvt_dot( |
|
&qa[n][h][l][0], |
|
&ga[n][h][l][0], |
|
kvp, |
|
E, |
|
M |
|
); |
|
vmt_dot( |
|
&va[n][h][l][0], |
|
kvp, |
|
&gka[n][h][l][0], |
|
E, |
|
M |
|
); |
|
vm_dot( |
|
&ka[n][h][l][0], |
|
kvp, |
|
&gva[n][h][l][0], |
|
E, |
|
M |
|
); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def( |
|
"causal_dot_product", |
|
&causal_dot_product, |
|
"Compute the weighted sum of values but attending only to previous " |
|
"values." |
|
); |
|
m.def( |
|
"causal_dot_backward", |
|
&causal_dot_backward, |
|
"Compute the gradient of queries, keys and values given the gradient " |
|
"of causal_dot_product." |
|
); |
|
} |