Axolotl fine tuning Jamba-1.5-Mini

#7
by coranholmes - opened

Could you please also provide an example yaml for using axolotl to sft Jamba-1.5-Mini like that for Large? Thank you so much.

AI21 org

Hi!
for qLoRA+fsdp, you can use the same guide as in the Large card but just change the model name and batch according to your hardware - it requires ~70GB total

base_model: ai21labs/AI21-Jamba-1.5-Mini
tokenizer_type: AutoTokenizer

load_in_4bit: true
strict: false
use_tensorboard: true
datasets:
  - path: cgato/SlimOrcaDedupCleaned
    type: chat_template
    chat_template: jamba
    drop_system_message: true
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-mini-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
lora_target_linear: false

gradient_accumulation_steps: 4 # change according to your hardware
micro_batch_size: 4 # change according to your hardware 
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: true
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true
logging_steps: 1
flash_attention: true

warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: false
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD

Thanks for your reply.

I tried to train Jamba1.5-Mini using axolotl but got some problems. I have posted an issue here. It would be of great help if you can offer some advice. Thank you in advance.

AI21 org

Hi,
can you specify the transformers version you are using ?
you can try this specific commit mentioned in the Large model-card as well under the qLoRA+fsdp finetuning section

pip install git+https://github.com/xgal/transformers@897f80665c37c531b7803f92655dbc9b3a593fe7

It should be fixed in transformers version >= 4.44.2

Sign up or log in to comment