AIR-hl commited on
Commit
2e57fb9
·
verified ·
1 Parent(s): 535e5d6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +136 -3
README.md CHANGED
@@ -1,3 +1,136 @@
1
- ---
2
- license: llama3.2
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3.2
3
+ datasets:
4
+ - HuggingFaceH4/ultrafeedback_binarized
5
+ base_model:
6
+ - tanliboy/llama-3.2-3b-sft
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - trl
10
+ - llama
11
+ - dpo
12
+ - alignment
13
+ - transformers
14
+ - custome
15
+ - chat
16
+ ---
17
+ # Llama-3.2-3B-DPO
18
+
19
+
20
+ ## Model Details
21
+
22
+ - **Model type:** aligned model
23
+ - **License:** llama3.2
24
+ - **Finetuned from model:** [tanliboy/llama-3.2-3b-sft](https://huggingface.co/tanliboy/llama-3.2-3b-sft)
25
+ - **Training data:** [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
26
+ - **Training framework:** [trl](https://github.com/huggingface/trl)
27
+
28
+ ## Training Details
29
+
30
+ devices: 4 * NPU 910B-64GB \
31
+ precision: bf16 mixed-precision \
32
+ global_batch_size: 128
33
+
34
+ ### Training Hyperparameters
35
+
36
+ `attn_implementation`: None \
37
+ `beta`: 0.01 \
38
+ `bf16`: True \
39
+ `learning_rate`: 8e-7 \
40
+ `lr_scheduler_type`: cosine \
41
+ `per_device_train_batch_size`: 8 \
42
+ `gradient_accumulation_steps`: 4 \
43
+ `torch_dtype`: bfloat16 \
44
+ `num_train_epochs`: 1 \
45
+ `max_prompt_length`: 512 \
46
+ `max_length`: 1024 \
47
+ `warmup_ratio`: 0.05
48
+
49
+ ### Results
50
+
51
+ `init_train_loss`: 0.6924 \
52
+ `final_train_loss`: 0.5792 \
53
+ `accuracy`: 0.7188 \
54
+ `reward_margin`: 0.5234
55
+
56
+ ### Training script
57
+
58
+ ```python
59
+ import torch
60
+ from datasets import load_dataset
61
+ from transformers import AutoModelForCausalLM, AutoTokenizer
62
+ import multiprocessing
63
+ from trl import (
64
+ DPOConfig,
65
+ DPOTrainer,
66
+ ModelConfig,
67
+ ScriptArguments,
68
+ TrlParser,
69
+ get_kbit_device_map,
70
+ get_peft_config,
71
+ get_quantization_config,
72
+ )
73
+ from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
74
+
75
+ if __name__ == "__main__":
76
+ parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
77
+ script_args, training_args, model_config = parser.parse_args_and_config()
78
+
79
+ torch_dtype = (
80
+ model_config.torch_dtype
81
+ if model_config.torch_dtype in ["auto", None]
82
+ else getattr(torch, model_config.torch_dtype)
83
+ )
84
+
85
+ quantization_config = get_quantization_config(model_config)
86
+
87
+ model_kwargs = dict(
88
+ revision=model_config.model_revision,
89
+ attn_implementation=model_config.attn_implementation,
90
+ torch_dtype=torch_dtype,
91
+ use_cache=False if training_args.gradient_checkpointing else True,
92
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
93
+ quantization_config=quantization_config,
94
+ )
95
+
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
98
+ )
99
+
100
+ peft_config = get_peft_config(model_config)
101
+ if peft_config is None:
102
+ ref_model = AutoModelForCausalLM.from_pretrained(
103
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
104
+ )
105
+ else:
106
+ ref_model = None
107
+
108
+ tokenizer = AutoTokenizer.from_pretrained(
109
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
110
+ )
111
+ if tokenizer.pad_token is None:
112
+ tokenizer.pad_token = tokenizer.eos_token
113
+ if tokenizer.chat_template is None:
114
+ tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
115
+ if script_args.ignore_bias_buffers:
116
+ model._ddp_params_and_buffers_to_ignore = [
117
+ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
118
+ ]
119
+
120
+ dataset = load_dataset(script_args.dataset_name,
121
+ split=script_args.dataset_train_split)
122
+ dataset=dataset.select_columns(['chosen', 'prompt', 'rejected'])
123
+
124
+ trainer = DPOTrainer(
125
+ model,
126
+ ref_model,
127
+ args=training_args,
128
+ train_dataset=dataset,
129
+ processing_class=tokenizer,
130
+ peft_config=peft_config,
131
+ )
132
+
133
+ trainer.train()
134
+
135
+ trainer.save_model(training_args.output_dir)
136
+ ```