mariordoniez
commited on
Commit
•
fc863b3
1
Parent(s):
d92db69
Upload 9 files
Browse files- README.md +146 -1
- adapter_config.json +20 -0
- adapter_model.bin +3 -0
- config.json +26 -0
- generation_config.json +4 -0
- modeling_mixformer_sequential.py +855 -0
- pytorch_model.bin +3 -0
- training_args.bin +3 -0
README.md
CHANGED
@@ -1,3 +1,148 @@
|
|
1 |
---
|
2 |
-
license:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: other
|
3 |
+
base_model: microsoft/phi-1_5
|
4 |
+
tags:
|
5 |
+
- generated_from_trainer
|
6 |
+
- sales
|
7 |
+
model-index:
|
8 |
+
- name: salesGPT_v2
|
9 |
+
results: []
|
10 |
+
datasets:
|
11 |
+
- goendalf666/sales-conversations-2
|
12 |
+
- goendalf666/sales-conversations-instruction-ext
|
13 |
+
- goendalf666/sales-conversations-instruction-base
|
14 |
+
- goendalf666/sales-textbook_for_convincing_and_selling
|
15 |
+
language:
|
16 |
+
- en
|
17 |
+
pipeline_tag: text-generation
|
18 |
---
|
19 |
+
|
20 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
21 |
+
should probably proofread and complete it, then remove this comment. -->
|
22 |
+
|
23 |
+
# salesGPT_v2
|
24 |
+
|
25 |
+
**Model Card for salesGPT_v2**
|
26 |
+
|
27 |
+
### Model Description
|
28 |
+
salesGPT_v2, derived from microsoft/phi-1_5, is specialized in simulating sales conversations, wherein it understands customer requirements, manages objections, and suggests suitable products or services. It was fine-tuned on a variety of sales-related datasets and seems proficient in initiating conversations, asking pertinent questions, and sustaining interactive dialogues with users.
|
29 |
+
|
30 |
+
### Related Ressources
|
31 |
+
|
32 |
+
Github: https://github.com/tom813/salesGPT_foundation
|
33 |
+
salesGPT_v1: https://huggingface.co/goendalf666/salesGPT_v1
|
34 |
+
|
35 |
+
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63797fcb2cb50dda39d8aec6/re7MmsaYNzTYVH2jEXDDu.png)
|
36 |
+
|
37 |
+
### Intended Uses & Limitations
|
38 |
+
**Intended Uses:**
|
39 |
+
- Simulating sales conversations for training or evaluation purposes.
|
40 |
+
- Providing guidelines or suggested dialogues for sales representatives.
|
41 |
+
|
42 |
+
**Limitations:**
|
43 |
+
- The model might repetitively ask questions in certain scenarios.
|
44 |
+
- May struggle with handling customers who lack specific preferences or knowledge about products.
|
45 |
+
- The objection handling could be more focused on convincing techniques rather than objective criteria.
|
46 |
+
- Challenges in providing appropriate suggestions for customers without specific needs.
|
47 |
+
- Limited effectiveness in handling financial and budgetary conversations or sensitivities.
|
48 |
+
|
49 |
+
### Training and Evaluation Data
|
50 |
+
**Training Data:**
|
51 |
+
1. **Textbook v1 Dataset**
|
52 |
+
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling)
|
53 |
+
- Content: Textbook content for sales, derived from structural points and detailed subpoints created through API calls.
|
54 |
+
|
55 |
+
2. **Sales Conversation Dataset**
|
56 |
+
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations)
|
57 |
+
- Content: Sales conversations, generated based on the chapters of the textbook.
|
58 |
+
|
59 |
+
3. **Sales Conversations Instruction Base Dataset**
|
60 |
+
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-base)
|
61 |
+
- Content: Extended sales conversations with structured dialogues.
|
62 |
+
|
63 |
+
4. **Sales Conversations Instruction Extension Dataset**
|
64 |
+
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-ext)
|
65 |
+
- Content: Updates based on real conversations with the model to improve its proficiency in unconvincing cases.
|
66 |
+
|
67 |
+
**Evaluation Data:**
|
68 |
+
- More information is needed regarding how and where the model was evaluated. If it was assessed on a separate test set, providing access and details to that dataset would be crucial.
|
69 |
+
|
70 |
+
### Training Procedure
|
71 |
+
Fine-tuning of salesGPT_v2 was executed in three phases using the LoRa approach with Rank 64:
|
72 |
+
1. Training on a textbook for 20k steps.
|
73 |
+
2. Training on sales conversations for 40k steps, resulting in salesGPT_v1.
|
74 |
+
3. Training on sales conversations instruction for 40k steps, evolving into salesGPT_v2.
|
75 |
+
|
76 |
+
Hyperparameters used during training:
|
77 |
+
- Learning rate: 0.0002
|
78 |
+
- Train batch size: 2
|
79 |
+
- Eval batch size: 8
|
80 |
+
- Seed: 42
|
81 |
+
- Optimizer: Adam
|
82 |
+
- LR scheduler type: Cosine
|
83 |
+
- Training steps: 40,000
|
84 |
+
- [Additional parameters...](https://huggingface.co/goendalf666/salesGPT_v2)
|
85 |
+
|
86 |
+
Model size details:
|
87 |
+
- Trainable params: 18,874,368
|
88 |
+
- All params: 1,437,145,088
|
89 |
+
- Trainable%: 1.313323766514519
|
90 |
+
|
91 |
+
### Results
|
92 |
+
While salesGPT_v2 showcases notable capabilities like initiating a structured dialogue and asking relevant questions, there are some key areas of improvement and observations:
|
93 |
+
- The model tends to keep asking questions if a user provides vague information.
|
94 |
+
- In some cases, it's not assertive enough to suggest a product when a customer lacks preferences.
|
95 |
+
- It tends to leverage objective criteria over convincing strategies during objections.
|
96 |
+
- Needs better handling of conversations regarding budgetary or financial concerns.
|
97 |
+
|
98 |
+
### Example Usage & Conversations
|
99 |
+
#### Example 1:
|
100 |
+
**Customer**: I am interested in buying a new laptop.
|
101 |
+
**salesGPT_v2**: I understand. Could you please tell me more about your specific requirements and budget?
|
102 |
+
|
103 |
+
#### Example 2:
|
104 |
+
**Customer**: Oh I don't know how I should explain this to my wife. I said we can't go to vacation because we don't have any money.
|
105 |
+
**salesGPT_v2**: I see. Can you please provide more details about your financial situation and the specific expenses you are concerned about?
|
106 |
+
|
107 |
+
These examples underline the model's capability to engage in sales dialogues and also hint at its limitations, such as possibly prying too much into a customer's personal financial situation.
|
108 |
+
|
109 |
+
### Future Directions & Improvement
|
110 |
+
- Enhancement in handling objections by focusing more on persuasive techniques and emotional intelligence.
|
111 |
+
- Improving suggestion-making capability especially when customers are indecisive.
|
112 |
+
- Better navigation through the conversation that involves budgetary and financial aspects without seeming insensitive or intrusive.
|
113 |
+
- Striking a balance between being informative and being too technical in its product descriptions.
|
114 |
+
- Possible implementation of more ethical and privacy-guided conversation guidelines, especially in discussing customers' financial capacities.
|
115 |
+
|
116 |
+
### Ethical Considerations
|
117 |
+
The model’s tendency to repeatedly ask for specific information, especially related to personal financial details, raises ethical concerns regarding privacy and data sensitivity. Care must be taken to ensure the model respects user privacy and does not persistently probe for personal or sensitive information.
|
118 |
+
|
119 |
+
### Conclusion
|
120 |
+
salesGPT_v2 offers a foundation for simulating sales conversations with potential for future refinement in handling objections, making product suggestions, and managing conversations delicately around financial discussions. Future versions might seek to refine its balance between being convincingly persuasive and remaining ethically and emotionally intelligent within dialogues.
|
121 |
+
|
122 |
+
### Inference
|
123 |
+
|
124 |
+
```
|
125 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
126 |
+
|
127 |
+
# Initialize the model and tokenizer
|
128 |
+
cuda = "cuda:0" if torch.cuda.is_available() else ""
|
129 |
+
model = AutoModelForCausalLM.from_pretrained("goendalf666/salesGPT_v2", trust_remote_code=True, torch_dtype=torch.float32, device_map={"":0})
|
130 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, device_map={"":0})
|
131 |
+
|
132 |
+
inputs = tokenizer(conversation_text, return_tensors="pt", return_attention_mask=False)
|
133 |
+
inputs.to(cuda)
|
134 |
+
|
135 |
+
# Generate response
|
136 |
+
outputs = model.generate(**inputs, max_length=512)
|
137 |
+
response_text = tokenizer.batch_decode(outputs)[0]
|
138 |
+
```
|
139 |
+
Or
|
140 |
+
|
141 |
+
Inference script: https://github.com/tom813/salesGPT_foundation/blob/main/inference.py
|
142 |
+
|
143 |
+
### Framework versions
|
144 |
+
|
145 |
+
- Transformers 4.32.1
|
146 |
+
- Pytorch 2.1.0.dev20230829+cu121
|
147 |
+
- Datasets 2.14.5
|
148 |
+
- Tokenizers 0.13.3
|
adapter_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_mapping": null,
|
3 |
+
"base_model_name_or_path": "microsoft/phi-1_5",
|
4 |
+
"bias": "none",
|
5 |
+
"fan_in_fan_out": false,
|
6 |
+
"inference_mode": true,
|
7 |
+
"init_lora_weights": true,
|
8 |
+
"layers_pattern": null,
|
9 |
+
"layers_to_transform": null,
|
10 |
+
"lora_alpha": 16,
|
11 |
+
"lora_dropout": 0.05,
|
12 |
+
"modules_to_save": null,
|
13 |
+
"peft_type": "LORA",
|
14 |
+
"r": 64,
|
15 |
+
"target_modules": [
|
16 |
+
"Wqkv",
|
17 |
+
"out_proj"
|
18 |
+
],
|
19 |
+
"task_type": "CAUSAL_LM"
|
20 |
+
}
|
adapter_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e896326bf0827004d90f3ddd361d14ec98a8bf8e62aa1b490f90eab86cc9e10
|
3 |
+
size 75531342
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "mariordoniez/phi",
|
3 |
+
"activation_function": "gelu_new",
|
4 |
+
"architectures": [
|
5 |
+
"MixFormerSequentialForCausalLM"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "mariordoniez/phi--configuration_mixformer_sequential.MixFormerSequentialConfig",
|
9 |
+
"AutoModelForCausalLM": "mariordoniez/phi--modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
|
10 |
+
},
|
11 |
+
"embd_pdrop": 0.0,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"layer_norm_epsilon": 1e-05,
|
14 |
+
"model_type": "mixformer-sequential",
|
15 |
+
"n_embd": 2048,
|
16 |
+
"n_head": 32,
|
17 |
+
"n_inner": null,
|
18 |
+
"n_layer": 24,
|
19 |
+
"n_positions": 2048,
|
20 |
+
"resid_pdrop": 0.0,
|
21 |
+
"rotary_dim": 32,
|
22 |
+
"tie_word_embeddings": false,
|
23 |
+
"torch_dtype": "float32",
|
24 |
+
"transformers_version": "4.32.1",
|
25 |
+
"vocab_size": 51200
|
26 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.32.1"
|
4 |
+
}
|
modeling_mixformer_sequential.py
ADDED
@@ -0,0 +1,855 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
#
|
4 |
+
# BSD 3-Clause License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2022, Tri Dao, [email protected].
|
7 |
+
# All rights reserved.
|
8 |
+
#
|
9 |
+
# Redistribution and use in source and binary forms, with or without
|
10 |
+
# modification, are permitted provided that the following conditions are met:
|
11 |
+
#
|
12 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
13 |
+
# list of conditions and the following disclaimer.
|
14 |
+
#
|
15 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
16 |
+
# this list of conditions and the following disclaimer in the documentation
|
17 |
+
# and/or other materials provided with the distribution.
|
18 |
+
#
|
19 |
+
# * Neither the name of the copyright holder nor the names of its
|
20 |
+
# contributors may be used to endorse or promote products derived from
|
21 |
+
# this software without specific prior written permission.
|
22 |
+
#
|
23 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
24 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
25 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
26 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
27 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
28 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
29 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
30 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
31 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
32 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
33 |
+
|
34 |
+
from __future__ import annotations
|
35 |
+
|
36 |
+
import math
|
37 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
38 |
+
from dataclasses import dataclass, field
|
39 |
+
|
40 |
+
import torch
|
41 |
+
import torch.nn as nn
|
42 |
+
|
43 |
+
from einops import rearrange, repeat
|
44 |
+
from transformers.activations import ACT2FN
|
45 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
46 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
47 |
+
|
48 |
+
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
49 |
+
|
50 |
+
|
51 |
+
try:
|
52 |
+
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
53 |
+
from flash_attn.ops.fused_dense import FusedDense
|
54 |
+
except:
|
55 |
+
FlashRotaryEmbedding = None
|
56 |
+
FusedDense = None
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class InferenceParams:
|
61 |
+
"""Inference parameters passed to model to efficiently calculate
|
62 |
+
and store context during inference.
|
63 |
+
|
64 |
+
Reference:
|
65 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
max_seqlen: Maximum sequence length.
|
69 |
+
max_batch_size: Maximum batch size.
|
70 |
+
seqlen_offset: Sequence length offset.
|
71 |
+
batch_size_offset: Batch size offset.
|
72 |
+
key_value_memory_dict: Key value memory dictionary.
|
73 |
+
lengths_per_sample: Lengths per sample.
|
74 |
+
|
75 |
+
"""
|
76 |
+
|
77 |
+
max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
|
78 |
+
|
79 |
+
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
|
80 |
+
|
81 |
+
seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
|
82 |
+
|
83 |
+
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
|
84 |
+
|
85 |
+
key_value_memory_dict: Dict[str, Any] = field(
|
86 |
+
default_factory=dict, metadata={"help": "Key value memory dictionary."}
|
87 |
+
)
|
88 |
+
|
89 |
+
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
|
90 |
+
|
91 |
+
|
92 |
+
class Embedding(nn.Module):
|
93 |
+
"""Token embedding with dropout."""
|
94 |
+
|
95 |
+
def __init__(self, config: PretrainedConfig) -> None:
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
99 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
100 |
+
|
101 |
+
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
102 |
+
input_shape = input_ids.size()
|
103 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
104 |
+
|
105 |
+
hidden_states = self.wte(input_ids)
|
106 |
+
hidden_states = self.drop(hidden_states)
|
107 |
+
|
108 |
+
return hidden_states
|
109 |
+
|
110 |
+
|
111 |
+
def _apply_rotary_emb(
|
112 |
+
x: torch.FloatTensor,
|
113 |
+
cos: torch.FloatTensor,
|
114 |
+
sin: torch.FloatTensor,
|
115 |
+
) -> torch.FloatTensor:
|
116 |
+
_, seqlen, _, head_dim = x.shape
|
117 |
+
rotary_seqlen, rotary_dim = cos.shape
|
118 |
+
rotary_dim *= 2
|
119 |
+
|
120 |
+
assert rotary_dim <= head_dim
|
121 |
+
assert seqlen <= rotary_seqlen
|
122 |
+
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
123 |
+
|
124 |
+
x_rot = x[:, :, :, :rotary_dim]
|
125 |
+
x_pass = x[:, :, :, rotary_dim:]
|
126 |
+
|
127 |
+
x1, x2 = x_rot.chunk(2, dim=-1)
|
128 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
129 |
+
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
|
130 |
+
|
131 |
+
x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
|
132 |
+
|
133 |
+
return torch.cat([x_rot, x_pass], axis=-1)
|
134 |
+
|
135 |
+
|
136 |
+
def _apply_rotary_emb_kv(
|
137 |
+
kv: torch.FloatTensor,
|
138 |
+
cos: torch.FloatTensor,
|
139 |
+
sin: torch.FloatTensor,
|
140 |
+
cos_k: Optional[torch.FloatTensor] = None,
|
141 |
+
sin_k: Optional[torch.FloatTensor] = None,
|
142 |
+
) -> torch.FloatTensor:
|
143 |
+
_, seqlen, two, _, head_dim = kv.shape
|
144 |
+
assert two == 2
|
145 |
+
|
146 |
+
rotary_seqlen, rotary_dim = cos.shape
|
147 |
+
rotary_dim *= 2
|
148 |
+
assert rotary_dim <= head_dim
|
149 |
+
assert seqlen <= rotary_seqlen
|
150 |
+
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
151 |
+
|
152 |
+
k_rot = kv[:, :, 0, :, :rotary_dim]
|
153 |
+
k_pass = kv[:, :, 0, :, rotary_dim:]
|
154 |
+
|
155 |
+
k1, k2 = k_rot.chunk(2, dim=-1)
|
156 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
157 |
+
k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
|
158 |
+
|
159 |
+
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
|
160 |
+
|
161 |
+
return torch.cat(
|
162 |
+
[
|
163 |
+
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
164 |
+
kv[:, :, 1:2, :, :],
|
165 |
+
],
|
166 |
+
axis=2,
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
def _apply_rotary_emb_qkv(
|
171 |
+
qkv: torch.FloatTensor,
|
172 |
+
cos: torch.FloatTensor,
|
173 |
+
sin: torch.FloatTensor,
|
174 |
+
cos_k: Optional[torch.FloatTensor] = None,
|
175 |
+
sin_k: Optional[torch.FloatTensor] = None,
|
176 |
+
) -> torch.FloatTensor:
|
177 |
+
_, seqlen, three, _, head_dim = qkv.shape
|
178 |
+
assert three == 3
|
179 |
+
|
180 |
+
rotary_seqlen, rotary_dim = cos.shape
|
181 |
+
rotary_dim *= 2
|
182 |
+
assert rotary_dim <= head_dim
|
183 |
+
assert seqlen <= rotary_seqlen
|
184 |
+
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
185 |
+
|
186 |
+
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
187 |
+
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
188 |
+
|
189 |
+
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
190 |
+
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
191 |
+
|
192 |
+
q1, q2 = q_rot.chunk(2, dim=-1)
|
193 |
+
k1, k2 = k_rot.chunk(2, dim=-1)
|
194 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
195 |
+
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
196 |
+
|
197 |
+
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
198 |
+
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
199 |
+
|
200 |
+
return torch.cat(
|
201 |
+
[
|
202 |
+
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
203 |
+
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
204 |
+
qkv[:, :, 2:3, :, :],
|
205 |
+
],
|
206 |
+
axis=2,
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
class RotaryEmbedding(nn.Module):
|
211 |
+
"""Rotary positional embedding (RoPE).
|
212 |
+
|
213 |
+
Reference:
|
214 |
+
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
215 |
+
https://arxiv.org/pdf/2104.09864.pdf.
|
216 |
+
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
dim: int,
|
222 |
+
base: int = 10000,
|
223 |
+
scale_base: Optional[float] = None,
|
224 |
+
pos_idx_in_fp32: bool = True,
|
225 |
+
device: Optional[str] = None,
|
226 |
+
**kwargs,
|
227 |
+
) -> None:
|
228 |
+
super().__init__()
|
229 |
+
|
230 |
+
if scale_base is not None:
|
231 |
+
raise NotImplementedError
|
232 |
+
|
233 |
+
self.dim = dim
|
234 |
+
self.base = float(base)
|
235 |
+
self.scale_base = scale_base
|
236 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
237 |
+
self.device = device
|
238 |
+
|
239 |
+
# Generate and save the inverse frequency buffer (non-trainable)
|
240 |
+
inv_freq = self._compute_inv_freq(device)
|
241 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
242 |
+
|
243 |
+
# Generate and save the scale buffer (non-trainable)
|
244 |
+
scale = (
|
245 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
246 |
+
if scale_base is not None
|
247 |
+
else None
|
248 |
+
)
|
249 |
+
self.register_buffer("scale", scale, persistent=False)
|
250 |
+
|
251 |
+
self._seq_len_cached = 0
|
252 |
+
self._cos_cached = None
|
253 |
+
self._sin_cached = None
|
254 |
+
self._cos_k_cached = None
|
255 |
+
self._sin_k_cached = None
|
256 |
+
|
257 |
+
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
258 |
+
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
259 |
+
|
260 |
+
def _update_cos_sin_cache(
|
261 |
+
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
|
262 |
+
) -> None:
|
263 |
+
# Reset the tables if sequence length has been chaned, if we are on a
|
264 |
+
# new device or if we are switching from inference mode to training
|
265 |
+
if (
|
266 |
+
seqlen > self._seq_len_cached
|
267 |
+
or self._cos_cached is None
|
268 |
+
or self._cos_cached.device != device
|
269 |
+
or self._cos_cached.dtype != dtype
|
270 |
+
or (self.training and self._cos_cached.is_inference())
|
271 |
+
):
|
272 |
+
self._seq_len_cached = seqlen
|
273 |
+
|
274 |
+
# fp32 is preferred since the output of `torch.arange` can be quite large
|
275 |
+
# and bf16 would lose a lot of precision
|
276 |
+
if self.pos_idx_in_fp32:
|
277 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
278 |
+
if self.inv_freq.dtype != torch.float32:
|
279 |
+
inv_freq = self._compute_inv_freq(device=device)
|
280 |
+
else:
|
281 |
+
inv_freq = self.inv_freq
|
282 |
+
else:
|
283 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
284 |
+
inv_freq = self.inv_freq
|
285 |
+
|
286 |
+
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
287 |
+
freqs = torch.outer(t, inv_freq)
|
288 |
+
if self.scale is None:
|
289 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
290 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
291 |
+
else:
|
292 |
+
power = (
|
293 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
294 |
+
) / self.scale_base
|
295 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
296 |
+
|
297 |
+
# Force the scale multiplication to happen in fp32
|
298 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
299 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
300 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
301 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
302 |
+
|
303 |
+
def forward(
|
304 |
+
self,
|
305 |
+
qkv: torch.Tensor,
|
306 |
+
kv: Optional[torch.Tensor] = None,
|
307 |
+
seqlen_offset: int = 0,
|
308 |
+
max_seqlen: Optional[int] = None,
|
309 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
310 |
+
seqlen = qkv.shape[1]
|
311 |
+
|
312 |
+
if max_seqlen is not None:
|
313 |
+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
314 |
+
else:
|
315 |
+
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
316 |
+
|
317 |
+
if kv is None:
|
318 |
+
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
319 |
+
else:
|
320 |
+
q = _apply_rotary_emb(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
321 |
+
kv = _apply_rotary_emb_kv(kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
322 |
+
|
323 |
+
return q, kv
|
324 |
+
|
325 |
+
|
326 |
+
class MLP(nn.Module):
|
327 |
+
"""Multi-Layer Perceptron.
|
328 |
+
|
329 |
+
Reference:
|
330 |
+
Attention Is All You Need.
|
331 |
+
https://arxiv.org/pdf/1706.03762.pdf.
|
332 |
+
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
|
336 |
+
super().__init__()
|
337 |
+
|
338 |
+
act_fn = config.activation_function if act_fn is None else act_fn
|
339 |
+
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
340 |
+
|
341 |
+
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
342 |
+
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
343 |
+
|
344 |
+
self.fc1 = nn.Linear(config.n_embd, n_inner)
|
345 |
+
self.fc2 = nn.Linear(n_inner, config.n_embd)
|
346 |
+
self.act = ACT2FN[act_fn]
|
347 |
+
|
348 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
349 |
+
hidden_states = self.fc1(hidden_states)
|
350 |
+
hidden_states = self.act(hidden_states)
|
351 |
+
hidden_states = self.fc2(hidden_states)
|
352 |
+
|
353 |
+
return hidden_states
|
354 |
+
|
355 |
+
|
356 |
+
class SelfAttention(nn.Module):
|
357 |
+
"""Self-attention layer (compatible with PyTorch).
|
358 |
+
|
359 |
+
Reference:
|
360 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
361 |
+
|
362 |
+
"""
|
363 |
+
|
364 |
+
def __init__(
|
365 |
+
self,
|
366 |
+
causal: bool = True,
|
367 |
+
softmax_scale: Optional[float] = None,
|
368 |
+
attention_dropout: float = 0.0,
|
369 |
+
) -> None:
|
370 |
+
super().__init__()
|
371 |
+
|
372 |
+
self.causal = causal
|
373 |
+
self.softmax_scale = softmax_scale
|
374 |
+
self.drop = nn.Dropout(attention_dropout)
|
375 |
+
|
376 |
+
def forward(
|
377 |
+
self,
|
378 |
+
qkv: torch.FloatTensor,
|
379 |
+
causal: bool = None,
|
380 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
381 |
+
**kwargs,
|
382 |
+
) -> torch.FloatTensor:
|
383 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
384 |
+
q, k, v = qkv.unbind(dim=2)
|
385 |
+
|
386 |
+
causal = self.causal if causal is None else causal
|
387 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
388 |
+
|
389 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
390 |
+
|
391 |
+
if attention_mask is not None:
|
392 |
+
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
393 |
+
padding_mask.masked_fill_(attention_mask, 0.0)
|
394 |
+
|
395 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
396 |
+
|
397 |
+
if causal:
|
398 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
399 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
400 |
+
|
401 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
402 |
+
attention = self.drop(attention)
|
403 |
+
|
404 |
+
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
405 |
+
|
406 |
+
return output
|
407 |
+
|
408 |
+
|
409 |
+
class CrossAttention(nn.Module):
|
410 |
+
"""Cross-attention layer (compatible with PyTorch).
|
411 |
+
|
412 |
+
Reference:
|
413 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
414 |
+
|
415 |
+
"""
|
416 |
+
|
417 |
+
def __init__(
|
418 |
+
self,
|
419 |
+
causal: bool = True,
|
420 |
+
softmax_scale: Optional[float] = None,
|
421 |
+
attention_dropout: float = 0.0,
|
422 |
+
) -> None:
|
423 |
+
super().__init__()
|
424 |
+
|
425 |
+
self.causal = causal
|
426 |
+
self.softmax_scale = softmax_scale
|
427 |
+
self.drop = nn.Dropout(attention_dropout)
|
428 |
+
|
429 |
+
def forward(
|
430 |
+
self,
|
431 |
+
q: torch.FloatTensor,
|
432 |
+
kv: torch.FloatTensor,
|
433 |
+
causal: bool = None,
|
434 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
435 |
+
**kwargs,
|
436 |
+
) -> torch.FloatTensor:
|
437 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
438 |
+
seqlen_k = kv.shape[1]
|
439 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
440 |
+
|
441 |
+
if kv.shape[3] != q.shape[2]:
|
442 |
+
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
443 |
+
k, v = kv.unbind(dim=2)
|
444 |
+
|
445 |
+
causal = self.causal if causal is None else causal
|
446 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
447 |
+
|
448 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
449 |
+
|
450 |
+
if attention_mask is not None:
|
451 |
+
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
|
452 |
+
padding_mask.masked_fill_(attention_mask, 0.0)
|
453 |
+
|
454 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
455 |
+
|
456 |
+
if causal:
|
457 |
+
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
|
458 |
+
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
|
459 |
+
causal_mask = cols > rows + seqlen_k - seqlen_q
|
460 |
+
|
461 |
+
scores = scores.masked_fill(causal_mask, -10000.0)
|
462 |
+
|
463 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
464 |
+
attention = self.drop(attention)
|
465 |
+
|
466 |
+
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
467 |
+
|
468 |
+
return output
|
469 |
+
|
470 |
+
|
471 |
+
def _find_mha_dims(
|
472 |
+
config: PretrainedConfig,
|
473 |
+
n_head: Optional[int] = None,
|
474 |
+
n_head_kv: Optional[int] = None,
|
475 |
+
head_dim: Optional[int] = None,
|
476 |
+
) -> Tuple[int, int]:
|
477 |
+
assert all(
|
478 |
+
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
479 |
+
), "`config` must have `n_embd` and `n_head` attributes."
|
480 |
+
|
481 |
+
if head_dim is None:
|
482 |
+
assert (
|
483 |
+
config.n_embd % config.n_head == 0
|
484 |
+
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
|
485 |
+
|
486 |
+
if n_head is None and head_dim is None:
|
487 |
+
head_dim = config.n_embd // config.n_head
|
488 |
+
n_head = config.n_head
|
489 |
+
elif n_head is None or head_dim is None:
|
490 |
+
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
491 |
+
|
492 |
+
if n_head_kv is None:
|
493 |
+
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
494 |
+
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
|
495 |
+
|
496 |
+
return n_head, n_head_kv, head_dim
|
497 |
+
|
498 |
+
|
499 |
+
def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
|
500 |
+
num_heads, head_dim = kv.shape[-2:]
|
501 |
+
|
502 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
503 |
+
kv_cache = torch.empty(
|
504 |
+
inference_params.max_batch_size,
|
505 |
+
inference_params.max_seqlen,
|
506 |
+
2,
|
507 |
+
num_heads,
|
508 |
+
head_dim,
|
509 |
+
dtype=kv.dtype,
|
510 |
+
device=kv.device,
|
511 |
+
)
|
512 |
+
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
513 |
+
else:
|
514 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
515 |
+
|
516 |
+
batch_start = inference_params.batch_size_offset
|
517 |
+
batch_end = batch_start + kv.shape[0]
|
518 |
+
assert batch_end <= kv_cache.shape[0]
|
519 |
+
|
520 |
+
sequence_start = inference_params.seqlen_offset
|
521 |
+
sequence_end = sequence_start + kv.shape[1]
|
522 |
+
assert sequence_end <= kv_cache.shape[1]
|
523 |
+
|
524 |
+
assert kv_cache is not None
|
525 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
526 |
+
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
527 |
+
|
528 |
+
return kv
|
529 |
+
|
530 |
+
|
531 |
+
class MHA(nn.Module):
|
532 |
+
"""Multi-head attention layer."""
|
533 |
+
|
534 |
+
def __init__(
|
535 |
+
self,
|
536 |
+
config: PretrainedConfig,
|
537 |
+
dtype: Optional[torch.dtype] = None,
|
538 |
+
device: Optional[str] = None,
|
539 |
+
rotary_dim: Optional[int] = None,
|
540 |
+
rotary_emb_scale_base: Optional[float] = None,
|
541 |
+
n_head: Optional[int] = None,
|
542 |
+
n_head_kv: Optional[int] = None,
|
543 |
+
head_dim: Optional[int] = None,
|
544 |
+
bias: bool = True,
|
545 |
+
causal: bool = True,
|
546 |
+
softmax_scale: Optional[float] = None,
|
547 |
+
layer_idx: Optional[int] = None,
|
548 |
+
return_residual: bool = False,
|
549 |
+
checkpointing: bool = False,
|
550 |
+
) -> None:
|
551 |
+
super().__init__()
|
552 |
+
|
553 |
+
# Rotary embedding
|
554 |
+
self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
555 |
+
if self.rotary_emb_dim > 0:
|
556 |
+
rotary_kwargs = {"device": device}
|
557 |
+
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
558 |
+
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
559 |
+
|
560 |
+
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
561 |
+
if rotary_cls is None:
|
562 |
+
rotary_cls = RotaryEmbedding
|
563 |
+
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
564 |
+
|
565 |
+
# MLP
|
566 |
+
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
567 |
+
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
568 |
+
hidden_size = config.n_embd
|
569 |
+
|
570 |
+
linear_cls = FusedDense if config.fused_dense else nn.Linear
|
571 |
+
if linear_cls is None:
|
572 |
+
linear_cls = nn.Linear
|
573 |
+
|
574 |
+
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
575 |
+
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
576 |
+
|
577 |
+
# Attention
|
578 |
+
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
579 |
+
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
580 |
+
|
581 |
+
self.layer_idx = layer_idx
|
582 |
+
self.return_residual = return_residual
|
583 |
+
self.checkpointing = checkpointing
|
584 |
+
|
585 |
+
def _forward_self_attn(
|
586 |
+
self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
|
587 |
+
) -> torch.FloatTensor:
|
588 |
+
qkv = self.Wqkv(x)
|
589 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
590 |
+
|
591 |
+
if self.rotary_emb_dim > 0:
|
592 |
+
qkv = self.rotary_emb(qkv)
|
593 |
+
|
594 |
+
if self.checkpointing:
|
595 |
+
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
|
596 |
+
|
597 |
+
return self.inner_attn(qkv, attention_mask=attention_mask)
|
598 |
+
|
599 |
+
def _forward_cross_attn(
|
600 |
+
self,
|
601 |
+
x: torch.FloatTensor,
|
602 |
+
past_key_values: Optional[InferenceParams],
|
603 |
+
attention_mask: Optional[torch.BoolTensor],
|
604 |
+
) -> torch.FloatTensor:
|
605 |
+
qkv = self.Wqkv(x)
|
606 |
+
|
607 |
+
q = qkv[..., : self.n_head * self.head_dim]
|
608 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
609 |
+
|
610 |
+
kv = qkv[..., self.n_head * self.head_dim :]
|
611 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
612 |
+
|
613 |
+
seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
|
614 |
+
causal = None if seqlen_offset == 0 else False
|
615 |
+
if self.rotary_emb_dim > 0:
|
616 |
+
q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
|
617 |
+
|
618 |
+
if past_key_values is not None:
|
619 |
+
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
620 |
+
|
621 |
+
if self.checkpointing:
|
622 |
+
return torch.utils.checkpoint.checkpoint(
|
623 |
+
self.inner_cross_attn, q, kv, attention_mask=attention_mask, causal=causal
|
624 |
+
)
|
625 |
+
|
626 |
+
return self.inner_cross_attn(q, kv, attention_mask=attention_mask, causal=causal)
|
627 |
+
|
628 |
+
def forward(
|
629 |
+
self,
|
630 |
+
x: torch.FloatTensor,
|
631 |
+
past_key_values: Optional[InferenceParams] = None,
|
632 |
+
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
633 |
+
**kwargs,
|
634 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
635 |
+
if attention_mask is not None and torch.any(~attention_mask.bool()):
|
636 |
+
attention_mask = attention_mask.bool()
|
637 |
+
else:
|
638 |
+
attention_mask = None
|
639 |
+
|
640 |
+
# MHA
|
641 |
+
if self.n_head == self.n_head_kv:
|
642 |
+
if past_key_values is None:
|
643 |
+
# If `past_key_values` are not supplied, we run self-attention
|
644 |
+
attn_output = self._forward_self_attn(x, attention_mask)
|
645 |
+
else:
|
646 |
+
# If `past_key_values` are supplied, it means that we might have cached values and
|
647 |
+
# could take advantage of cross-attention
|
648 |
+
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
649 |
+
# MQA / GQA
|
650 |
+
else:
|
651 |
+
# Regardless of `past_key_values` being supplied or not, it always use cross-attention
|
652 |
+
# because `q` and `kv` lengths might be different
|
653 |
+
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
654 |
+
|
655 |
+
output = rearrange(attn_output, "... h d -> ... (h d)")
|
656 |
+
output = self.out_proj(output)
|
657 |
+
|
658 |
+
return output if not self.return_residual else (output, x)
|
659 |
+
|
660 |
+
|
661 |
+
class ParallelBlock(nn.Module):
|
662 |
+
"""Parallel block.
|
663 |
+
|
664 |
+
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
|
665 |
+
|
666 |
+
"""
|
667 |
+
|
668 |
+
def __init__(
|
669 |
+
self,
|
670 |
+
config: PretrainedConfig,
|
671 |
+
block_idx: Optional[int] = None,
|
672 |
+
) -> None:
|
673 |
+
super().__init__()
|
674 |
+
|
675 |
+
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
676 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
677 |
+
self.block_idx = block_idx
|
678 |
+
|
679 |
+
self.mixer = MHA(config, layer_idx=block_idx)
|
680 |
+
self.mlp = MLP(config)
|
681 |
+
|
682 |
+
def forward(
|
683 |
+
self,
|
684 |
+
hidden_states: torch.FloatTensor,
|
685 |
+
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
686 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
687 |
+
**kwargs,
|
688 |
+
) -> torch.FloatTensor:
|
689 |
+
residual = hidden_states
|
690 |
+
hidden_states = self.ln(hidden_states)
|
691 |
+
|
692 |
+
attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask)
|
693 |
+
if isinstance(attn_outputs, tuple):
|
694 |
+
attn_outputs = attn_outputs[0]
|
695 |
+
|
696 |
+
attn_outputs = self.resid_dropout(attn_outputs)
|
697 |
+
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
698 |
+
|
699 |
+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
700 |
+
|
701 |
+
return hidden_states
|
702 |
+
|
703 |
+
|
704 |
+
class CausalLMHead(nn.Module):
|
705 |
+
"""Causal Language Modeling head.
|
706 |
+
|
707 |
+
Reference:
|
708 |
+
Improving Language Understanding by Generative Pre-Training.
|
709 |
+
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
710 |
+
|
711 |
+
"""
|
712 |
+
|
713 |
+
def __init__(self, config: PretrainedConfig) -> None:
|
714 |
+
super().__init__()
|
715 |
+
|
716 |
+
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
717 |
+
self.linear = nn.Linear(config.n_embd, config.vocab_size)
|
718 |
+
|
719 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
720 |
+
hidden_states = self.ln(hidden_states)
|
721 |
+
logits = self.linear(hidden_states).to(torch.float32)
|
722 |
+
|
723 |
+
return logits
|
724 |
+
|
725 |
+
|
726 |
+
class CausalLMLoss(nn.Module):
|
727 |
+
"""Causal Language Modeling loss.
|
728 |
+
|
729 |
+
Reference:
|
730 |
+
Improving Language Understanding by Generative Pre-Training.
|
731 |
+
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
732 |
+
|
733 |
+
"""
|
734 |
+
|
735 |
+
def __init__(self, shift_labels: bool = True) -> None:
|
736 |
+
super().__init__()
|
737 |
+
|
738 |
+
self.shift_labels = shift_labels
|
739 |
+
self.loss_fct = nn.CrossEntropyLoss()
|
740 |
+
|
741 |
+
def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
|
742 |
+
if self.shift_labels:
|
743 |
+
logits = logits[..., :-1, :].contiguous()
|
744 |
+
labels = labels[..., 1:].contiguous()
|
745 |
+
|
746 |
+
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
747 |
+
|
748 |
+
return loss
|
749 |
+
|
750 |
+
|
751 |
+
class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
752 |
+
"""MixFormer (sequential for DeepSpeed) pre-trained model."""
|
753 |
+
|
754 |
+
config_class = MixFormerSequentialConfig
|
755 |
+
base_model_prefix = "transformer"
|
756 |
+
supports_gradient_checkpointing = True
|
757 |
+
|
758 |
+
def __init__(self, *inputs, **kwargs) -> None:
|
759 |
+
super().__init__(*inputs, **kwargs)
|
760 |
+
|
761 |
+
def _init_weights(self, module: nn.Module) -> None:
|
762 |
+
if isinstance(module, (nn.Linear,)):
|
763 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
764 |
+
if module.bias is not None:
|
765 |
+
module.bias.data.zero_()
|
766 |
+
elif isinstance(module, nn.Embedding):
|
767 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
768 |
+
if module.padding_idx is not None:
|
769 |
+
module.weight.data[module.padding_idx].zero_()
|
770 |
+
elif isinstance(module, nn.LayerNorm):
|
771 |
+
if module.bias is not None:
|
772 |
+
module.bias.data.zero_()
|
773 |
+
module.weight.data.fill_(1.0)
|
774 |
+
|
775 |
+
def prepare_inputs_for_generation(
|
776 |
+
self,
|
777 |
+
input_ids: torch.LongTensor,
|
778 |
+
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
779 |
+
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
780 |
+
**kwargs,
|
781 |
+
) -> Dict[str, Any]:
|
782 |
+
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
783 |
+
past_key_values = InferenceParams(
|
784 |
+
max_seqlen=self.config.n_positions,
|
785 |
+
max_batch_size=input_ids.shape[0],
|
786 |
+
seqlen_offset=0,
|
787 |
+
batch_size_offset=0,
|
788 |
+
key_value_memory_dict={},
|
789 |
+
lengths_per_sample=None,
|
790 |
+
)
|
791 |
+
else:
|
792 |
+
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
793 |
+
past_key_values.seqlen_offset = len(input_ids[0]) - 1
|
794 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
795 |
+
|
796 |
+
return {
|
797 |
+
"input_ids": input_ids,
|
798 |
+
"past_key_values": past_key_values,
|
799 |
+
"attention_mask": attention_mask,
|
800 |
+
}
|
801 |
+
|
802 |
+
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
|
803 |
+
if isinstance(module, MixFormerSequentialPreTrainedModel):
|
804 |
+
module.gradient_checkpointing = value
|
805 |
+
|
806 |
+
|
807 |
+
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
808 |
+
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
|
809 |
+
|
810 |
+
_keys_to_ignore_on_load_missing = [""]
|
811 |
+
_keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
812 |
+
_no_split_modules = ["ParallelBlock"]
|
813 |
+
|
814 |
+
def __init__(self, config: MixFormerSequentialConfig) -> None:
|
815 |
+
super().__init__(config)
|
816 |
+
|
817 |
+
modules = [Embedding(config)]
|
818 |
+
modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
|
819 |
+
modules.append(CausalLMHead(config))
|
820 |
+
|
821 |
+
self.layers = nn.Sequential(*modules)
|
822 |
+
self.loss = CausalLMLoss()
|
823 |
+
|
824 |
+
self.post_init()
|
825 |
+
|
826 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
827 |
+
return self.layers[0].wte
|
828 |
+
|
829 |
+
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
830 |
+
self.layers[0].wte = new_embeddings
|
831 |
+
|
832 |
+
def get_output_embeddings(self) -> nn.Linear:
|
833 |
+
return self.layers[-1].linear
|
834 |
+
|
835 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
836 |
+
self.layers[-1].linear = new_embeddings
|
837 |
+
|
838 |
+
def forward(
|
839 |
+
self,
|
840 |
+
input_ids: torch.LongTensor,
|
841 |
+
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
842 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
843 |
+
labels: Optional[torch.LongTensor] = None,
|
844 |
+
**kwargs,
|
845 |
+
) -> CausalLMOutputWithPast:
|
846 |
+
hidden_layer = self.layers[0](input_ids)
|
847 |
+
for module in self.layers[1:-1]:
|
848 |
+
hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
|
849 |
+
lm_logits = self.layers[-1](hidden_layer)
|
850 |
+
|
851 |
+
loss = None
|
852 |
+
if labels is not None:
|
853 |
+
loss = self.loss(lm_logits, labels)
|
854 |
+
|
855 |
+
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d121d287c708fc6d08043ed171921e4b9fb68d00f452c1d23ea1c55292bd1d5c
|
3 |
+
size 5673168010
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4acc51c52c33dccf606b129498bc828aef164175ad72fb3176ccedff193d49b0
|
3 |
+
size 4536
|