import gradio as gr import pandas as pd # col=['Layer number', 'Hidden size', 'FFN Hidden size', 'Sequence length', 'Head number', 'Group number', # 'dp', 'tp', 'pp', 'cp', 'GPU numbers', 'Batch size', 'FP8', 'Model parameters', 'Model_states', 'Activation', 'Total'] col=['L', 'H', 'FFN', 'S', 'A', 'G', 'DP', 'TP', 'PP', 'CP', 'GPUs', 'B', 'FP8', 'Model parameters (B)', 'Model states (GB)', 'Activation (GB)', 'Total (GB)'] abbr = """
> **Abbreviations of symbols:** |Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name|Abbr|Full name| |---|---|---|---|---|---|---|---|---|---|---|---| |L|Layer number|H|Hidden size|FFN|FFN Hidden size|S|Sequence length|A|Head number|G|Group number|
""" def Get_GigaByte(memory): return memory / 1024**3 def Get_BillionParameter(parameter): return parameter / 1000**3 # model states: def Compute_Parameters_input(seq_length, hidden_size, vocab_size, act_func, tp): num_parameters_word_embedding = hidden_size * vocab_size / tp # position embedding if act_func == "LLaMA": num_parameters_position_embedding = 0 else: num_parameters_position_embedding = seq_length * hidden_size / tp return num_parameters_word_embedding + num_parameters_position_embedding def Compute_Parameters_output(hidden_size, vocab_size, is_tie_word_embedding, act_func, tp): # layernorm: h/2h if act_func == "LLaMA": num_parameters_output_layernorm = hidden_size # RMSNorm else: num_parameters_output_layernorm = 2 * hidden_size # LayerNorm if is_tie_word_embedding == "True": num_parameters_output_embedding = 0 # due to sharedWordEmbedding else: num_parameters_output_embedding = hidden_size * vocab_size / tp return num_parameters_output_layernorm + num_parameters_output_embedding def Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, act_func, tp): # attention: # layernorm: h/2h if act_func == "LLaMA": num_parameters_attention = hidden_size # RMSNorm else: num_parameters_attention = 2 * hidden_size # LayerNorm # QKV weight: 3h*h/tp, bias: 3h/tp # output linear weight: h*h/tp, bias: h num_parameters_attention_Q_weight = hidden_size * hidden_size / tp num_parameters_attention_KV_weight = 2 * kv_hidden_size * hidden_size / tp num_parameters_attention_Linear_weight = hidden_size * hidden_size / tp num_parameters_attention += num_parameters_attention_Q_weight + num_parameters_attention_KV_weight + num_parameters_attention_Linear_weight if is_bias == "True": num_parameters_attention += (hidden_size + 2 * kv_hidden_size) / tp + hidden_size return num_parameters_attention def Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp): # MLP: # layernorm: h/2h if act_func == "LLaMA": num_parameters_mlp = hidden_size # RMSNorm else: num_parameters_mlp = 2 * hidden_size # LayerNorm # mlp1 weight: h*ffn/tp, bias: ffn/tp # mlp2 weight: ffn*h/tp, bias: h if act_func == "LLaMA": num_parameters_mlp += hidden_size * ffn_size * 3 / tp if is_bias == "True": num_parameters_mlp += ffn_size * 2 / tp + hidden_size else: num_parameters_mlp += hidden_size * ffn_size * 2 / tp if is_bias == "True": num_parameters_mlp += ffn_size / tp + hidden_size return num_parameters_mlp def Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, tp, pp): if is_group_query == "False": group_query_num = head_num kv_hidden_size = hidden_size / head_num * group_query_num # input part num_parameters_input = Compute_Parameters_input(seq_length, hidden_size, vocab_size, act_func, tp) # middle layers part num_parameters_attention = Compute_Parameters_attention(hidden_size, kv_hidden_size, is_bias, act_func, tp) num_parameters_mlp = Compute_Parameters_mlp(hidden_size, ffn_size, is_bias, act_func, tp) num_parameters_in_single_layer = num_parameters_attention + num_parameters_mlp num_parameters_in_total_layers = num_parameters_in_single_layer * layer_num / pp # output part parameters_output = Compute_Parameters_output(hidden_size, vocab_size, is_tie_word_embedding, act_func, tp) if pp == 1: num_parameters_total = ( num_parameters_input + num_parameters_in_total_layers + parameters_output # num_parameters_output_layernorm ) else: num_parameters_total = ( num_parameters_input + num_parameters_in_total_layers ) return num_parameters_total def Compute_Weight(numParametersTotal, precision, is_fp8, is_fp8_init): weight_memory = 0 if precision == "FP32": weight_memory = 4 * numParametersTotal else: weight_memory = 2 * numParametersTotal if is_fp8 == "True" and is_fp8_init == "False": weight_memory += 2 * numParametersTotal return weight_memory def Compute_Gradient(numParametersTotal, g_ty): if g_ty == "FP32": gradient_memory = 4 * numParametersTotal elif g_ty =="BF16": gradient_memory = 2 * numParametersTotal return gradient_memory def Compute_Optimizer_states(numParametersTotal, opt_func, o_ty, is_dist_opt, dp, cp): if o_ty == "FP32": optimizer_memory = 4 * 2 * numParametersTotal elif o_ty =="BF16": optimizer_memory = 2 * 2 * numParametersTotal if is_dist_opt == "True": optimizer_memory = optimizer_memory / (dp * cp) # for SGD, we have no optimizer states if opt_func == "SGD": optimizer_memory = 0 return optimizer_memory def Compute_Master_weight(numParametersTotal, precision, is_dist_opt, dp, cp): if precision == "BF16": master_weight_memory = 4 * numParametersTotal else: master_weight_memory = 0 if is_dist_opt == "True": master_weight_memory = master_weight_memory / (dp * cp) return master_weight_memory def Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty): numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, tp, pp) weight_memory = Compute_Weight(numParametersTotal, precision, is_fp8, is_fp8_init) gradient_memory = Compute_Gradient(numParametersTotal, g_ty) optimizer_memory = Compute_Optimizer_states(numParametersTotal, opt_func, o_ty, is_dist_opt, dp, cp) master_weight_memory = Compute_Master_weight(numParametersTotal, precision, is_dist_opt, dp, cp) return numParametersTotal, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, \ weight_memory + gradient_memory + optimizer_memory + master_weight_memory # activation memory: def compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp): # LN 2bsh activation_mem_attn_ln = seq_length * b * hidden_size * training_dtype if is_sp == "False": activation_mem_attn_ln *= tp # attention input X, qkv 2bsh/1bsh activation_mem_attn_qkv = seq_length * b * hidden_size * gemm_dtype if is_sp == "False": activation_mem_attn_qkv *= tp # attention q 2bsh activation_mem_attn_q = seq_length * b * hidden_size * training_dtype # attention k and v 4bsh activation_mem_attn_kv = seq_length * b * kv_hidden_size * training_dtype * 2 # attention proj input 2bsh/1bsh activation_mem_attn_proj = seq_length * b * hidden_size * gemm_dtype # dropout bsh activation_mem_attn_dropout = seq_length * b * hidden_size if is_sp == "False": activation_mem_attn_dropout *= tp # bf16: 2+2+2+4+2+1=13bsh # fp8: 2+1+2+4+1+1=11bsh activation_memory_attn = ( activation_mem_attn_ln + activation_mem_attn_qkv + activation_mem_attn_q + activation_mem_attn_kv + activation_mem_attn_proj + activation_mem_attn_dropout ) return activation_memory_attn def compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp): # LN 2bsh activation_mem_mlp_ln = seq_length * b * hidden_size * training_dtype if is_sp == "False": activation_mem_mlp_ln *= tp # FC1 2bsh/1bsh activation_mem_mlp_fc1 = seq_length * b * hidden_size * gemm_dtype if is_sp == "False": activation_mem_mlp_fc1 *= tp # Act 8bsh if act_func == "LLaMA": activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype * 2 else: activation_mem_mlp_act = seq_length * b * ffn_size * training_dtype # FC2 8bsh/4bsh activation_mem_mlp_fc2 = seq_length * b * ffn_size * gemm_dtype # dropout bsh activation_mem_mlp_dropout = seq_length * b * hidden_size if is_sp == "False": activation_mem_mlp_dropout *= tp # bf16: 2+2+8+8+1=21 # fp8: 2+1+8+4+1=16 activation_memory_mlp = ( activation_mem_mlp_ln + activation_mem_mlp_fc1 + activation_mem_mlp_act + activation_mem_mlp_fc2 + activation_mem_mlp_dropout ) return activation_memory_mlp def compute_activation_memory_input(seq_length, b, hidden_size, pp): # embedding + Dropout return 8 * seq_length * b * pp + seq_length * b * hidden_size * pp def compute_activation_memory_output(seq_length, b, hidden_size, vocab_size): # Inputs to output layer and CE loss(bf16, fp32 * 2). return 2 * seq_length * b * hidden_size + (2 + 4 + 4) * seq_length * b * vocab_size def compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches): # Multiply by interleaved PP memory factor. if vp > 0: interleaved_schedule_memory_penalty = 1 + (pp - 1) / (pp * vp) activation_memory *= interleaved_schedule_memory_penalty # If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size, # so discount accordingly. if vp == 0 and pp > 1: if num_microbatches > 1: activation_memory *= min(1, num_microbatches / pp) return activation_memory def compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp): # Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf. # We are trying to compute the maximum activation footprint, so all calculations in this function # are for the first pipeline stage. # activation dataType for Training if precision == "FP32": training_dtype = 4 else: training_dtype = 2 # activation dataType for GEMM if precision == "FP32": gemm_dtype = 4 elif is_fp8 == "False": gemm_dtype = 2 else: gemm_dtype = 1 # kv_hidden_size if is_group_query == "False": group_query_num = head_num kv_hidden_size = hidden_size / head_num * group_query_num activation_memory_attn = compute_activation_memory_attention(training_dtype, gemm_dtype, seq_length, b, hidden_size, kv_hidden_size, is_sp, tp) activation_memory_mlp = compute_activation_memory_mlp(training_dtype, gemm_dtype, seq_length, b, hidden_size, ffn_size, act_func, is_sp, tp) activation_memory = activation_memory_attn + activation_memory_mlp activation_memory *= layer_num # Now add activation memory required for input embeddings, last LayerNorm and output layer. # Input to embedding (pp_size microbatches in flight). activation_memory_input = compute_activation_memory_input(seq_length, b, hidden_size, pp) activation_memory += activation_memory_input # get num_microbatches num_microbatches = b_global / b / dp / cp activation_memory = compute_activation_memory_pp(activation_memory, vp, pp, num_microbatches) if pp == 1: # Inputs to output layer and CE loss(fp32). activation_memory_output = compute_activation_memory_output(seq_length, b, hidden_size, vocab_size) activation_memory += activation_memory_output elif pp > 1: # Sendrecv memory activation_memory += seq_length * b * hidden_size * 2 # Activation memory is partitioned by TP size due to tensor and sequence model parallelism. return activation_memory / tp / cp # compute_btn.click.function def Compute_ALL_Model_memory(vocab_size, layer_num, hidden_size, ffn_size, seq_length, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty, record_df, count): # data type trans if is_group_query == "True": group_query_num = int(group_query_num) else: group_query_num = head_num # check input [result, Error_message] = check_input(dp, tp, pp, cp, hidden_size, head_num, layer_num, seq_length, vp, b, b_global) if result == False: return Error_message, record_df, count # get model states numParameters, weight_memory, gradient_memory, optimizer_memory, master_weight_memory, model_states_memory = Compute_Model_states(seq_length, vocab_size, layer_num, hidden_size, ffn_size, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_dist_opt, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty) # get activation memory activation_memory = compute_activation_memory(vocab_size, seq_length, layer_num, b, b_global, head_num, hidden_size, ffn_size, act_func, precision, is_fp8, is_sp, is_group_query, group_query_num, tp, pp, dp, cp, vp) # get model parameters numParametersTotal = Compute_Parameters(seq_length, vocab_size, layer_num, hidden_size, ffn_size, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, head_num, 1, 1) # get gpu number gpu_num = dp * tp * pp * cp # get B/GB numParametersTotal = round(Get_BillionParameter(numParametersTotal), 3) numParameters = round(Get_BillionParameter(numParameters), 3) model_states_memory = round(Get_GigaByte(model_states_memory), 3) activation_memory = round(Get_GigaByte(activation_memory), 3) other_memory = 5 Total = round(model_states_memory + activation_memory + other_memory, 3) # record new_row = pd.DataFrame([[layer_num, hidden_size, ffn_size, seq_length, head_num, group_query_num, dp, tp, pp, cp, gpu_num, b, is_fp8, numParametersTotal, model_states_memory, activation_memory, Total]], columns=col) if count == 1: record_df = new_row else: record_df = record_df._append(new_row, ignore_index=True) count = count + 1 # return str(gpu_num), str(model_states) + " GB", str(activation) + " GB", str(total) + " GB", table_data return f""" GPU numbers = {str(gpu_num)}, \n Model parameters = {str(numParametersTotal)} B, \n Model parameters on each device = {str(numParameters)} B, \n Model_states = Weight + Gradient + Optimizer = {str(model_states_memory)} GB, \n Activation = {str(activation_memory)} GB, \n Other memory = 5 GB, \n Total memory consumption = {str(Total)} GB \n """, record_df, count def generate_csv(record_df): # 将 DataFrame 保存为 CSV 文件 csv_filename = "data.csv" record_df.to_csv(csv_filename, index=False) # 返回 CSV 文件路径 return csv_filename # formula string formula = r""" > **Note**🔑: In this formula, we assume LLM training with FP8 training. > 1. LlaMA-family Model. > 2. Interleaved pipeline. > 3. bias = False. > 4. SP = True.
$$ {Total\ Model\ parameters} = HV + (4H^2 + 3H \times FFN + 2H) \times L + H $$ ***
$$ {Model\ states} = (6 + \frac{12}{dp \times cp}) \times (\frac{(\frac{4H^2 + 3H \times FFN}{tp} + 2H) \times L}{pp} + \frac{HV}{tp}) $$ $$ {Activation} = (1 + \frac{pp-1}{pp \times vp}) \times \frac{(8BS + BSH) \times pp + (15BSH + 5BS \times FFN) \times L}{tp \times cp} $$ *** $$ \\begin{gather} {GPU\ numbers} = tp \times pp \times dp \times cp\\\\ {Total\ memory\ consumption} = {Model\ states} + Activation \\end{gather} $$ """ def check_tp(tp, head_num): if head_num % tp == 0: return True else: return False def check_pp(pp, layer_num): if layer_num % pp == 0: return True else: return False def check_cp(cp, seq_length): if seq_length % cp == 0: return True else: return False def check_hidden(hidden_size, head_num): if hidden_size % head_num == 0: return True else: return False def check_b_global(b_global, b, dp, cp): if b_global % (b * dp * cp) == 0: return True else: return False def check_num_microbatch(layer_num, vp, pp, num_microbatches): if vp > 0: if layer_num % (pp * vp) == 0: return True else: return False if vp == 0 and pp > 1: if num_microbatches > 1: if num_microbatches % pp == 0: return True else: return False return True def check_input(dp, tp, pp, cp, hidden_size, head_num, layer_num, seq_length, vp, b, b_global): result = True Error_message = "" if check_tp(tp, head_num) == False: result = False Error_message += "Error message: Please reset Tensor parallelism or head_num, make head_num % tp = 0. \n" if check_pp(pp, layer_num) == False: result = False Error_message += "Error message: Please reset Pipeline parallelism or layer_num, make layer_num % pp = 0. \n" if check_cp(cp, seq_length) == False: result = False Error_message += "Error message: Please reset Context parallelism or seq_length, make seq_length % cp = 0. \n" if check_hidden(hidden_size, head_num) == False: result = False Error_message += "Error message: Please reset hidden_size or head_num, make hidden_size % head_num = 0. \n" if check_b_global(b_global, b, dp, cp) == False: result = False Error_message += "Error message: Please reset b_global or batch_size, make b_global % (batch_size * dp * cp) = 0. \n" if check_num_microbatch(layer_num, vp, pp, b_global / b / dp / cp) == False: result = False Error_message += "Error message: Please reset b_global or batch_size or layer_num or Virtual Pipeline Size, make layer_num % (pp * vp) = 0, num_microbatches % pp = 0. \n" return result, Error_message with gr.Blocks() as demo: with gr.Row(): # Text gr.Markdown( """

