What is the compute needed for GRPO for 7B R1-Distill model?

#4
by AndrewSanders - opened

Anybody who has tried GRPO over any of the R1-Distill models: what is the minimum GPU compute requirement to run the training?
Let's say for R1-Distill-Qwen-7B ?

I am talking about this from the github repo README:

GRPO

accelerate launch --config_file configs/zero3.yaml src/open_r1/grpo.py \
    --output_dir DeepSeek-R1-Distill-Qwen-7B-GRPO \
    --model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \
    --dataset_name AI-MO/NuminaMath-TIR \
    --max_prompt_length 256 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --logging_steps 10 \
    --bf16

Sign up or log in to comment