residual_rms / ext-torch /torch_binding.h
drbh
feat: impl residual rms kernel repo
79aac9d
raw
history blame
329 Bytes
#pragma once
#include <optional>
#include <torch/library.h>
#include <vector>
void increment(torch::Tensor &x);
void residual_rms(torch::Tensor &input, torch::Tensor &residual,
torch::Tensor &weight, torch::Tensor &output, double epsilon,
double scale, int64_t mode, int64_t num_threads);