GPU memory calculator 🌀

Here's a GPU memory calculator, it helps you to compute memory comsumption in LLM training.

Note: Flash-attention is enabled by default.

""" ) with gr.Row(): with gr.Column(): # Input 1.[Model Parameters] gr.Markdown( """

Model Parameters:

""" ) with gr.Accordion("Model Parameters"): # with gr.Row(): act_func = gr.Radio(["LLaMA", "GPT"], value="LLaMA", label="Model type", info="eg. LLaMa: SwiGLU, RoPE, RMSNorm") #, info="Action Function in MLP, whether to use GLU (Gated Linear Unit). [e.g \"True\" for LlaMA, \"False\" for GPT.]") with gr.Row(): vocab_size = gr.Number(label="Vocab size (V)", value=32000) layer_num = gr.Number(label="Layer number (L)", value=32) with gr.Row(): hidden_size = gr.Number(label="Hidden size (H)", value=4096) ffn_size = gr.Number(label="FFN Hidden size (FFN)", value=11008) with gr.Row(): sequence_len = gr.Number(label="Sequence length (S)", value=2048) head_num = gr.Number(label="Number of Attention Heads (A)", value=32) with gr.Row(): is_group_query = gr.Radio(["True", "False"], value="False", label="Use Group Query Attention") group_query_num = gr.Textbox(label="Number of Query Groups (G)", max_lines=1, value=None, interactive=False) with gr.Row(): is_bias = gr.Radio(["True", "False"], value="False", label="Use Bias") is_tie_word_embedding = gr.Radio(["True", "False"], value="False", label="Tie word embeddings") # change editable function def toggle_textbox_editable(radio_value): # 根据 radio_value 的值来决定 textbox 是否可编辑 if radio_value == "True": return gr.update(interactive=True, value="96") else: return gr.update(interactive=False, value="") # 将 radio 组件的变化连接到函数 is_group_query.change(toggle_textbox_editable, inputs=is_group_query, outputs=group_query_num) with gr.Column(): # Input 2.[Parallelism] gr.Markdown( """

