File size: 4,326 Bytes
3454939
7e2c7a4
 
281816a
 
 
 
 
b1f1c0d
ef88e36
3454939
7e2c7a4
281816a
7e2c7a4
 
 
 
a5a19c0
5beb283
281816a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999796c
b788917
281816a
 
 
 
 
 
 
 
 
 
 
 
5beb283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7a9dac
ef88e36
 
5beb283
 
281816a
e7a9dac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30f82d7
e7a9dac
30f82d7
e7a9dac
 
 
 
 
 
 
 
 
 
281816a
 
 
 
30f82d7
281816a
 
 
 
30f82d7
 
 
 
 
281816a
 
c93b759
281816a
 
c93b759
281816a
 
 
a5a19c0
 
281816a
fa010b5
281816a
 
fa010b5
 
281816a
fa010b5
281816a
 
 
999796c
a5a19c0
99c7904
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
---
library_name: peft
base_model: mistralai/Mistral-7B-v0.1
license: mit
tags:
- Mathematical Reasoning
language:
- en
datasets:
- adityasihag/math_QAaugP
---

**This repo contains LoRA adapter weights**. 


### Model Description

- **Project GitHub Page:** https://github.com/adityasihag1996/math_QA.git
- **Developed by:** [Aditya Sihag](https://www.linkedin.com/in/aditya-sihag-ab29681a9/)
- **Model type:** fine-tuned using QLoRA on 1x RTX 4090
- **Finetuned from model:** mistralai/Mistral-7B-v0.1

## Results

<table>
    <thead>
        <tr>
            <th>Prompt Approach</th>
            <th>GSM8k</th>
            <th>MATH</th>
        </tr>
    </thead>
    <tbody>
        <tr>
            <td>Zero-Shot CoT</td>
            <td><b>75.81</b></td>
            <td><b>-</b></td>
        </tr>
    </tbody>
</table>

## Training procedure

The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: True
- bnb_4bit_compute_dtype: float16

`LoraConfig` params:
- r: 128
- lora_alpha: lora_r * 2
- lora_dropout: 0.05
- bias: "none"
- task_type: "CAUSAL_LM"
- target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

The hyperparameters for the LoRA fine-tuning are listed below:
- epochs: 3
- learning_rate: 5e-5
- batch_size: 256
- max_grad_norm: 1.0
- weight_decay: 0.001
- lr_scheduler_type: "cosine"
- warmup_ratio: 0.03

## Dataset
math_QA dataset is prepared as combination of [MetaMathQA](https://huggingface.co./datasets/meta-math/MetaMathQA) and [MathInstruct](https://huggingface.co./datasets/TIGER-Lab/MathInstruct), and some internal data.
Refer [math_QAaugP](https://huggingface.co./datasets/adityasihag/math_QAaugP)

## Model Usage
```
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
from peft import PeftModel

model_path = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype = torch.float16,
    device_map = {"": 0},
)

# Load LoRA and merge
model = PeftModel.from_pretrained(model, "adityasihag/math_QA-Mistral-7B-QLoRA-adapter")
model = model.merge_and_unload()

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

question = """Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x."""

sample_input = f"""Question: {question} \n Answer: """

sample_input_tokenised = tokenizer(sample_input, return_tensors = "pt").to("cuda")

generated_ids = model.generate(
                    **sample_input_tokenised,
                    max_new_tokens = 1024,
                    temperature = 0.3
                )
output = tokenizer.decode(generated_ids[0], skip_special_tokens = True)
print(output)
```

##### Sample Input:
```
Question: Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x. \n Answer: 
```

##### Model Output:
```
Given the linear equation 3(x+2)-x=x+9. 
First, distribute the 3 in the brackets to get 3x + 6 - x = x + 9. 
Simplify the equation to get 2x + 6 = x + 9. 
Next, transpose x from the right side to the left side and from the left side to the right side to get x = 9 - 6. 
Finally, solve for x to get x = 3.
```

#### Prompt Template:
```
Question: <question>
Answer: 
```

## Comparing math_QA models with other SFT LLM models
| Model               | GSM8k Pass@1 | MATH Pass@1 |
|---------------------|--------------|-------------|
| LLaMA-2-7B          | 14.6         | 2.5         |
| gemma-2b         |  17.7        |         |
| LLaMA-2-13B         | 28.7         | 3.9         |
| LLaMA-2-34B         | 42.2         | 6.24        |
| **math_QA-gemma-2B**  | **43.66** |        |
| gemma-7b         |  46.4        |         |
| WizardMath-7B       | 54.9         | 10.7        |
| Mistral-7B         | 35.4         |         |
| WizardMath-13B      | 63.9         | 14.0        |
| MetaMath-7B         | 66.5         | 19.8        |
| MetaMath-13B        | 72.3         | 22.4        |
| **math_QA-Mistral-7B**  | **75.81** |        |
| Arithmo2-Mistral-7B  | 76.4 | 27.2       |
| MetaMath-Mistral-7B  | 77.7 | 28.2       |
| DeepSeekMath-Instruct-7B  | 82.9 | 46.8       |
| GPT4  | 92.0 | 52.9       |