|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "cutlass_heuristic.h" |
|
#include "cutlass/gemm/gemm.h" |
|
#include <cuda_runtime_api.h> |
|
|
|
#include <vector> |
|
#include <stdexcept> |
|
|
|
namespace fastertransformer { |
|
|
|
struct TileShape { |
|
int m; |
|
int n; |
|
}; |
|
|
|
TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) |
|
{ |
|
switch (tile_config) { |
|
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: |
|
return TileShape{32, 128}; |
|
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: |
|
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: |
|
return TileShape{64, 128}; |
|
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: |
|
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: |
|
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: |
|
return TileShape{128, 128}; |
|
default: |
|
throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config"); |
|
} |
|
} |
|
|
|
bool is_valid_split_k_factor(const int64_t m, |
|
const int64_t n, |
|
const int64_t k, |
|
const TileShape tile_shape, |
|
const int split_k_factor, |
|
const size_t workspace_bytes, |
|
const bool is_weight_only) |
|
{ |
|
|
|
|
|
static constexpr int k_tile = 64; |
|
|
|
|
|
if (is_weight_only) { |
|
if ((k % k_tile) != 0) { |
|
return false; |
|
} |
|
|
|
if ((k % split_k_factor) != 0) { |
|
return false; |
|
} |
|
|
|
const int k_elements_per_split = k / split_k_factor; |
|
if ((k_elements_per_split % k_tile) != 0) { |
|
return false; |
|
} |
|
} |
|
|
|
|
|
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; |
|
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; |
|
const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; |
|
|
|
if (required_ws_bytes > workspace_bytes) { |
|
return false; |
|
} |
|
|
|
return true; |
|
} |
|
|
|
std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) |
|
{ |
|
|
|
std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; |
|
|
|
std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, |
|
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, |
|
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; |
|
|
|
std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, |
|
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, |
|
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; |
|
|
|
const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs; |
|
return simt_configs_only ? simt_configs : allowed_configs; |
|
} |
|
|
|
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) |
|
{ |
|
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only); |
|
|
|
std::vector<CutlassGemmConfig> candidate_configs; |
|
const int min_stages = 2; |
|
const int max_stages = sm >= 80 ? 4 : 2; |
|
|
|
for (const auto& tile_config : tiles) { |
|
for (int stages = min_stages; stages <= max_stages; ++stages) { |
|
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; |
|
candidate_configs.push_back(config); |
|
} |
|
} |
|
|
|
return candidate_configs; |
|
} |
|
|
|
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs, |
|
const std::vector<int>& occupancies, |
|
const int64_t m, |
|
const int64_t n, |
|
const int64_t k, |
|
const int64_t num_experts, |
|
const int split_k_limit, |
|
const size_t workspace_bytes, |
|
const int multi_processor_count, |
|
const int is_weight_only) |
|
{ |
|
|
|
if (occupancies.size() != candidate_configs.size()) { |
|
throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and " |
|
"candidate configs vectors must have equal length."); |
|
} |
|
|
|
CutlassGemmConfig best_config; |
|
|
|
|
|
float config_score = 1.0f; |
|
int config_waves = INT_MAX; |
|
int current_m_tile = 0; |
|
|
|
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; |
|
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { |
|
CutlassGemmConfig candidate_config = candidate_configs[ii]; |
|
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); |
|
int occupancy = occupancies[ii]; |
|
|
|
if (occupancy == 0) { |
|
continue; |
|
} |
|
|
|
|
|
if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile |
|
&& current_m_tile < tile_shape.m) { |
|
continue; |
|
} |
|
|
|
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; |
|
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; |
|
|
|
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { |
|
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { |
|
const int ctas_per_wave = occupancy * multi_processor_count; |
|
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; |
|
|
|
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; |
|
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); |
|
const float current_score = float(num_waves_total) - num_waves_fractional; |
|
|
|
const float score_slack = 0.1f; |
|
if (current_score < config_score |
|
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { |
|
config_score = current_score; |
|
config_waves = num_waves_total; |
|
SplitKStyle split_style = |
|
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; |
|
best_config = CutlassGemmConfig{ |
|
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; |
|
current_m_tile = tile_shape.m; |
|
} |
|
else if (current_score == config_score |
|
&& (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor |
|
|| current_m_tile < tile_shape.m)) { |
|
|
|
SplitKStyle split_style = |
|
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; |
|
best_config = CutlassGemmConfig{ |
|
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; |
|
current_m_tile = tile_shape.m; |
|
config_waves = num_waves_total; |
|
} |
|
} |
|
} |
|
} |
|
|
|
if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { |
|
throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config."); |
|
} |
|
|
|
return best_config; |
|
} |
|
|
|
} |
|
|