Parallelism config:

""" ) with gr.Accordion("Parallelism config"): # with gr.Row(): dp = gr.Number(label="Data parallelism (dp)", value=2) tp = gr.Number(label="Tensor parallelism (tp)", value=2) pp = gr.Number(label="Pipeline parallelism (pp)", value=2) cp = gr.Number(label="Context parallelism (cp)", value=1) # with gr.Row(): is_sp = gr.Radio(["True", "False"], value="True", label="Sequence parallelism") vp = gr.Number(label="Virtual Pipeline Size (vp)") is_dist_opt = gr.Radio(["True", "False"], value="True", label="Use Distributed Optimizer(Zero1)") with gr.Column(): # Input 3.[Training Settings] gr.Markdown( """

Training Config:

""" ) with gr.Accordion("Training Config"): # with gr.Row(): b = gr.Number(label="Micro Batch size (B)", value=4) b_global = gr.Number(label="Global Batch size", value=64) precision = gr.Dropdown(["FP32", "BF16"], value="BF16", label="Training precision") with gr.Row(): is_fp8 = gr.Radio(["True", "False"], value="True", label="FP8 Training") is_fp8_init = gr.Radio(["True", "False"], value="True", label="FP8 Initialization(will reduce memory)") g_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Gradients Dtype") with gr.Row(): opt_func = gr.Radio(["Adam", "SGD"], value="Adam", label="Optimizer function") o_ty = gr.Dropdown(["FP32", "BF16"], value="FP32", label="Optimizer State Dtype") compute_btn = gr.Button("Compute") with gr.Tab("Output"): with gr.Column(): # gr.Markdown( # """ #

Output Data:

# """ # ) output_text = gr.Textbox( label="Compute result", interactive=False, ) with gr.Tab("Formula"): formula = formula gr.Markdown( formula , latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }] ) # gr.Markdown(abbr) record_df = gr.Dataframe( label="Record Table", headers=col, interactive=False ) download_btn = gr.Button("Download") count = gr.Number(label="Row count", value=1, visible=False) compute_btn.click( fn=Compute_ALL_Model_memory, inputs=[vocab_size, layer_num, hidden_size, ffn_size, sequence_len, head_num, is_group_query, group_query_num, is_bias, is_tie_word_embedding, act_func, dp, tp, pp, cp, is_sp, vp, is_dist_opt, b, b_global, precision, is_fp8, is_fp8_init, g_ty, opt_func, o_ty, record_df, count], outputs=[output_text, record_df, count] ) output_file=gr.File(label="When you click the download button, the downloaded form will be displayed here.") # download func download_btn.click( fn=generate_csv, inputs=record_df, outputs=output_file ) if __name__ == "__main__": demo.launch(share=False, allowed_paths=["/"])