// // Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ // Written by Angelos Katharopoulos , // Apoorv Vyas // #include /** * Compute a*b^T and save it into out. * * a \in R^A * b \in R^B */ inline void vvt_dot(float *a, float *b, float *out, int A, int B) { for (int i=0; i(); auto ka = keys.accessor(); auto va = values.accessor(); auto pa = product.accessor(); #pragma omp parallel for collapse(2) for (int n=0; n(); for (int l=0; l(); auto ka = keys.accessor(); auto va = values.accessor(); auto ga = grad_out.accessor(); auto gqa = grad_queries.accessor(); auto gka = grad_keys.accessor(); auto gva = grad_values.accessor(); #pragma omp parallel for collapse(2) for (int n=0; n(); // Compute the gradient wrt the queries for (int l=0; l=0; l--) { vvt_dot( &qa[n][h][l][0], &ga[n][h][l][0], kvp, E, M ); vmt_dot( &va[n][h][l][0], kvp, &gka[n][h][l][0], E, M ); vm_dot( &ka[n][h][l][0], kvp, &gva[n][h][l][0], E, M ); } } } } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "causal_dot_product", &causal_dot_product, "Compute the weighted sum of values but attending only to previous " "values." ); m.def( "causal_dot_backward", &causal_dot_backward, "Compute the gradient of queries, keys and values given the gradient " "of causal_dot_product." ); }