File size: 4,519 Bytes
1dc29e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include "fpA_intB_gemm.h"
#include "fpA_intB_gemm/fpA_intB_gemm_template.h"

namespace fastertransformer
{

  ActivationType get_activation(const std::string &activation_name)
  {
    if (activation_name == "identity")
      return ActivationType::Identity;
    if (activation_name == "relu")
      return ActivationType::Relu;
    if (activation_name == "silu")
      return ActivationType::Silu;
    if (activation_name == "gelu")
      return ActivationType::Gelu;
    // todo: more
    return ActivationType::InvalidType;
  }

  void gemm_fp16_int(const half *A,
                     const uint8_t *B,
                     const half *weight_scales,
                     half *C,
                     int m, int n, int k,
                     char *workspace_ptr,
                     size_t workspace_bytes,
                     cudaStream_t stream)
  {
    CutlassFpAIntBGemmRunner<half, uint8_t> runner;
    runner.gemm(A, B, weight_scales,
                C, m, n, k, workspace_ptr, workspace_bytes, stream);
  }

  template <typename WeightType>
  void gemm_fp16_int_bias_act(const half *A,
                              const WeightType *B,
                              const half *weight_scales,
                              const half *bias,
                              half *C,
                              std::optional<std::string> activation,
                              int m, int n, int k, int bias_stride, char *workspace_ptr,
                              size_t workspace_bytes, cudaStream_t stream)
  {
    CutlassFpAIntBGemmRunner<half, WeightType> runner;

    if (!activation && bias == nullptr)
    {
      runner.gemm(A, B, weight_scales,
                  C, m, n, k, workspace_ptr, workspace_bytes, stream);
    }
    else if (!activation)
    {
      runner.gemm_bias_act(A, B, weight_scales, bias,
                           C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream);
    }
    else
    {
      runner.gemm_bias_act(A, B, weight_scales, bias,
                           C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream);
    }
  }

  template <typename WeightType>
  void gemm_fp16_int_bias_act_residual(
      const half *A, const WeightType *B, const half *weight_scales,
      const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
      const std::string &unary_op, int m, int n,
      int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream)
  {
    CutlassFpAIntBGemmRunner<half, WeightType> runner;

    runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual,
                                  C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream);
  }

  template void gemm_fp16_int_bias_act<uint4b_t>(const half *A, const uint4b_t *B,
                                                 const half *weight_scales, const half *bias,
                                                 half *C, std::optional<std::string> activation, int m,
                                                 int n, int k, int bias_stride, char *workspace_ptr,
                                                 size_t workspace_bytes, cudaStream_t stream);

  template void gemm_fp16_int_bias_act_residual<uint4b_t>(
      const half *A, const uint4b_t *B, const half *weight_scales,
      const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
      const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);

  template void gemm_fp16_int_bias_act<uint8_t>(const half *A, const uint8_t *B,
                                                const half *weight_scales, const half *bias,
                                                half *C, std::optional<std::string> activation, int m,
                                                int n, int k, int bias_stride, char *workspace_ptr,
                                                size_t workspace_bytes, cudaStream_t stream);

  template void gemm_fp16_int_bias_act_residual<uint8_t>(
      const half *A, const uint8_t *B, const half *weight_scales,
      const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
      const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);

} // namespace fastertransformer