Commit
·
77d4d73
1
Parent(s):
fc863b3
Upload 14 files
Browse files- README.md +127 -132
- Research License.docx +0 -0
- added_tokens.json +40 -0
- config.json +4 -4
- configuration_mixformer_sequential.py +53 -0
- merges.txt +0 -0
- modeling_mixformer_sequential.py +222 -303
- pytorch_model.bin +2 -2
- special_tokens_map.json +5 -0
- tokenizer.json +0 -0
- tokenizer_config.json +9 -0
- vocab.json +0 -0
README.md
CHANGED
@@ -1,148 +1,143 @@
|
|
1 |
---
|
2 |
license: other
|
3 |
-
|
4 |
-
|
5 |
-
- generated_from_trainer
|
6 |
-
- sales
|
7 |
-
model-index:
|
8 |
-
- name: salesGPT_v2
|
9 |
-
results: []
|
10 |
-
datasets:
|
11 |
-
- goendalf666/sales-conversations-2
|
12 |
-
- goendalf666/sales-conversations-instruction-ext
|
13 |
-
- goendalf666/sales-conversations-instruction-base
|
14 |
-
- goendalf666/sales-textbook_for_convincing_and_selling
|
15 |
language:
|
16 |
- en
|
17 |
pipeline_tag: text-generation
|
18 |
---
|
|
|
19 |
|
20 |
-
|
21 |
-
should probably proofread and complete it, then remove this comment. -->
|
22 |
-
|
23 |
-
# salesGPT_v2
|
24 |
-
|
25 |
-
**Model Card for salesGPT_v2**
|
26 |
-
|
27 |
-
### Model Description
|
28 |
-
salesGPT_v2, derived from microsoft/phi-1_5, is specialized in simulating sales conversations, wherein it understands customer requirements, manages objections, and suggests suitable products or services. It was fine-tuned on a variety of sales-related datasets and seems proficient in initiating conversations, asking pertinent questions, and sustaining interactive dialogues with users.
|
29 |
-
|
30 |
-
### Related Ressources
|
31 |
-
|
32 |
-
Github: https://github.com/tom813/salesGPT_foundation
|
33 |
-
salesGPT_v1: https://huggingface.co/goendalf666/salesGPT_v1
|
34 |
-
|
35 |
-
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63797fcb2cb50dda39d8aec6/re7MmsaYNzTYVH2jEXDDu.png)
|
36 |
-
|
37 |
-
### Intended Uses & Limitations
|
38 |
-
**Intended Uses:**
|
39 |
-
- Simulating sales conversations for training or evaluation purposes.
|
40 |
-
- Providing guidelines or suggested dialogues for sales representatives.
|
41 |
-
|
42 |
-
**Limitations:**
|
43 |
-
- The model might repetitively ask questions in certain scenarios.
|
44 |
-
- May struggle with handling customers who lack specific preferences or knowledge about products.
|
45 |
-
- The objection handling could be more focused on convincing techniques rather than objective criteria.
|
46 |
-
- Challenges in providing appropriate suggestions for customers without specific needs.
|
47 |
-
- Limited effectiveness in handling financial and budgetary conversations or sensitivities.
|
48 |
-
|
49 |
-
### Training and Evaluation Data
|
50 |
-
**Training Data:**
|
51 |
-
1. **Textbook v1 Dataset**
|
52 |
-
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling)
|
53 |
-
- Content: Textbook content for sales, derived from structural points and detailed subpoints created through API calls.
|
54 |
-
|
55 |
-
2. **Sales Conversation Dataset**
|
56 |
-
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations)
|
57 |
-
- Content: Sales conversations, generated based on the chapters of the textbook.
|
58 |
-
|
59 |
-
3. **Sales Conversations Instruction Base Dataset**
|
60 |
-
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-base)
|
61 |
-
- Content: Extended sales conversations with structured dialogues.
|
62 |
-
|
63 |
-
4. **Sales Conversations Instruction Extension Dataset**
|
64 |
-
- URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-ext)
|
65 |
-
- Content: Updates based on real conversations with the model to improve its proficiency in unconvincing cases.
|
66 |
-
|
67 |
-
**Evaluation Data:**
|
68 |
-
- More information is needed regarding how and where the model was evaluated. If it was assessed on a separate test set, providing access and details to that dataset would be crucial.
|
69 |
-
|
70 |
-
### Training Procedure
|
71 |
-
Fine-tuning of salesGPT_v2 was executed in three phases using the LoRa approach with Rank 64:
|
72 |
-
1. Training on a textbook for 20k steps.
|
73 |
-
2. Training on sales conversations for 40k steps, resulting in salesGPT_v1.
|
74 |
-
3. Training on sales conversations instruction for 40k steps, evolving into salesGPT_v2.
|
75 |
-
|
76 |
-
Hyperparameters used during training:
|
77 |
-
- Learning rate: 0.0002
|
78 |
-
- Train batch size: 2
|
79 |
-
- Eval batch size: 8
|
80 |
-
- Seed: 42
|
81 |
-
- Optimizer: Adam
|
82 |
-
- LR scheduler type: Cosine
|
83 |
-
- Training steps: 40,000
|
84 |
-
- [Additional parameters...](https://huggingface.co/goendalf666/salesGPT_v2)
|
85 |
-
|
86 |
-
Model size details:
|
87 |
-
- Trainable params: 18,874,368
|
88 |
-
- All params: 1,437,145,088
|
89 |
-
- Trainable%: 1.313323766514519
|
90 |
-
|
91 |
-
### Results
|
92 |
-
While salesGPT_v2 showcases notable capabilities like initiating a structured dialogue and asking relevant questions, there are some key areas of improvement and observations:
|
93 |
-
- The model tends to keep asking questions if a user provides vague information.
|
94 |
-
- In some cases, it's not assertive enough to suggest a product when a customer lacks preferences.
|
95 |
-
- It tends to leverage objective criteria over convincing strategies during objections.
|
96 |
-
- Needs better handling of conversations regarding budgetary or financial concerns.
|
97 |
-
|
98 |
-
### Example Usage & Conversations
|
99 |
-
#### Example 1:
|
100 |
-
**Customer**: I am interested in buying a new laptop.
|
101 |
-
**salesGPT_v2**: I understand. Could you please tell me more about your specific requirements and budget?
|
102 |
-
|
103 |
-
#### Example 2:
|
104 |
-
**Customer**: Oh I don't know how I should explain this to my wife. I said we can't go to vacation because we don't have any money.
|
105 |
-
**salesGPT_v2**: I see. Can you please provide more details about your financial situation and the specific expenses you are concerned about?
|
106 |
-
|
107 |
-
These examples underline the model's capability to engage in sales dialogues and also hint at its limitations, such as possibly prying too much into a customer's personal financial situation.
|
108 |
-
|
109 |
-
### Future Directions & Improvement
|
110 |
-
- Enhancement in handling objections by focusing more on persuasive techniques and emotional intelligence.
|
111 |
-
- Improving suggestion-making capability especially when customers are indecisive.
|
112 |
-
- Better navigation through the conversation that involves budgetary and financial aspects without seeming insensitive or intrusive.
|
113 |
-
- Striking a balance between being informative and being too technical in its product descriptions.
|
114 |
-
- Possible implementation of more ethical and privacy-guided conversation guidelines, especially in discussing customers' financial capacities.
|
115 |
-
|
116 |
-
### Ethical Considerations
|
117 |
-
The model’s tendency to repeatedly ask for specific information, especially related to personal financial details, raises ethical concerns regarding privacy and data sensitivity. Care must be taken to ensure the model respects user privacy and does not persistently probe for personal or sensitive information.
|
118 |
-
|
119 |
-
### Conclusion
|
120 |
-
salesGPT_v2 offers a foundation for simulating sales conversations with potential for future refinement in handling objections, making product suggestions, and managing conversations delicately around financial discussions. Future versions might seek to refine its balance between being convincingly persuasive and remaining ethically and emotionally intelligent within dialogues.
|
121 |
-
|
122 |
-
### Inference
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
```
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
-
|
128 |
-
cuda = "cuda:0" if torch.cuda.is_available() else ""
|
129 |
-
model = AutoModelForCausalLM.from_pretrained("goendalf666/salesGPT_v2", trust_remote_code=True, torch_dtype=torch.float32, device_map={"":0})
|
130 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, device_map={"":0})
|
131 |
|
132 |
-
|
133 |
-
inputs.to(cuda)
|
134 |
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
```
|
139 |
-
Or
|
140 |
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: other
|
3 |
+
license_name: microsoft-research-license
|
4 |
+
license_link: https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
language:
|
6 |
- en
|
7 |
pipeline_tag: text-generation
|
8 |
---
|
9 |
+
## Model Summary
|
10 |
|
11 |
+
The language model phi-1.5 is a Transformer with **1.3 billion** parameters. It was trained using the same data sources as [phi-1](https://huggingface.co/microsoft/phi-1), augmented with a new data source that consists of various NLP synthetic texts. When assessed against benchmarks testing common sense, language understanding, and logical reasoning, phi-1.5 demonstrates a nearly state-of-the-art performance among models with less than 10 billion parameters.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
We **did not** fine-tune phi-1.5 either for **instruction following or through reinforcement learning from human feedback**. The intention behind crafting this open-source model is to provide the research community with a non-restricted small model to explore vital safety challenges, such as reducing toxicity, understanding societal biases, enhancing controllability, and more.
|
14 |
+
|
15 |
+
For a safer model release, we exclude generic web-crawl data sources such as common-crawl from the training. This strategy prevents direct exposure to potentially harmful online content, enhancing the model's safety without RLHF. However, the model is still vulnerable to generating harmful content. We hope the model can help the research community to further study the safety of language models.
|
16 |
+
|
17 |
+
phi-1.5 can write poems, draft emails, create stories, summarize texts, write Python code (such as downloading a Hugging Face transformer model), etc.
|
18 |
+
|
19 |
+
## Intended Uses
|
20 |
+
Given the nature of the training data, phi-1.5 is best suited for prompts using the QA format, the chat format, and the code format. Note that phi-1.5, being a base model, often produces irrelevant text following the main answer. In the following example, we've truncated the answer for illustrative purposes only.
|
21 |
+
|
22 |
+
#### QA format:
|
23 |
+
|
24 |
+
```markdown
|
25 |
+
Write a detailed analogy between mathematics and a lighthouse.
|
26 |
+
|
27 |
+
Answer: Mathematics is like a lighthouse, guiding us through the vast ocean of numbers and calculations. Just as a lighthouse illuminates the darkness, mathematics provides us with a clear path to navigate through complex problems. It helps us make sense of the world around us, just like a lighthouse helps ships find their way home.
|
28 |
```
|
29 |
+
where the model generates the text after "Answer:".
|
30 |
+
|
31 |
+
#### Chat format:
|
32 |
+
|
33 |
+
```markdown
|
34 |
+
Alice: I don't know why, I'm struggling to maintain focus while studying. Any suggestions?
|
35 |
+
|
36 |
+
Bob: Have you tried using a timer? It can help you stay on track and avoid distractions.
|
37 |
+
|
38 |
+
Alice: That's a good idea. I'll give it a try.
|
39 |
|
40 |
+
Charlie: Another thing that can help is to break up your study sessions into smaller chunks. It's easier to concentrate on one thing at a time.
|
|
|
|
|
|
|
41 |
|
42 |
+
Alice: That makes sense. I'll try that too.
|
|
|
43 |
|
44 |
+
Bob: And don't forget to take breaks! It's important to give your brain a rest so you can come back to your studies with a fresh perspective.
|
45 |
+
|
46 |
+
Alice: Thanks for the advice, guys. I feel more motivated now.
|
47 |
+
|
48 |
+
Charlie: No problem, Alice. We're all in this together.
|
49 |
+
|
50 |
+
Bob: Yeah, and remember that it's okay to ask for help if you need it. We're here to support each other.
|
51 |
+
```
|
52 |
+
where the model generates the text after the first "Bob:".
|
53 |
+
|
54 |
+
#### Code format:
|
55 |
+
```python
|
56 |
+
def print_prime(n):
|
57 |
+
"""
|
58 |
+
Print all primes between 1 and n
|
59 |
+
"""
|
60 |
+
primes = []
|
61 |
+
for num in range(2, n+1):
|
62 |
+
is_prime = True
|
63 |
+
for i in range(2, int(math.sqrt(num))+1):
|
64 |
+
if num % i == 0:
|
65 |
+
is_prime = False
|
66 |
+
break
|
67 |
+
if is_prime:
|
68 |
+
primes.append(num)
|
69 |
+
print(primes)
|
70 |
+
```
|
71 |
+
where the model generates the text after the comments.
|
72 |
+
|
73 |
+
**Notes**
|
74 |
+
* phi-1.5 is intended for research purposes. The model-generated text/code should be treated as a starting point rather than a definitive solution for potential use cases. Users should be cautious when employing these models in their applications.
|
75 |
+
* Direct adoption for production tasks is out of the scope of this research project. As a result, phi-1.5 has not been tested to ensure that it performs adequately for any production-level application. Please refer to the limitation sections of this document for more details.
|
76 |
+
|
77 |
+
## Limitations of phi-1.5
|
78 |
+
|
79 |
+
* Generate Inaccurate Code and Facts: The model often produces incorrect code snippets and statements. Users should treat these outputs as suggestions or starting points, not as definitive or accurate solutions.
|
80 |
+
* Limited Scope for code: If the model generates Python scripts that utilize uncommon packages or scripts in other languages, we strongly recommend users manually verify all API uses.
|
81 |
+
* Unreliable Responses to Instruction: The model has not undergone instruction fine-tuning. As a result, it may struggle or fail to adhere to intricate or nuanced instructions provided by users.
|
82 |
+
* Language Limitations: The model is primarily designed to understand standard English. Informal English, slang, or any other language outside of English might pose challenges to its comprehension, leading to potential misinterpretations or errors in response.
|
83 |
+
* Potential Societal Biases: Regardless of the safe data used for its training, the model is not entirely free from societal biases. There's a possibility it may generate content that mirrors these societal biases, particularly if prompted or instructed to do so. We urge users to be aware of this and to exercise caution and critical thinking when interpreting model outputs.
|
84 |
+
* Toxicity: Despite that the model is trained with carefully selected data, the model can still produce harmful content if explicitly prompted or instructed to do so. We chose to release the model for research purposes only -- We hope to help the open-source community develop the most effective ways to reduce the toxicity of a model directly after pretraining.
|
85 |
+
|
86 |
+
## Training
|
87 |
+
|
88 |
+
### Model
|
89 |
+
* Architecture: a Transformer-based model with next-word prediction objective
|
90 |
+
* Dataset size: 30B tokens
|
91 |
+
* Training tokens: 150B tokens
|
92 |
+
* Precision: fp16
|
93 |
+
* GPUs: 32xA100-40G
|
94 |
+
* Training time: 8 days
|
95 |
+
|
96 |
+
### Software
|
97 |
+
* [PyTorch](https://github.com/pytorch/pytorch)
|
98 |
+
* [DeepSpeed](https://github.com/microsoft/DeepSpeed)
|
99 |
+
* [flash-attention](https://github.com/HazyResearch/flash-attention)
|
100 |
+
|
101 |
+
### License
|
102 |
+
The model is licensed under the [Research License](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx).
|
103 |
+
|
104 |
+
### Sample Code
|
105 |
+
```python
|
106 |
+
import torch
|
107 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
108 |
+
|
109 |
+
torch.set_default_device("cuda")
|
110 |
+
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
|
112 |
+
inputs = tokenizer('''```python
|
113 |
+
def print_prime(n):
|
114 |
+
"""
|
115 |
+
Print all primes between 1 and n
|
116 |
+
"""''', return_tensors="pt", return_attention_mask=False)
|
117 |
+
|
118 |
+
outputs = model.generate(**inputs, max_length=200)
|
119 |
+
text = tokenizer.batch_decode(outputs)[0]
|
120 |
+
print(text)
|
121 |
```
|
|
|
122 |
|
123 |
+
If you need to use the model in a lower precision (e.g., FP16), please wrap the model's forward pass with `torch.autocast()`, as follows:
|
124 |
+
```python
|
125 |
+
with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
|
126 |
+
outputs = model.generate(**inputs, max_length=200)
|
127 |
+
```
|
128 |
+
|
129 |
+
**Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
|
130 |
+
Furthermore, in the forward pass of the model, we currently do not support attention mask during training, outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
|
131 |
+
|
132 |
+
### Citation
|
133 |
|
134 |
+
You can find the paper at https://arxiv.org/abs/2309.05463
|
135 |
|
136 |
+
```bib
|
137 |
+
@article{textbooks2,
|
138 |
+
title={Textbooks Are All You Need II: \textbf{phi-1.5} technical report},
|
139 |
+
author={Li, Yuanzhi and Bubeck, S{\'e}bastien and Eldan, Ronen and Del Giorno, Allie and Gunasekar, Suriya and Lee, Yin Tat},
|
140 |
+
journal={arXiv preprint arXiv:2309.05463},
|
141 |
+
year={2023}
|
142 |
+
}
|
143 |
+
```
|
Research License.docx
ADDED
Binary file (38.9 kB). View file
|
|
added_tokens.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"\t\t": 50294,
|
3 |
+
"\t\t\t": 50293,
|
4 |
+
"\t\t\t\t": 50292,
|
5 |
+
"\t\t\t\t\t": 50291,
|
6 |
+
"\t\t\t\t\t\t": 50290,
|
7 |
+
"\t\t\t\t\t\t\t": 50289,
|
8 |
+
"\t\t\t\t\t\t\t\t": 50288,
|
9 |
+
"\t\t\t\t\t\t\t\t\t": 50287,
|
10 |
+
" ": 50286,
|
11 |
+
" ": 50285,
|
12 |
+
" ": 50284,
|
13 |
+
" ": 50283,
|
14 |
+
" ": 50282,
|
15 |
+
" ": 50281,
|
16 |
+
" ": 50280,
|
17 |
+
" ": 50279,
|
18 |
+
" ": 50278,
|
19 |
+
" ": 50277,
|
20 |
+
" ": 50276,
|
21 |
+
" ": 50275,
|
22 |
+
" ": 50274,
|
23 |
+
" ": 50273,
|
24 |
+
" ": 50272,
|
25 |
+
" ": 50271,
|
26 |
+
" ": 50270,
|
27 |
+
" ": 50269,
|
28 |
+
" ": 50268,
|
29 |
+
" ": 50267,
|
30 |
+
" ": 50266,
|
31 |
+
" ": 50265,
|
32 |
+
" ": 50264,
|
33 |
+
" ": 50263,
|
34 |
+
" ": 50262,
|
35 |
+
" ": 50261,
|
36 |
+
" ": 50260,
|
37 |
+
" ": 50259,
|
38 |
+
" ": 50258,
|
39 |
+
" ": 50257
|
40 |
+
}
|
config.json
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"activation_function": "gelu_new",
|
4 |
"architectures": [
|
5 |
"MixFormerSequentialForCausalLM"
|
6 |
],
|
7 |
"auto_map": {
|
8 |
-
"AutoConfig": "
|
9 |
-
"AutoModelForCausalLM": "
|
10 |
},
|
11 |
"embd_pdrop": 0.0,
|
12 |
"initializer_range": 0.02,
|
@@ -20,7 +20,7 @@
|
|
20 |
"resid_pdrop": 0.0,
|
21 |
"rotary_dim": 32,
|
22 |
"tie_word_embeddings": false,
|
23 |
-
"torch_dtype": "
|
24 |
"transformers_version": "4.32.1",
|
25 |
"vocab_size": 51200
|
26 |
}
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "phi-1.5-half",
|
3 |
"activation_function": "gelu_new",
|
4 |
"architectures": [
|
5 |
"MixFormerSequentialForCausalLM"
|
6 |
],
|
7 |
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
|
9 |
+
"AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
|
10 |
},
|
11 |
"embd_pdrop": 0.0,
|
12 |
"initializer_range": 0.02,
|
|
|
20 |
"resid_pdrop": 0.0,
|
21 |
"rotary_dim": 32,
|
22 |
"tie_word_embeddings": false,
|
23 |
+
"torch_dtype": "float16",
|
24 |
"transformers_version": "4.32.1",
|
25 |
"vocab_size": 51200
|
26 |
}
|
configuration_mixformer_sequential.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT license.
|
3 |
+
|
4 |
+
import math
|
5 |
+
from typing import Any, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
from transformers import PretrainedConfig
|
8 |
+
|
9 |
+
|
10 |
+
class MixFormerSequentialConfig(PretrainedConfig):
|
11 |
+
"""MixFormer (sequential for DeepSpeed) configuration."""
|
12 |
+
|
13 |
+
model_type = "mixformer-sequential"
|
14 |
+
|
15 |
+
attribute_map = {
|
16 |
+
"max_position_embeddings": "n_positions",
|
17 |
+
"hidden_size": "n_embd",
|
18 |
+
"num_attention_heads": "n_head",
|
19 |
+
"num_hidden_layers": "n_layer",
|
20 |
+
}
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
vocab_size: Optional[int] = 50304,
|
25 |
+
n_positions: Optional[int] = 2048,
|
26 |
+
n_embd: Optional[int] = 1024,
|
27 |
+
n_layer: Optional[int] = 20,
|
28 |
+
n_inner: Optional[int] = None,
|
29 |
+
n_head: Optional[int] = 16,
|
30 |
+
rotary_dim: Optional[int] = 32,
|
31 |
+
activation_function: Optional[str] = "gelu_new",
|
32 |
+
embd_pdrop: Optional[float] = 0.0,
|
33 |
+
resid_pdrop: Optional[float] = 0.0,
|
34 |
+
layer_norm_epsilon: Optional[float] = 1e-5,
|
35 |
+
initializer_range: Optional[float] = 0.02,
|
36 |
+
tie_word_embeddings: Optional[bool] = False,
|
37 |
+
pad_vocab_size_multiple: Optional[int] = 64,
|
38 |
+
**kwargs
|
39 |
+
) -> None:
|
40 |
+
self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
41 |
+
self.n_positions = n_positions
|
42 |
+
self.n_embd = n_embd
|
43 |
+
self.n_layer = n_layer
|
44 |
+
self.n_inner = n_inner
|
45 |
+
self.n_head = n_head
|
46 |
+
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
47 |
+
self.activation_function = activation_function
|
48 |
+
self.embd_pdrop = embd_pdrop
|
49 |
+
self.resid_pdrop = resid_pdrop
|
50 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
51 |
+
self.initializer_range = initializer_range
|
52 |
+
|
53 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_mixformer_sequential.py
CHANGED
@@ -34,28 +34,20 @@
|
|
34 |
from __future__ import annotations
|
35 |
|
36 |
import math
|
|
|
37 |
from typing import Any, Dict, Optional, Tuple, Union
|
38 |
from dataclasses import dataclass, field
|
39 |
|
40 |
import torch
|
41 |
import torch.nn as nn
|
42 |
|
43 |
-
from einops import rearrange
|
44 |
from transformers.activations import ACT2FN
|
45 |
from transformers import PretrainedConfig, PreTrainedModel
|
46 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
47 |
|
48 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
49 |
|
50 |
-
|
51 |
-
try:
|
52 |
-
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
53 |
-
from flash_attn.ops.fused_dense import FusedDense
|
54 |
-
except:
|
55 |
-
FlashRotaryEmbedding = None
|
56 |
-
FusedDense = None
|
57 |
-
|
58 |
-
|
59 |
@dataclass
|
60 |
class InferenceParams:
|
61 |
"""Inference parameters passed to model to efficiently calculate
|
@@ -65,20 +57,21 @@ class InferenceParams:
|
|
65 |
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
|
66 |
|
67 |
Args:
|
68 |
-
|
69 |
max_batch_size: Maximum batch size.
|
70 |
-
|
71 |
batch_size_offset: Batch size offset.
|
72 |
key_value_memory_dict: Key value memory dictionary.
|
|
|
73 |
lengths_per_sample: Lengths per sample.
|
74 |
|
75 |
"""
|
76 |
|
77 |
-
|
78 |
|
79 |
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
|
80 |
|
81 |
-
|
82 |
|
83 |
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
|
84 |
|
@@ -86,6 +79,8 @@ class InferenceParams:
|
|
86 |
default_factory=dict, metadata={"help": "Key value memory dictionary."}
|
87 |
)
|
88 |
|
|
|
|
|
89 |
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
|
90 |
|
91 |
|
@@ -108,112 +103,12 @@ class Embedding(nn.Module):
|
|
108 |
return hidden_states
|
109 |
|
110 |
|
111 |
-
def _apply_rotary_emb(
|
112 |
-
x: torch.FloatTensor,
|
113 |
-
cos: torch.FloatTensor,
|
114 |
-
sin: torch.FloatTensor,
|
115 |
-
) -> torch.FloatTensor:
|
116 |
-
_, seqlen, _, head_dim = x.shape
|
117 |
-
rotary_seqlen, rotary_dim = cos.shape
|
118 |
-
rotary_dim *= 2
|
119 |
-
|
120 |
-
assert rotary_dim <= head_dim
|
121 |
-
assert seqlen <= rotary_seqlen
|
122 |
-
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
123 |
-
|
124 |
-
x_rot = x[:, :, :, :rotary_dim]
|
125 |
-
x_pass = x[:, :, :, rotary_dim:]
|
126 |
-
|
127 |
-
x1, x2 = x_rot.chunk(2, dim=-1)
|
128 |
-
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
129 |
-
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
|
130 |
-
|
131 |
-
x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
|
132 |
-
|
133 |
-
return torch.cat([x_rot, x_pass], axis=-1)
|
134 |
-
|
135 |
-
|
136 |
-
def _apply_rotary_emb_kv(
|
137 |
-
kv: torch.FloatTensor,
|
138 |
-
cos: torch.FloatTensor,
|
139 |
-
sin: torch.FloatTensor,
|
140 |
-
cos_k: Optional[torch.FloatTensor] = None,
|
141 |
-
sin_k: Optional[torch.FloatTensor] = None,
|
142 |
-
) -> torch.FloatTensor:
|
143 |
-
_, seqlen, two, _, head_dim = kv.shape
|
144 |
-
assert two == 2
|
145 |
-
|
146 |
-
rotary_seqlen, rotary_dim = cos.shape
|
147 |
-
rotary_dim *= 2
|
148 |
-
assert rotary_dim <= head_dim
|
149 |
-
assert seqlen <= rotary_seqlen
|
150 |
-
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
151 |
-
|
152 |
-
k_rot = kv[:, :, 0, :, :rotary_dim]
|
153 |
-
k_pass = kv[:, :, 0, :, rotary_dim:]
|
154 |
-
|
155 |
-
k1, k2 = k_rot.chunk(2, dim=-1)
|
156 |
-
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
157 |
-
k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
|
158 |
-
|
159 |
-
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
|
160 |
-
|
161 |
-
return torch.cat(
|
162 |
-
[
|
163 |
-
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
164 |
-
kv[:, :, 1:2, :, :],
|
165 |
-
],
|
166 |
-
axis=2,
|
167 |
-
)
|
168 |
-
|
169 |
-
|
170 |
-
def _apply_rotary_emb_qkv(
|
171 |
-
qkv: torch.FloatTensor,
|
172 |
-
cos: torch.FloatTensor,
|
173 |
-
sin: torch.FloatTensor,
|
174 |
-
cos_k: Optional[torch.FloatTensor] = None,
|
175 |
-
sin_k: Optional[torch.FloatTensor] = None,
|
176 |
-
) -> torch.FloatTensor:
|
177 |
-
_, seqlen, three, _, head_dim = qkv.shape
|
178 |
-
assert three == 3
|
179 |
-
|
180 |
-
rotary_seqlen, rotary_dim = cos.shape
|
181 |
-
rotary_dim *= 2
|
182 |
-
assert rotary_dim <= head_dim
|
183 |
-
assert seqlen <= rotary_seqlen
|
184 |
-
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
185 |
-
|
186 |
-
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
187 |
-
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
188 |
-
|
189 |
-
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
190 |
-
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
191 |
-
|
192 |
-
q1, q2 = q_rot.chunk(2, dim=-1)
|
193 |
-
k1, k2 = k_rot.chunk(2, dim=-1)
|
194 |
-
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
195 |
-
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
196 |
-
|
197 |
-
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
198 |
-
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
199 |
-
|
200 |
-
return torch.cat(
|
201 |
-
[
|
202 |
-
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
203 |
-
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
204 |
-
qkv[:, :, 2:3, :, :],
|
205 |
-
],
|
206 |
-
axis=2,
|
207 |
-
)
|
208 |
-
|
209 |
-
|
210 |
class RotaryEmbedding(nn.Module):
|
211 |
-
"""Rotary
|
212 |
|
213 |
Reference:
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
"""
|
218 |
|
219 |
def __init__(
|
@@ -221,7 +116,6 @@ class RotaryEmbedding(nn.Module):
|
|
221 |
dim: int,
|
222 |
base: int = 10000,
|
223 |
scale_base: Optional[float] = None,
|
224 |
-
pos_idx_in_fp32: bool = True,
|
225 |
device: Optional[str] = None,
|
226 |
**kwargs,
|
227 |
) -> None:
|
@@ -230,23 +124,21 @@ class RotaryEmbedding(nn.Module):
|
|
230 |
if scale_base is not None:
|
231 |
raise NotImplementedError
|
232 |
|
|
|
233 |
self.dim = dim
|
234 |
-
self.base =
|
235 |
self.scale_base = scale_base
|
236 |
-
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
237 |
self.device = device
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
242 |
|
243 |
-
# Generate and save the scale buffer (non-trainable)
|
244 |
scale = (
|
245 |
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
246 |
if scale_base is not None
|
247 |
else None
|
248 |
)
|
249 |
-
self.register_buffer("scale", scale
|
250 |
|
251 |
self._seq_len_cached = 0
|
252 |
self._cos_cached = None
|
@@ -254,73 +146,91 @@ class RotaryEmbedding(nn.Module):
|
|
254 |
self._cos_k_cached = None
|
255 |
self._sin_k_cached = None
|
256 |
|
257 |
-
def
|
258 |
-
|
|
|
|
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
seqlen > self._seq_len_cached
|
267 |
-
or self._cos_cached is None
|
268 |
-
or self._cos_cached.device != device
|
269 |
-
or self._cos_cached.dtype != dtype
|
270 |
-
or (self.training and self._cos_cached.is_inference())
|
271 |
-
):
|
272 |
-
self._seq_len_cached = seqlen
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
278 |
-
if self.inv_freq.dtype != torch.float32:
|
279 |
-
inv_freq = self._compute_inv_freq(device=device)
|
280 |
-
else:
|
281 |
-
inv_freq = self.inv_freq
|
282 |
-
else:
|
283 |
-
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
284 |
-
inv_freq = self.inv_freq
|
285 |
|
286 |
-
#
|
287 |
-
freqs = torch.
|
|
|
288 |
if self.scale is None:
|
289 |
-
self._cos_cached = torch.cos(freqs).to(dtype)
|
290 |
-
self._sin_cached = torch.sin(freqs).to(dtype)
|
291 |
else:
|
292 |
power = (
|
293 |
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
294 |
) / self.scale_base
|
295 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
296 |
|
297 |
-
#
|
298 |
-
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
299 |
-
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
300 |
-
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
301 |
-
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
302 |
|
303 |
-
def
|
304 |
self,
|
305 |
-
qkv: torch.
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
else
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
-
|
|
|
|
|
|
|
324 |
|
325 |
|
326 |
class MLP(nn.Module):
|
@@ -380,22 +290,21 @@ class SelfAttention(nn.Module):
|
|
380 |
attention_mask: Optional[torch.BoolTensor] = None,
|
381 |
**kwargs,
|
382 |
) -> torch.FloatTensor:
|
383 |
-
|
|
|
384 |
q, k, v = qkv.unbind(dim=2)
|
385 |
|
386 |
-
causal = self.causal if causal is None else causal
|
387 |
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
388 |
-
|
389 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
390 |
|
391 |
if attention_mask is not None:
|
392 |
-
padding_mask = torch.full((batch_size,
|
393 |
padding_mask.masked_fill_(attention_mask, 0.0)
|
394 |
|
395 |
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
396 |
|
397 |
if causal:
|
398 |
-
causal_mask = torch.triu(torch.full((
|
399 |
scores = scores + causal_mask.to(dtype=scores.dtype)
|
400 |
|
401 |
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
@@ -434,31 +343,25 @@ class CrossAttention(nn.Module):
|
|
434 |
attention_mask: Optional[torch.BoolTensor] = None,
|
435 |
**kwargs,
|
436 |
) -> torch.FloatTensor:
|
437 |
-
|
438 |
-
|
439 |
-
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
440 |
|
441 |
-
|
442 |
-
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
443 |
k, v = kv.unbind(dim=2)
|
444 |
|
445 |
-
causal = self.causal if causal is None else causal
|
446 |
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
447 |
-
|
448 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
449 |
|
450 |
if attention_mask is not None:
|
451 |
-
padding_mask = torch.full((batch_size,
|
452 |
padding_mask.masked_fill_(attention_mask, 0.0)
|
453 |
|
454 |
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
455 |
|
456 |
if causal:
|
457 |
-
|
458 |
-
|
459 |
-
causal_mask = cols > rows + seqlen_k - seqlen_q
|
460 |
-
|
461 |
-
scores = scores.masked_fill(causal_mask, -10000.0)
|
462 |
|
463 |
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
464 |
attention = self.drop(attention)
|
@@ -468,12 +371,21 @@ class CrossAttention(nn.Module):
|
|
468 |
return output
|
469 |
|
470 |
|
471 |
-
def
|
472 |
-
config: PretrainedConfig,
|
473 |
-
n_head: Optional[int] = None,
|
474 |
-
n_head_kv: Optional[int] = None,
|
475 |
-
head_dim: Optional[int] = None,
|
476 |
) -> Tuple[int, int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
assert all(
|
478 |
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
479 |
), "`config` must have `n_embd` and `n_head` attributes."
|
@@ -489,20 +401,31 @@ def _find_mha_dims(
|
|
489 |
elif n_head is None or head_dim is None:
|
490 |
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
491 |
|
492 |
-
|
493 |
-
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
494 |
-
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
|
495 |
|
496 |
-
return n_head, n_head_kv, head_dim
|
497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
|
499 |
-
def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
|
500 |
num_heads, head_dim = kv.shape[-2:]
|
501 |
|
502 |
if layer_idx not in inference_params.key_value_memory_dict:
|
503 |
kv_cache = torch.empty(
|
504 |
inference_params.max_batch_size,
|
505 |
-
inference_params.
|
506 |
2,
|
507 |
num_heads,
|
508 |
head_dim,
|
@@ -511,19 +434,43 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
511 |
)
|
512 |
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
513 |
else:
|
514 |
-
|
|
|
|
|
|
|
|
|
515 |
|
516 |
batch_start = inference_params.batch_size_offset
|
517 |
batch_end = batch_start + kv.shape[0]
|
518 |
-
assert batch_end <= kv_cache.shape[0]
|
519 |
|
520 |
-
sequence_start = inference_params.
|
521 |
sequence_end = sequence_start + kv.shape[1]
|
522 |
-
assert sequence_end <= kv_cache.shape[1]
|
523 |
|
524 |
-
|
525 |
-
|
526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
|
528 |
return kv
|
529 |
|
@@ -539,11 +486,11 @@ class MHA(nn.Module):
|
|
539 |
rotary_dim: Optional[int] = None,
|
540 |
rotary_emb_scale_base: Optional[float] = None,
|
541 |
n_head: Optional[int] = None,
|
542 |
-
n_head_kv: Optional[int] = None,
|
543 |
head_dim: Optional[int] = None,
|
544 |
bias: bool = True,
|
545 |
causal: bool = True,
|
546 |
softmax_scale: Optional[float] = None,
|
|
|
547 |
layer_idx: Optional[int] = None,
|
548 |
return_residual: bool = False,
|
549 |
checkpointing: bool = False,
|
@@ -556,101 +503,58 @@ class MHA(nn.Module):
|
|
556 |
rotary_kwargs = {"device": device}
|
557 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
558 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
559 |
-
|
560 |
-
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
561 |
-
if rotary_cls is None:
|
562 |
-
rotary_cls = RotaryEmbedding
|
563 |
-
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
564 |
|
565 |
# MLP
|
566 |
-
self.n_head, self.
|
567 |
-
op_size = self.
|
568 |
hidden_size = config.n_embd
|
569 |
|
570 |
-
|
571 |
-
|
572 |
-
linear_cls = nn.Linear
|
573 |
-
|
574 |
-
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
575 |
-
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
576 |
|
577 |
# Attention
|
578 |
-
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=
|
579 |
-
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=
|
580 |
|
581 |
self.layer_idx = layer_idx
|
582 |
self.return_residual = return_residual
|
583 |
self.checkpointing = checkpointing
|
584 |
|
585 |
-
def
|
586 |
-
self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
|
587 |
-
) -> torch.FloatTensor:
|
588 |
-
qkv = self.Wqkv(x)
|
589 |
-
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
590 |
-
|
591 |
-
if self.rotary_emb_dim > 0:
|
592 |
-
qkv = self.rotary_emb(qkv)
|
593 |
-
|
594 |
-
if self.checkpointing:
|
595 |
-
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
|
596 |
-
|
597 |
-
return self.inner_attn(qkv, attention_mask=attention_mask)
|
598 |
-
|
599 |
-
def _forward_cross_attn(
|
600 |
self,
|
601 |
x: torch.FloatTensor,
|
602 |
-
past_key_values: Optional[InferenceParams],
|
603 |
-
attention_mask: Optional[torch.BoolTensor],
|
604 |
-
|
|
|
|
|
|
|
605 |
qkv = self.Wqkv(x)
|
|
|
606 |
|
607 |
-
|
608 |
-
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
609 |
-
|
610 |
-
kv = qkv[..., self.n_head * self.head_dim :]
|
611 |
-
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
612 |
-
|
613 |
-
seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
|
614 |
-
causal = None if seqlen_offset == 0 else False
|
615 |
if self.rotary_emb_dim > 0:
|
616 |
-
|
617 |
|
618 |
if past_key_values is not None:
|
619 |
-
kv =
|
620 |
|
621 |
-
if
|
622 |
-
|
623 |
-
|
624 |
-
)
|
625 |
|
626 |
-
|
627 |
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
past_key_values: Optional[InferenceParams] = None,
|
632 |
-
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
633 |
-
**kwargs,
|
634 |
-
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
635 |
-
if attention_mask is not None and torch.any(~attention_mask.bool()):
|
636 |
-
attention_mask = attention_mask.bool()
|
637 |
-
else:
|
638 |
-
attention_mask = None
|
639 |
-
|
640 |
-
# MHA
|
641 |
-
if self.n_head == self.n_head_kv:
|
642 |
-
if past_key_values is None:
|
643 |
-
# If `past_key_values` are not supplied, we run self-attention
|
644 |
-
attn_output = self._forward_self_attn(x, attention_mask)
|
645 |
else:
|
646 |
-
|
647 |
-
# could take advantage of cross-attention
|
648 |
-
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
649 |
-
# MQA / GQA
|
650 |
else:
|
651 |
-
|
652 |
-
|
653 |
-
attn_output = self.
|
654 |
|
655 |
output = rearrange(attn_output, "... h d -> ... (h d)")
|
656 |
output = self.out_proj(output)
|
@@ -768,29 +672,38 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
|
768 |
if module.padding_idx is not None:
|
769 |
module.weight.data[module.padding_idx].zero_()
|
770 |
elif isinstance(module, nn.LayerNorm):
|
771 |
-
|
772 |
-
module.bias.data.zero_()
|
773 |
module.weight.data.fill_(1.0)
|
774 |
|
775 |
def prepare_inputs_for_generation(
|
776 |
self,
|
777 |
input_ids: torch.LongTensor,
|
778 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
779 |
-
attention_mask: Optional[
|
780 |
**kwargs,
|
781 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
782 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
783 |
past_key_values = InferenceParams(
|
784 |
-
max_seqlen=self.config.n_positions,
|
785 |
max_batch_size=input_ids.shape[0],
|
786 |
-
|
|
|
787 |
batch_size_offset=0,
|
|
|
788 |
key_value_memory_dict={},
|
789 |
-
lengths_per_sample=None,
|
790 |
)
|
791 |
else:
|
792 |
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
793 |
-
past_key_values.
|
794 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
795 |
|
796 |
return {
|
@@ -799,9 +712,9 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
|
799 |
"attention_mask": attention_mask,
|
800 |
}
|
801 |
|
802 |
-
def _set_gradient_checkpointing(self, module
|
803 |
-
|
804 |
-
|
805 |
|
806 |
|
807 |
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
@@ -843,13 +756,19 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
|
843 |
labels: Optional[torch.LongTensor] = None,
|
844 |
**kwargs,
|
845 |
) -> CausalLMOutputWithPast:
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
850 |
|
851 |
loss = None
|
852 |
if labels is not None:
|
853 |
loss = self.loss(lm_logits, labels)
|
854 |
|
855 |
-
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
|
|
|
34 |
from __future__ import annotations
|
35 |
|
36 |
import math
|
37 |
+
import copy
|
38 |
from typing import Any, Dict, Optional, Tuple, Union
|
39 |
from dataclasses import dataclass, field
|
40 |
|
41 |
import torch
|
42 |
import torch.nn as nn
|
43 |
|
44 |
+
from einops import rearrange
|
45 |
from transformers.activations import ACT2FN
|
46 |
from transformers import PretrainedConfig, PreTrainedModel
|
47 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
48 |
|
49 |
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
@dataclass
|
52 |
class InferenceParams:
|
53 |
"""Inference parameters passed to model to efficiently calculate
|
|
|
57 |
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
|
58 |
|
59 |
Args:
|
60 |
+
max_sequence_len: Maximum sequence length.
|
61 |
max_batch_size: Maximum batch size.
|
62 |
+
sequence_len_offset: Sequence length offset.
|
63 |
batch_size_offset: Batch size offset.
|
64 |
key_value_memory_dict: Key value memory dictionary.
|
65 |
+
fused_ft_kernel: Whether to use fused kernel for fast inference.
|
66 |
lengths_per_sample: Lengths per sample.
|
67 |
|
68 |
"""
|
69 |
|
70 |
+
max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
|
71 |
|
72 |
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
|
73 |
|
74 |
+
sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
|
75 |
|
76 |
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
|
77 |
|
|
|
79 |
default_factory=dict, metadata={"help": "Key value memory dictionary."}
|
80 |
)
|
81 |
|
82 |
+
fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
|
83 |
+
|
84 |
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
|
85 |
|
86 |
|
|
|
103 |
return hidden_states
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
class RotaryEmbedding(nn.Module):
|
107 |
+
"""Rotary embeddings.
|
108 |
|
109 |
Reference:
|
110 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
|
111 |
+
|
|
|
112 |
"""
|
113 |
|
114 |
def __init__(
|
|
|
116 |
dim: int,
|
117 |
base: int = 10000,
|
118 |
scale_base: Optional[float] = None,
|
|
|
119 |
device: Optional[str] = None,
|
120 |
**kwargs,
|
121 |
) -> None:
|
|
|
124 |
if scale_base is not None:
|
125 |
raise NotImplementedError
|
126 |
|
127 |
+
# Generate and save the inverse frequency buffer (non-trainable)
|
128 |
self.dim = dim
|
129 |
+
self.base = base
|
130 |
self.scale_base = scale_base
|
|
|
131 |
self.device = device
|
132 |
|
133 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
134 |
+
self.register_buffer("inv_freq", inv_freq)
|
|
|
135 |
|
|
|
136 |
scale = (
|
137 |
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
138 |
if scale_base is not None
|
139 |
else None
|
140 |
)
|
141 |
+
self.register_buffer("scale", scale)
|
142 |
|
143 |
self._seq_len_cached = 0
|
144 |
self._cos_cached = None
|
|
|
146 |
self._cos_k_cached = None
|
147 |
self._sin_k_cached = None
|
148 |
|
149 |
+
def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
|
150 |
+
# Reset the tables if the sequence length has changed,
|
151 |
+
# or if we're on a new device (possibly due to tracing for instance)
|
152 |
+
seqlen = x.shape[1] + seqlen_offset
|
153 |
|
154 |
+
# Re-generate the inverse frequency buffer if it's not fp32
|
155 |
+
# (for instance if model.half() was called)
|
156 |
+
if self.inv_freq.dtype != "torch.float32":
|
157 |
+
self.inv_freq = 1.0 / (
|
158 |
+
self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
|
159 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
+
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
162 |
+
self._seq_len_cached = seqlen
|
163 |
+
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
# Don't do einsum, it converts fp32 to fp16
|
166 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
167 |
+
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
|
168 |
if self.scale is None:
|
169 |
+
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
170 |
+
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
171 |
else:
|
172 |
power = (
|
173 |
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
174 |
) / self.scale_base
|
175 |
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
176 |
|
177 |
+
# We want the multiplication by scale to happen in fp32
|
178 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
179 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
180 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
181 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
182 |
|
183 |
+
def _apply_rotary_emb_qkv(
|
184 |
self,
|
185 |
+
qkv: torch.FloatTensor,
|
186 |
+
sin: torch.FloatTensor,
|
187 |
+
cos: torch.FloatTensor,
|
188 |
+
sin_k: Optional[torch.FloatTensor] = None,
|
189 |
+
cos_k: Optional[torch.FloatTensor] = None,
|
190 |
+
) -> torch.FloatTensor:
|
191 |
+
_, seqlen, three, _, headdim = qkv.shape
|
192 |
+
assert three == 3
|
193 |
+
|
194 |
+
rotary_seqlen, rotary_dim = cos.shape
|
195 |
+
rotary_dim *= 2
|
196 |
+
assert rotary_dim <= headdim
|
197 |
+
assert seqlen <= rotary_seqlen
|
198 |
+
|
199 |
+
cos_k = cos if cos_k is None else cos_k
|
200 |
+
sin_k = sin if sin_k is None else sin_k
|
201 |
+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
202 |
+
|
203 |
+
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
204 |
+
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
205 |
+
|
206 |
+
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
207 |
+
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
208 |
+
|
209 |
+
# Splits the queries and keys in half
|
210 |
+
q1, q2 = q_rot.chunk(2, dim=-1)
|
211 |
+
k1, k2 = k_rot.chunk(2, dim=-1)
|
212 |
+
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
213 |
+
|
214 |
+
# Casts to fp32 are necessary to prevent fp16 overflow issues
|
215 |
+
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
216 |
+
|
217 |
+
# Computes the new keys and queries, recasting to original dtype
|
218 |
+
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
219 |
+
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
220 |
+
|
221 |
+
return torch.cat(
|
222 |
+
[
|
223 |
+
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
224 |
+
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
225 |
+
qkv[:, :, 2:3, :, :],
|
226 |
+
],
|
227 |
+
axis=2,
|
228 |
+
)
|
229 |
|
230 |
+
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
231 |
+
# `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
|
232 |
+
self._update_cos_sin_cache(qkv, seqlen_offset)
|
233 |
+
return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
|
234 |
|
235 |
|
236 |
class MLP(nn.Module):
|
|
|
290 |
attention_mask: Optional[torch.BoolTensor] = None,
|
291 |
**kwargs,
|
292 |
) -> torch.FloatTensor:
|
293 |
+
causal = self.causal if causal is None else causal
|
294 |
+
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
|
295 |
q, k, v = qkv.unbind(dim=2)
|
296 |
|
|
|
297 |
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
|
|
298 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
299 |
|
300 |
if attention_mask is not None:
|
301 |
+
padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
|
302 |
padding_mask.masked_fill_(attention_mask, 0.0)
|
303 |
|
304 |
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
305 |
|
306 |
if causal:
|
307 |
+
causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
|
308 |
scores = scores + causal_mask.to(dtype=scores.dtype)
|
309 |
|
310 |
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
|
|
343 |
attention_mask: Optional[torch.BoolTensor] = None,
|
344 |
**kwargs,
|
345 |
) -> torch.FloatTensor:
|
346 |
+
causal = self.causal if causal is None else causal
|
347 |
+
batch_size, seq_len_q = q.shape[0], q.shape[1]
|
348 |
+
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
349 |
|
350 |
+
seq_len_k = kv.shape[1]
|
|
|
351 |
k, v = kv.unbind(dim=2)
|
352 |
|
|
|
353 |
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
|
|
354 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
355 |
|
356 |
if attention_mask is not None:
|
357 |
+
padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
|
358 |
padding_mask.masked_fill_(attention_mask, 0.0)
|
359 |
|
360 |
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
361 |
|
362 |
if causal:
|
363 |
+
causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
|
364 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
|
|
|
|
|
|
365 |
|
366 |
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
367 |
attention = self.drop(attention)
|
|
|
371 |
return output
|
372 |
|
373 |
|
374 |
+
def find_mha_dims(
|
375 |
+
config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
|
|
|
|
|
|
|
376 |
) -> Tuple[int, int]:
|
377 |
+
"""Validate and return the number of heads and head dimension for multi-head attention.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
config: Model configuration.
|
381 |
+
n_head: Number of heads.
|
382 |
+
head_dim: Head dimension.
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
Number of heads and head dimension.
|
386 |
+
|
387 |
+
"""
|
388 |
+
|
389 |
assert all(
|
390 |
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
391 |
), "`config` must have `n_embd` and `n_head` attributes."
|
|
|
401 |
elif n_head is None or head_dim is None:
|
402 |
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
403 |
|
404 |
+
return n_head, head_dim
|
|
|
|
|
405 |
|
|
|
406 |
|
407 |
+
def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
|
408 |
+
"""Update the key-value cache for inference.
|
409 |
+
|
410 |
+
Reference:
|
411 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
kv: Key-value tensor.
|
415 |
+
inference_params: Inference parameters.
|
416 |
+
layer_idx: Layer index.
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
Updated key-value tensor.
|
420 |
+
|
421 |
+
"""
|
422 |
|
|
|
423 |
num_heads, head_dim = kv.shape[-2:]
|
424 |
|
425 |
if layer_idx not in inference_params.key_value_memory_dict:
|
426 |
kv_cache = torch.empty(
|
427 |
inference_params.max_batch_size,
|
428 |
+
inference_params.max_sequence_len,
|
429 |
2,
|
430 |
num_heads,
|
431 |
head_dim,
|
|
|
434 |
)
|
435 |
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
436 |
else:
|
437 |
+
if not inference_params.fused_ft_kernel:
|
438 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
439 |
+
else:
|
440 |
+
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
|
441 |
+
kv_cache = None
|
442 |
|
443 |
batch_start = inference_params.batch_size_offset
|
444 |
batch_end = batch_start + kv.shape[0]
|
445 |
+
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
446 |
|
447 |
+
sequence_start = inference_params.sequence_len_offset
|
448 |
sequence_end = sequence_start + kv.shape[1]
|
449 |
+
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
450 |
|
451 |
+
if not inference_params.fused_ft_kernel:
|
452 |
+
assert kv_cache is not None
|
453 |
+
|
454 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
455 |
+
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
456 |
+
|
457 |
+
return kv
|
458 |
+
|
459 |
+
assert inference_params.sequence_len_offset == 0
|
460 |
+
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
461 |
+
|
462 |
+
packsize = 4 if kv.dtype == torch.float32 else 8
|
463 |
+
|
464 |
+
if kv_cache is not None:
|
465 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
466 |
+
k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
|
467 |
+
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
|
468 |
+
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
|
469 |
+
else:
|
470 |
+
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
471 |
+
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
|
472 |
+
)
|
473 |
+
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
|
474 |
|
475 |
return kv
|
476 |
|
|
|
486 |
rotary_dim: Optional[int] = None,
|
487 |
rotary_emb_scale_base: Optional[float] = None,
|
488 |
n_head: Optional[int] = None,
|
|
|
489 |
head_dim: Optional[int] = None,
|
490 |
bias: bool = True,
|
491 |
causal: bool = True,
|
492 |
softmax_scale: Optional[float] = None,
|
493 |
+
dropout: float = 0.0,
|
494 |
layer_idx: Optional[int] = None,
|
495 |
return_residual: bool = False,
|
496 |
checkpointing: bool = False,
|
|
|
503 |
rotary_kwargs = {"device": device}
|
504 |
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
505 |
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
506 |
+
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
|
|
|
|
|
|
|
|
|
507 |
|
508 |
# MLP
|
509 |
+
self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
|
510 |
+
op_size = self.n_head * self.head_dim
|
511 |
hidden_size = config.n_embd
|
512 |
|
513 |
+
self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
|
514 |
+
self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
515 |
|
516 |
# Attention
|
517 |
+
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
518 |
+
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
519 |
|
520 |
self.layer_idx = layer_idx
|
521 |
self.return_residual = return_residual
|
522 |
self.checkpointing = checkpointing
|
523 |
|
524 |
+
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
self,
|
526 |
x: torch.FloatTensor,
|
527 |
+
past_key_values: Optional[InferenceParams] = None,
|
528 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
529 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
530 |
+
max_seqlen: Optional[int] = None,
|
531 |
+
**kwargs,
|
532 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
533 |
qkv = self.Wqkv(x)
|
534 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
535 |
|
536 |
+
seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
if self.rotary_emb_dim > 0:
|
538 |
+
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
539 |
|
540 |
if past_key_values is not None:
|
541 |
+
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
542 |
|
543 |
+
if attention_mask is not None:
|
544 |
+
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
|
545 |
+
attention_mask = attention_mask.bool().to(qkv.device)
|
|
|
546 |
|
547 |
+
attention_kwargs = {"attention_mask": attention_mask}
|
548 |
|
549 |
+
if past_key_values is None or seqlen_offset == 0:
|
550 |
+
if self.checkpointing:
|
551 |
+
attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
552 |
else:
|
553 |
+
attn_output = self.inner_attn(qkv, **attention_kwargs)
|
|
|
|
|
|
|
554 |
else:
|
555 |
+
q = qkv[:, :, 0]
|
556 |
+
causal = None if past_key_values.sequence_len_offset == 0 else False
|
557 |
+
attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
|
558 |
|
559 |
output = rearrange(attn_output, "... h d -> ... (h d)")
|
560 |
output = self.out_proj(output)
|
|
|
672 |
if module.padding_idx is not None:
|
673 |
module.weight.data[module.padding_idx].zero_()
|
674 |
elif isinstance(module, nn.LayerNorm):
|
675 |
+
module.bias.data.zero_()
|
|
|
676 |
module.weight.data.fill_(1.0)
|
677 |
|
678 |
def prepare_inputs_for_generation(
|
679 |
self,
|
680 |
input_ids: torch.LongTensor,
|
681 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
682 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
683 |
**kwargs,
|
684 |
) -> Dict[str, Any]:
|
685 |
+
if attention_mask is not None and torch.any(~attention_mask.bool()):
|
686 |
+
total_seq_len = torch.sum(attention_mask, dim=1)
|
687 |
+
max_seq_len = torch.max(total_seq_len)
|
688 |
+
|
689 |
+
total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
|
690 |
+
cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
|
691 |
+
attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
|
692 |
+
else:
|
693 |
+
attention_mask = None
|
694 |
+
|
695 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
696 |
past_key_values = InferenceParams(
|
|
|
697 |
max_batch_size=input_ids.shape[0],
|
698 |
+
max_sequence_len=self.config.n_positions,
|
699 |
+
sequence_len_offset=0,
|
700 |
batch_size_offset=0,
|
701 |
+
fused_ft_kernel=False,
|
702 |
key_value_memory_dict={},
|
|
|
703 |
)
|
704 |
else:
|
705 |
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
706 |
+
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
|
707 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
708 |
|
709 |
return {
|
|
|
712 |
"attention_mask": attention_mask,
|
713 |
}
|
714 |
|
715 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
716 |
+
if isinstance(module, MixFormerSequentialPreTrainedModel):
|
717 |
+
module.gradient_checkpointing = value
|
718 |
|
719 |
|
720 |
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
|
|
756 |
labels: Optional[torch.LongTensor] = None,
|
757 |
**kwargs,
|
758 |
) -> CausalLMOutputWithPast:
|
759 |
+
if attention_mask is not None and self.training:
|
760 |
+
print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")
|
761 |
+
|
762 |
+
if past_key_values is None and attention_mask is None:
|
763 |
+
lm_logits = self.layers(input_ids)
|
764 |
+
else:
|
765 |
+
hidden_layer = self.layers[0](input_ids)
|
766 |
+
for module in self.layers[1:-1]:
|
767 |
+
hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
|
768 |
+
lm_logits = self.layers[-1](hidden_layer)
|
769 |
|
770 |
loss = None
|
771 |
if labels is not None:
|
772 |
loss = self.loss(lm_logits, labels)
|
773 |
|
774 |
+
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eab6a12a9a2b78cac8f8975aea9f3a5e89ddadcb9e0dad27e40965e57e235a4a
|
3 |
+
size 2836623617
|
special_tokens_map.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<|endoftext|>",
|
3 |
+
"eos_token": "<|endoftext|>",
|
4 |
+
"unk_token": "<|endoftext|>"
|
5 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<|endoftext|>",
|
4 |
+
"clean_up_tokenization_spaces": true,
|
5 |
+
"eos_token": "<|endoftext|>",
|
6 |
+
"model_max_length": 2048,
|
7 |
+
"tokenizer_class": "CodeGenTokenizer",
|
8 |
+
"unk_token": "<|endoftext|>"
|
9 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|