File size: 4,326 Bytes
3454939 7e2c7a4 281816a b1f1c0d ef88e36 3454939 7e2c7a4 281816a 7e2c7a4 a5a19c0 5beb283 281816a 999796c b788917 281816a 5beb283 e7a9dac ef88e36 5beb283 281816a e7a9dac 30f82d7 e7a9dac 30f82d7 e7a9dac 281816a 30f82d7 281816a 30f82d7 281816a c93b759 281816a c93b759 281816a a5a19c0 281816a fa010b5 281816a fa010b5 281816a fa010b5 281816a 999796c a5a19c0 99c7904 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
---
library_name: peft
base_model: mistralai/Mistral-7B-v0.1
license: mit
tags:
- Mathematical Reasoning
language:
- en
datasets:
- adityasihag/math_QAaugP
---
**This repo contains LoRA adapter weights**.
### Model Description
- **Project GitHub Page:** https://github.com/adityasihag1996/math_QA.git
- **Developed by:** [Aditya Sihag](https://www.linkedin.com/in/aditya-sihag-ab29681a9/)
- **Model type:** fine-tuned using QLoRA on 1x RTX 4090
- **Finetuned from model:** mistralai/Mistral-7B-v0.1
## Results
<table>
<thead>
<tr>
<th>Prompt Approach</th>
<th>GSM8k</th>
<th>MATH</th>
</tr>
</thead>
<tbody>
<tr>
<td>Zero-Shot CoT</td>
<td><b>75.81</b></td>
<td><b>-</b></td>
</tr>
</tbody>
</table>
## Training procedure
The following `bitsandbytes` quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: True
- bnb_4bit_compute_dtype: float16
`LoraConfig` params:
- r: 128
- lora_alpha: lora_r * 2
- lora_dropout: 0.05
- bias: "none"
- task_type: "CAUSAL_LM"
- target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
The hyperparameters for the LoRA fine-tuning are listed below:
- epochs: 3
- learning_rate: 5e-5
- batch_size: 256
- max_grad_norm: 1.0
- weight_decay: 0.001
- lr_scheduler_type: "cosine"
- warmup_ratio: 0.03
## Dataset
math_QA dataset is prepared as combination of [MetaMathQA](https://huggingface.co./datasets/meta-math/MetaMathQA) and [MathInstruct](https://huggingface.co./datasets/TIGER-Lab/MathInstruct), and some internal data.
Refer [math_QAaugP](https://huggingface.co./datasets/adityasihag/math_QAaugP)
## Model Usage
```
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
from peft import PeftModel
model_path = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype = torch.float16,
device_map = {"": 0},
)
# Load LoRA and merge
model = PeftModel.from_pretrained(model, "adityasihag/math_QA-Mistral-7B-QLoRA-adapter")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
question = """Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x."""
sample_input = f"""Question: {question} \n Answer: """
sample_input_tokenised = tokenizer(sample_input, return_tensors = "pt").to("cuda")
generated_ids = model.generate(
**sample_input_tokenised,
max_new_tokens = 1024,
temperature = 0.3
)
output = tokenizer.decode(generated_ids[0], skip_special_tokens = True)
print(output)
```
##### Sample Input:
```
Question: Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x. \n Answer:
```
##### Model Output:
```
Given the linear equation 3(x+2)-x=x+9.
First, distribute the 3 in the brackets to get 3x + 6 - x = x + 9.
Simplify the equation to get 2x + 6 = x + 9.
Next, transpose x from the right side to the left side and from the left side to the right side to get x = 9 - 6.
Finally, solve for x to get x = 3.
```
#### Prompt Template:
```
Question: <question>
Answer:
```
## Comparing math_QA models with other SFT LLM models
| Model | GSM8k Pass@1 | MATH Pass@1 |
|---------------------|--------------|-------------|
| LLaMA-2-7B | 14.6 | 2.5 |
| gemma-2b | 17.7 | |
| LLaMA-2-13B | 28.7 | 3.9 |
| LLaMA-2-34B | 42.2 | 6.24 |
| **math_QA-gemma-2B** | **43.66** | |
| gemma-7b | 46.4 | |
| WizardMath-7B | 54.9 | 10.7 |
| Mistral-7B | 35.4 | |
| WizardMath-13B | 63.9 | 14.0 |
| MetaMath-7B | 66.5 | 19.8 |
| MetaMath-13B | 72.3 | 22.4 |
| **math_QA-Mistral-7B** | **75.81** | |
| Arithmo2-Mistral-7B | 76.4 | 27.2 |
| MetaMath-Mistral-7B | 77.7 | 28.2 |
| DeepSeekMath-Instruct-7B | 82.9 | 46.8 |
| GPT4 | 92.0 | 52.9 |
|