mariordoniez commited on
Commit
77d4d73
1 Parent(s): fc863b3

Upload 14 files

Browse files
README.md CHANGED
@@ -1,148 +1,143 @@
1
  ---
2
  license: other
3
- base_model: microsoft/phi-1_5
4
- tags:
5
- - generated_from_trainer
6
- - sales
7
- model-index:
8
- - name: salesGPT_v2
9
- results: []
10
- datasets:
11
- - goendalf666/sales-conversations-2
12
- - goendalf666/sales-conversations-instruction-ext
13
- - goendalf666/sales-conversations-instruction-base
14
- - goendalf666/sales-textbook_for_convincing_and_selling
15
  language:
16
  - en
17
  pipeline_tag: text-generation
18
  ---
 
19
 
20
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
21
- should probably proofread and complete it, then remove this comment. -->
22
-
23
- # salesGPT_v2
24
-
25
- **Model Card for salesGPT_v2**
26
-
27
- ### Model Description
28
- salesGPT_v2, derived from microsoft/phi-1_5, is specialized in simulating sales conversations, wherein it understands customer requirements, manages objections, and suggests suitable products or services. It was fine-tuned on a variety of sales-related datasets and seems proficient in initiating conversations, asking pertinent questions, and sustaining interactive dialogues with users.
29
-
30
- ### Related Ressources
31
-
32
- Github: https://github.com/tom813/salesGPT_foundation
33
- salesGPT_v1: https://huggingface.co/goendalf666/salesGPT_v1
34
-
35
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63797fcb2cb50dda39d8aec6/re7MmsaYNzTYVH2jEXDDu.png)
36
-
37
- ### Intended Uses & Limitations
38
- **Intended Uses:**
39
- - Simulating sales conversations for training or evaluation purposes.
40
- - Providing guidelines or suggested dialogues for sales representatives.
41
-
42
- **Limitations:**
43
- - The model might repetitively ask questions in certain scenarios.
44
- - May struggle with handling customers who lack specific preferences or knowledge about products.
45
- - The objection handling could be more focused on convincing techniques rather than objective criteria.
46
- - Challenges in providing appropriate suggestions for customers without specific needs.
47
- - Limited effectiveness in handling financial and budgetary conversations or sensitivities.
48
-
49
- ### Training and Evaluation Data
50
- **Training Data:**
51
- 1. **Textbook v1 Dataset**
52
- - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling)
53
- - Content: Textbook content for sales, derived from structural points and detailed subpoints created through API calls.
54
-
55
- 2. **Sales Conversation Dataset**
56
- - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations)
57
- - Content: Sales conversations, generated based on the chapters of the textbook.
58
-
59
- 3. **Sales Conversations Instruction Base Dataset**
60
- - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-base)
61
- - Content: Extended sales conversations with structured dialogues.
62
-
63
- 4. **Sales Conversations Instruction Extension Dataset**
64
- - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-ext)
65
- - Content: Updates based on real conversations with the model to improve its proficiency in unconvincing cases.
66
-
67
- **Evaluation Data:**
68
- - More information is needed regarding how and where the model was evaluated. If it was assessed on a separate test set, providing access and details to that dataset would be crucial.
69
-
70
- ### Training Procedure
71
- Fine-tuning of salesGPT_v2 was executed in three phases using the LoRa approach with Rank 64:
72
- 1. Training on a textbook for 20k steps.
73
- 2. Training on sales conversations for 40k steps, resulting in salesGPT_v1.
74
- 3. Training on sales conversations instruction for 40k steps, evolving into salesGPT_v2.
75
-
76
- Hyperparameters used during training:
77
- - Learning rate: 0.0002
78
- - Train batch size: 2
79
- - Eval batch size: 8
80
- - Seed: 42
81
- - Optimizer: Adam
82
- - LR scheduler type: Cosine
83
- - Training steps: 40,000
84
- - [Additional parameters...](https://huggingface.co/goendalf666/salesGPT_v2)
85
-
86
- Model size details:
87
- - Trainable params: 18,874,368
88
- - All params: 1,437,145,088
89
- - Trainable%: 1.313323766514519
90
-
91
- ### Results
92
- While salesGPT_v2 showcases notable capabilities like initiating a structured dialogue and asking relevant questions, there are some key areas of improvement and observations:
93
- - The model tends to keep asking questions if a user provides vague information.
94
- - In some cases, it's not assertive enough to suggest a product when a customer lacks preferences.
95
- - It tends to leverage objective criteria over convincing strategies during objections.
96
- - Needs better handling of conversations regarding budgetary or financial concerns.
97
-
98
- ### Example Usage & Conversations
99
- #### Example 1:
100
- **Customer**: I am interested in buying a new laptop.
101
- **salesGPT_v2**: I understand. Could you please tell me more about your specific requirements and budget?
102
-
103
- #### Example 2:
104
- **Customer**: Oh I don't know how I should explain this to my wife. I said we can't go to vacation because we don't have any money.
105
- **salesGPT_v2**: I see. Can you please provide more details about your financial situation and the specific expenses you are concerned about?
106
-
107
- These examples underline the model's capability to engage in sales dialogues and also hint at its limitations, such as possibly prying too much into a customer's personal financial situation.
108
-
109
- ### Future Directions & Improvement
110
- - Enhancement in handling objections by focusing more on persuasive techniques and emotional intelligence.
111
- - Improving suggestion-making capability especially when customers are indecisive.
112
- - Better navigation through the conversation that involves budgetary and financial aspects without seeming insensitive or intrusive.
113
- - Striking a balance between being informative and being too technical in its product descriptions.
114
- - Possible implementation of more ethical and privacy-guided conversation guidelines, especially in discussing customers' financial capacities.
115
-
116
- ### Ethical Considerations
117
- The model’s tendency to repeatedly ask for specific information, especially related to personal financial details, raises ethical concerns regarding privacy and data sensitivity. Care must be taken to ensure the model respects user privacy and does not persistently probe for personal or sensitive information.
118
-
119
- ### Conclusion
120
- salesGPT_v2 offers a foundation for simulating sales conversations with potential for future refinement in handling objections, making product suggestions, and managing conversations delicately around financial discussions. Future versions might seek to refine its balance between being convincingly persuasive and remaining ethically and emotionally intelligent within dialogues.
121
-
122
- ### Inference
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  ```
125
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
126
 
127
- # Initialize the model and tokenizer
128
- cuda = "cuda:0" if torch.cuda.is_available() else ""
129
- model = AutoModelForCausalLM.from_pretrained("goendalf666/salesGPT_v2", trust_remote_code=True, torch_dtype=torch.float32, device_map={"":0})
130
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, device_map={"":0})
131
 
132
- inputs = tokenizer(conversation_text, return_tensors="pt", return_attention_mask=False)
133
- inputs.to(cuda)
134
 
135
- # Generate response
136
- outputs = model.generate(**inputs, max_length=512)
137
- response_text = tokenizer.batch_decode(outputs)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  ```
139
- Or
140
 
141
- Inference script: https://github.com/tom813/salesGPT_foundation/blob/main/inference.py
 
 
 
 
 
 
 
 
 
142
 
143
- ### Framework versions
144
 
145
- - Transformers 4.32.1
146
- - Pytorch 2.1.0.dev20230829+cu121
147
- - Datasets 2.14.5
148
- - Tokenizers 0.13.3
 
 
 
 
 
1
  ---
2
  license: other
3
+ license_name: microsoft-research-license
4
+ license_link: https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx
 
 
 
 
 
 
 
 
 
 
5
  language:
6
  - en
7
  pipeline_tag: text-generation
8
  ---
9
+ ## Model Summary
10
 
11
+ The language model phi-1.5 is a Transformer with **1.3 billion** parameters. It was trained using the same data sources as [phi-1](https://huggingface.co/microsoft/phi-1), augmented with a new data source that consists of various NLP synthetic texts. When assessed against benchmarks testing common sense, language understanding, and logical reasoning, phi-1.5 demonstrates a nearly state-of-the-art performance among models with less than 10 billion parameters.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ We **did not** fine-tune phi-1.5 either for **instruction following or through reinforcement learning from human feedback**. The intention behind crafting this open-source model is to provide the research community with a non-restricted small model to explore vital safety challenges, such as reducing toxicity, understanding societal biases, enhancing controllability, and more.
14
+
15
+ For a safer model release, we exclude generic web-crawl data sources such as common-crawl from the training. This strategy prevents direct exposure to potentially harmful online content, enhancing the model's safety without RLHF. However, the model is still vulnerable to generating harmful content. We hope the model can help the research community to further study the safety of language models.
16
+
17
+ phi-1.5 can write poems, draft emails, create stories, summarize texts, write Python code (such as downloading a Hugging Face transformer model), etc.
18
+
19
+ ## Intended Uses
20
+ Given the nature of the training data, phi-1.5 is best suited for prompts using the QA format, the chat format, and the code format. Note that phi-1.5, being a base model, often produces irrelevant text following the main answer. In the following example, we've truncated the answer for illustrative purposes only.
21
+
22
+ #### QA format:
23
+
24
+ ```markdown
25
+ Write a detailed analogy between mathematics and a lighthouse.
26
+
27
+ Answer: Mathematics is like a lighthouse, guiding us through the vast ocean of numbers and calculations. Just as a lighthouse illuminates the darkness, mathematics provides us with a clear path to navigate through complex problems. It helps us make sense of the world around us, just like a lighthouse helps ships find their way home.
28
  ```
29
+ where the model generates the text after "Answer:".
30
+
31
+ #### Chat format:
32
+
33
+ ```markdown
34
+ Alice: I don't know why, I'm struggling to maintain focus while studying. Any suggestions?
35
+
36
+ Bob: Have you tried using a timer? It can help you stay on track and avoid distractions.
37
+
38
+ Alice: That's a good idea. I'll give it a try.
39
 
40
+ Charlie: Another thing that can help is to break up your study sessions into smaller chunks. It's easier to concentrate on one thing at a time.
 
 
 
41
 
42
+ Alice: That makes sense. I'll try that too.
 
43
 
44
+ Bob: And don't forget to take breaks! It's important to give your brain a rest so you can come back to your studies with a fresh perspective.
45
+
46
+ Alice: Thanks for the advice, guys. I feel more motivated now.
47
+
48
+ Charlie: No problem, Alice. We're all in this together.
49
+
50
+ Bob: Yeah, and remember that it's okay to ask for help if you need it. We're here to support each other.
51
+ ```
52
+ where the model generates the text after the first "Bob:".
53
+
54
+ #### Code format:
55
+ ```python
56
+ def print_prime(n):
57
+ """
58
+ Print all primes between 1 and n
59
+ """
60
+ primes = []
61
+ for num in range(2, n+1):
62
+ is_prime = True
63
+ for i in range(2, int(math.sqrt(num))+1):
64
+ if num % i == 0:
65
+ is_prime = False
66
+ break
67
+ if is_prime:
68
+ primes.append(num)
69
+ print(primes)
70
+ ```
71
+ where the model generates the text after the comments.
72
+
73
+ **Notes**
74
+ * phi-1.5 is intended for research purposes. The model-generated text/code should be treated as a starting point rather than a definitive solution for potential use cases. Users should be cautious when employing these models in their applications.
75
+ * Direct adoption for production tasks is out of the scope of this research project. As a result, phi-1.5 has not been tested to ensure that it performs adequately for any production-level application. Please refer to the limitation sections of this document for more details.
76
+
77
+ ## Limitations of phi-1.5
78
+
79
+ * Generate Inaccurate Code and Facts: The model often produces incorrect code snippets and statements. Users should treat these outputs as suggestions or starting points, not as definitive or accurate solutions.
80
+ * Limited Scope for code: If the model generates Python scripts that utilize uncommon packages or scripts in other languages, we strongly recommend users manually verify all API uses.
81
+ * Unreliable Responses to Instruction: The model has not undergone instruction fine-tuning. As a result, it may struggle or fail to adhere to intricate or nuanced instructions provided by users.
82
+ * Language Limitations: The model is primarily designed to understand standard English. Informal English, slang, or any other language outside of English might pose challenges to its comprehension, leading to potential misinterpretations or errors in response.
83
+ * Potential Societal Biases: Regardless of the safe data used for its training, the model is not entirely free from societal biases. There's a possibility it may generate content that mirrors these societal biases, particularly if prompted or instructed to do so. We urge users to be aware of this and to exercise caution and critical thinking when interpreting model outputs.
84
+ * Toxicity: Despite that the model is trained with carefully selected data, the model can still produce harmful content if explicitly prompted or instructed to do so. We chose to release the model for research purposes only -- We hope to help the open-source community develop the most effective ways to reduce the toxicity of a model directly after pretraining.
85
+
86
+ ## Training
87
+
88
+ ### Model
89
+ * Architecture: a Transformer-based model with next-word prediction objective
90
+ * Dataset size: 30B tokens
91
+ * Training tokens: 150B tokens
92
+ * Precision: fp16
93
+ * GPUs: 32xA100-40G
94
+ * Training time: 8 days
95
+
96
+ ### Software
97
+ * [PyTorch](https://github.com/pytorch/pytorch)
98
+ * [DeepSpeed](https://github.com/microsoft/DeepSpeed)
99
+ * [flash-attention](https://github.com/HazyResearch/flash-attention)
100
+
101
+ ### License
102
+ The model is licensed under the [Research License](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx).
103
+
104
+ ### Sample Code
105
+ ```python
106
+ import torch
107
+ from transformers import AutoModelForCausalLM, AutoTokenizer
108
+
109
+ torch.set_default_device("cuda")
110
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
111
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
112
+ inputs = tokenizer('''```python
113
+ def print_prime(n):
114
+ """
115
+ Print all primes between 1 and n
116
+ """''', return_tensors="pt", return_attention_mask=False)
117
+
118
+ outputs = model.generate(**inputs, max_length=200)
119
+ text = tokenizer.batch_decode(outputs)[0]
120
+ print(text)
121
  ```
 
122
 
123
+ If you need to use the model in a lower precision (e.g., FP16), please wrap the model's forward pass with `torch.autocast()`, as follows:
124
+ ```python
125
+ with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
126
+ outputs = model.generate(**inputs, max_length=200)
127
+ ```
128
+
129
+ **Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
130
+ Furthermore, in the forward pass of the model, we currently do not support attention mask during training, outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
131
+
132
+ ### Citation
133
 
134
+ You can find the paper at https://arxiv.org/abs/2309.05463
135
 
136
+ ```bib
137
+ @article{textbooks2,
138
+ title={Textbooks Are All You Need II: \textbf{phi-1.5} technical report},
139
+ author={Li, Yuanzhi and Bubeck, S{\'e}bastien and Eldan, Ronen and Del Giorno, Allie and Gunasekar, Suriya and Lee, Yin Tat},
140
+ journal={arXiv preprint arXiv:2309.05463},
141
+ year={2023}
142
+ }
143
+ ```
Research License.docx ADDED
Binary file (38.9 kB). View file
 
added_tokens.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "\t\t": 50294,
3
+ "\t\t\t": 50293,
4
+ "\t\t\t\t": 50292,
5
+ "\t\t\t\t\t": 50291,
6
+ "\t\t\t\t\t\t": 50290,
7
+ "\t\t\t\t\t\t\t": 50289,
8
+ "\t\t\t\t\t\t\t\t": 50288,
9
+ "\t\t\t\t\t\t\t\t\t": 50287,
10
+ " ": 50286,
11
+ " ": 50285,
12
+ " ": 50284,
13
+ " ": 50283,
14
+ " ": 50282,
15
+ " ": 50281,
16
+ " ": 50280,
17
+ " ": 50279,
18
+ " ": 50278,
19
+ " ": 50277,
20
+ " ": 50276,
21
+ " ": 50275,
22
+ " ": 50274,
23
+ " ": 50273,
24
+ " ": 50272,
25
+ " ": 50271,
26
+ " ": 50270,
27
+ " ": 50269,
28
+ " ": 50268,
29
+ " ": 50267,
30
+ " ": 50266,
31
+ " ": 50265,
32
+ " ": 50264,
33
+ " ": 50263,
34
+ " ": 50262,
35
+ " ": 50261,
36
+ " ": 50260,
37
+ " ": 50259,
38
+ " ": 50258,
39
+ " ": 50257
40
+ }
config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
- "_name_or_path": "mariordoniez/phi",
3
  "activation_function": "gelu_new",
4
  "architectures": [
5
  "MixFormerSequentialForCausalLM"
6
  ],
7
  "auto_map": {
8
- "AutoConfig": "mariordoniez/phi--configuration_mixformer_sequential.MixFormerSequentialConfig",
9
- "AutoModelForCausalLM": "mariordoniez/phi--modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
10
  },
11
  "embd_pdrop": 0.0,
12
  "initializer_range": 0.02,
@@ -20,7 +20,7 @@
20
  "resid_pdrop": 0.0,
21
  "rotary_dim": 32,
22
  "tie_word_embeddings": false,
23
- "torch_dtype": "float32",
24
  "transformers_version": "4.32.1",
25
  "vocab_size": 51200
26
  }
 
1
  {
2
+ "_name_or_path": "phi-1.5-half",
3
  "activation_function": "gelu_new",
4
  "architectures": [
5
  "MixFormerSequentialForCausalLM"
6
  ],
7
  "auto_map": {
8
+ "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
9
+ "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
10
  },
11
  "embd_pdrop": 0.0,
12
  "initializer_range": 0.02,
 
20
  "resid_pdrop": 0.0,
21
  "rotary_dim": 32,
22
  "tie_word_embeddings": false,
23
+ "torch_dtype": "float16",
24
  "transformers_version": "4.32.1",
25
  "vocab_size": 51200
26
  }
configuration_mixformer_sequential.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class MixFormerSequentialConfig(PretrainedConfig):
11
+ """MixFormer (sequential for DeepSpeed) configuration."""
12
+
13
+ model_type = "mixformer-sequential"
14
+
15
+ attribute_map = {
16
+ "max_position_embeddings": "n_positions",
17
+ "hidden_size": "n_embd",
18
+ "num_attention_heads": "n_head",
19
+ "num_hidden_layers": "n_layer",
20
+ }
21
+
22
+ def __init__(
23
+ self,
24
+ vocab_size: Optional[int] = 50304,
25
+ n_positions: Optional[int] = 2048,
26
+ n_embd: Optional[int] = 1024,
27
+ n_layer: Optional[int] = 20,
28
+ n_inner: Optional[int] = None,
29
+ n_head: Optional[int] = 16,
30
+ rotary_dim: Optional[int] = 32,
31
+ activation_function: Optional[str] = "gelu_new",
32
+ embd_pdrop: Optional[float] = 0.0,
33
+ resid_pdrop: Optional[float] = 0.0,
34
+ layer_norm_epsilon: Optional[float] = 1e-5,
35
+ initializer_range: Optional[float] = 0.02,
36
+ tie_word_embeddings: Optional[bool] = False,
37
+ pad_vocab_size_multiple: Optional[int] = 64,
38
+ **kwargs
39
+ ) -> None:
40
+ self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
41
+ self.n_positions = n_positions
42
+ self.n_embd = n_embd
43
+ self.n_layer = n_layer
44
+ self.n_inner = n_inner
45
+ self.n_head = n_head
46
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
47
+ self.activation_function = activation_function
48
+ self.embd_pdrop = embd_pdrop
49
+ self.resid_pdrop = resid_pdrop
50
+ self.layer_norm_epsilon = layer_norm_epsilon
51
+ self.initializer_range = initializer_range
52
+
53
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_mixformer_sequential.py CHANGED
@@ -34,28 +34,20 @@
34
  from __future__ import annotations
35
 
36
  import math
 
37
  from typing import Any, Dict, Optional, Tuple, Union
38
  from dataclasses import dataclass, field
39
 
40
  import torch
41
  import torch.nn as nn
42
 
43
- from einops import rearrange, repeat
44
  from transformers.activations import ACT2FN
45
  from transformers import PretrainedConfig, PreTrainedModel
46
  from transformers.modeling_outputs import CausalLMOutputWithPast
47
 
48
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
49
 
50
-
51
- try:
52
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
53
- from flash_attn.ops.fused_dense import FusedDense
54
- except:
55
- FlashRotaryEmbedding = None
56
- FusedDense = None
57
-
58
-
59
  @dataclass
60
  class InferenceParams:
61
  """Inference parameters passed to model to efficiently calculate
@@ -65,20 +57,21 @@ class InferenceParams:
65
  https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
66
 
67
  Args:
68
- max_seqlen: Maximum sequence length.
69
  max_batch_size: Maximum batch size.
70
- seqlen_offset: Sequence length offset.
71
  batch_size_offset: Batch size offset.
72
  key_value_memory_dict: Key value memory dictionary.
 
73
  lengths_per_sample: Lengths per sample.
74
 
75
  """
76
 
77
- max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
78
 
79
  max_batch_size: int = field(metadata={"help": "Maximum batch size."})
80
 
81
- seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
82
 
83
  batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
84
 
@@ -86,6 +79,8 @@ class InferenceParams:
86
  default_factory=dict, metadata={"help": "Key value memory dictionary."}
87
  )
88
 
 
 
89
  lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
90
 
91
 
@@ -108,112 +103,12 @@ class Embedding(nn.Module):
108
  return hidden_states
109
 
110
 
111
- def _apply_rotary_emb(
112
- x: torch.FloatTensor,
113
- cos: torch.FloatTensor,
114
- sin: torch.FloatTensor,
115
- ) -> torch.FloatTensor:
116
- _, seqlen, _, head_dim = x.shape
117
- rotary_seqlen, rotary_dim = cos.shape
118
- rotary_dim *= 2
119
-
120
- assert rotary_dim <= head_dim
121
- assert seqlen <= rotary_seqlen
122
- assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
123
-
124
- x_rot = x[:, :, :, :rotary_dim]
125
- x_pass = x[:, :, :, rotary_dim:]
126
-
127
- x1, x2 = x_rot.chunk(2, dim=-1)
128
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
129
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
130
-
131
- x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
132
-
133
- return torch.cat([x_rot, x_pass], axis=-1)
134
-
135
-
136
- def _apply_rotary_emb_kv(
137
- kv: torch.FloatTensor,
138
- cos: torch.FloatTensor,
139
- sin: torch.FloatTensor,
140
- cos_k: Optional[torch.FloatTensor] = None,
141
- sin_k: Optional[torch.FloatTensor] = None,
142
- ) -> torch.FloatTensor:
143
- _, seqlen, two, _, head_dim = kv.shape
144
- assert two == 2
145
-
146
- rotary_seqlen, rotary_dim = cos.shape
147
- rotary_dim *= 2
148
- assert rotary_dim <= head_dim
149
- assert seqlen <= rotary_seqlen
150
- assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
151
-
152
- k_rot = kv[:, :, 0, :, :rotary_dim]
153
- k_pass = kv[:, :, 0, :, rotary_dim:]
154
-
155
- k1, k2 = k_rot.chunk(2, dim=-1)
156
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
157
- k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
158
-
159
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
160
-
161
- return torch.cat(
162
- [
163
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
164
- kv[:, :, 1:2, :, :],
165
- ],
166
- axis=2,
167
- )
168
-
169
-
170
- def _apply_rotary_emb_qkv(
171
- qkv: torch.FloatTensor,
172
- cos: torch.FloatTensor,
173
- sin: torch.FloatTensor,
174
- cos_k: Optional[torch.FloatTensor] = None,
175
- sin_k: Optional[torch.FloatTensor] = None,
176
- ) -> torch.FloatTensor:
177
- _, seqlen, three, _, head_dim = qkv.shape
178
- assert three == 3
179
-
180
- rotary_seqlen, rotary_dim = cos.shape
181
- rotary_dim *= 2
182
- assert rotary_dim <= head_dim
183
- assert seqlen <= rotary_seqlen
184
- assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
185
-
186
- q_rot = qkv[:, :, 0, :, :rotary_dim]
187
- q_pass = qkv[:, :, 0, :, rotary_dim:]
188
-
189
- k_rot = qkv[:, :, 1, :, :rotary_dim]
190
- k_pass = qkv[:, :, 1, :, rotary_dim:]
191
-
192
- q1, q2 = q_rot.chunk(2, dim=-1)
193
- k1, k2 = k_rot.chunk(2, dim=-1)
194
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
195
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
196
-
197
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
198
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
199
-
200
- return torch.cat(
201
- [
202
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
203
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
204
- qkv[:, :, 2:3, :, :],
205
- ],
206
- axis=2,
207
- )
208
-
209
-
210
  class RotaryEmbedding(nn.Module):
211
- """Rotary positional embedding (RoPE).
212
 
213
  Reference:
214
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
215
- https://arxiv.org/pdf/2104.09864.pdf.
216
-
217
  """
218
 
219
  def __init__(
@@ -221,7 +116,6 @@ class RotaryEmbedding(nn.Module):
221
  dim: int,
222
  base: int = 10000,
223
  scale_base: Optional[float] = None,
224
- pos_idx_in_fp32: bool = True,
225
  device: Optional[str] = None,
226
  **kwargs,
227
  ) -> None:
@@ -230,23 +124,21 @@ class RotaryEmbedding(nn.Module):
230
  if scale_base is not None:
231
  raise NotImplementedError
232
 
 
233
  self.dim = dim
234
- self.base = float(base)
235
  self.scale_base = scale_base
236
- self.pos_idx_in_fp32 = pos_idx_in_fp32
237
  self.device = device
238
 
239
- # Generate and save the inverse frequency buffer (non-trainable)
240
- inv_freq = self._compute_inv_freq(device)
241
- self.register_buffer("inv_freq", inv_freq, persistent=False)
242
 
243
- # Generate and save the scale buffer (non-trainable)
244
  scale = (
245
  (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
246
  if scale_base is not None
247
  else None
248
  )
249
- self.register_buffer("scale", scale, persistent=False)
250
 
251
  self._seq_len_cached = 0
252
  self._cos_cached = None
@@ -254,73 +146,91 @@ class RotaryEmbedding(nn.Module):
254
  self._cos_k_cached = None
255
  self._sin_k_cached = None
256
 
257
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
258
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
 
 
259
 
260
- def _update_cos_sin_cache(
261
- self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
262
- ) -> None:
263
- # Reset the tables if sequence length has been chaned, if we are on a
264
- # new device or if we are switching from inference mode to training
265
- if (
266
- seqlen > self._seq_len_cached
267
- or self._cos_cached is None
268
- or self._cos_cached.device != device
269
- or self._cos_cached.dtype != dtype
270
- or (self.training and self._cos_cached.is_inference())
271
- ):
272
- self._seq_len_cached = seqlen
273
 
274
- # fp32 is preferred since the output of `torch.arange` can be quite large
275
- # and bf16 would lose a lot of precision
276
- if self.pos_idx_in_fp32:
277
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
278
- if self.inv_freq.dtype != torch.float32:
279
- inv_freq = self._compute_inv_freq(device=device)
280
- else:
281
- inv_freq = self.inv_freq
282
- else:
283
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
284
- inv_freq = self.inv_freq
285
 
286
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
287
- freqs = torch.outer(t, inv_freq)
 
288
  if self.scale is None:
289
- self._cos_cached = torch.cos(freqs).to(dtype)
290
- self._sin_cached = torch.sin(freqs).to(dtype)
291
  else:
292
  power = (
293
  torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
294
  ) / self.scale_base
295
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
296
 
297
- # Force the scale multiplication to happen in fp32
298
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
299
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
300
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
301
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
302
 
303
- def forward(
304
  self,
305
- qkv: torch.Tensor,
306
- kv: Optional[torch.Tensor] = None,
307
- seqlen_offset: int = 0,
308
- max_seqlen: Optional[int] = None,
309
- ) -> Tuple[torch.Tensor, torch.Tensor]:
310
- seqlen = qkv.shape[1]
311
-
312
- if max_seqlen is not None:
313
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
314
- else:
315
- self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
316
-
317
- if kv is None:
318
- return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
319
- else:
320
- q = _apply_rotary_emb(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
321
- kv = _apply_rotary_emb_kv(kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
- return q, kv
 
 
 
324
 
325
 
326
  class MLP(nn.Module):
@@ -380,22 +290,21 @@ class SelfAttention(nn.Module):
380
  attention_mask: Optional[torch.BoolTensor] = None,
381
  **kwargs,
382
  ) -> torch.FloatTensor:
383
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
 
384
  q, k, v = qkv.unbind(dim=2)
385
 
386
- causal = self.causal if causal is None else causal
387
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
388
-
389
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
390
 
391
  if attention_mask is not None:
392
- padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
393
  padding_mask.masked_fill_(attention_mask, 0.0)
394
 
395
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
396
 
397
  if causal:
398
- causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
399
  scores = scores + causal_mask.to(dtype=scores.dtype)
400
 
401
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
@@ -434,31 +343,25 @@ class CrossAttention(nn.Module):
434
  attention_mask: Optional[torch.BoolTensor] = None,
435
  **kwargs,
436
  ) -> torch.FloatTensor:
437
- batch_size, seqlen_q = q.shape[0], q.shape[1]
438
- seqlen_k = kv.shape[1]
439
- assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
440
 
441
- if kv.shape[3] != q.shape[2]:
442
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
443
  k, v = kv.unbind(dim=2)
444
 
445
- causal = self.causal if causal is None else causal
446
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
447
-
448
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
449
 
450
  if attention_mask is not None:
451
- padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
452
  padding_mask.masked_fill_(attention_mask, 0.0)
453
 
454
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
455
 
456
  if causal:
457
- rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
458
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
459
- causal_mask = cols > rows + seqlen_k - seqlen_q
460
-
461
- scores = scores.masked_fill(causal_mask, -10000.0)
462
 
463
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
464
  attention = self.drop(attention)
@@ -468,12 +371,21 @@ class CrossAttention(nn.Module):
468
  return output
469
 
470
 
471
- def _find_mha_dims(
472
- config: PretrainedConfig,
473
- n_head: Optional[int] = None,
474
- n_head_kv: Optional[int] = None,
475
- head_dim: Optional[int] = None,
476
  ) -> Tuple[int, int]:
 
 
 
 
 
 
 
 
 
 
 
 
477
  assert all(
478
  hasattr(config, attr) for attr in ["n_embd", "n_head"]
479
  ), "`config` must have `n_embd` and `n_head` attributes."
@@ -489,20 +401,31 @@ def _find_mha_dims(
489
  elif n_head is None or head_dim is None:
490
  raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
491
 
492
- if n_head_kv is None:
493
- n_head_kv = getattr(config, "n_head_kv", None) or n_head
494
- assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
495
 
496
- return n_head, n_head_kv, head_dim
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
- def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
500
  num_heads, head_dim = kv.shape[-2:]
501
 
502
  if layer_idx not in inference_params.key_value_memory_dict:
503
  kv_cache = torch.empty(
504
  inference_params.max_batch_size,
505
- inference_params.max_seqlen,
506
  2,
507
  num_heads,
508
  head_dim,
@@ -511,19 +434,43 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
511
  )
512
  inference_params.key_value_memory_dict[layer_idx] = kv_cache
513
  else:
514
- kv_cache = inference_params.key_value_memory_dict[layer_idx]
 
 
 
 
515
 
516
  batch_start = inference_params.batch_size_offset
517
  batch_end = batch_start + kv.shape[0]
518
- assert batch_end <= kv_cache.shape[0]
519
 
520
- sequence_start = inference_params.seqlen_offset
521
  sequence_end = sequence_start + kv.shape[1]
522
- assert sequence_end <= kv_cache.shape[1]
523
 
524
- assert kv_cache is not None
525
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
526
- kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  return kv
529
 
@@ -539,11 +486,11 @@ class MHA(nn.Module):
539
  rotary_dim: Optional[int] = None,
540
  rotary_emb_scale_base: Optional[float] = None,
541
  n_head: Optional[int] = None,
542
- n_head_kv: Optional[int] = None,
543
  head_dim: Optional[int] = None,
544
  bias: bool = True,
545
  causal: bool = True,
546
  softmax_scale: Optional[float] = None,
 
547
  layer_idx: Optional[int] = None,
548
  return_residual: bool = False,
549
  checkpointing: bool = False,
@@ -556,101 +503,58 @@ class MHA(nn.Module):
556
  rotary_kwargs = {"device": device}
557
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
558
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
559
-
560
- rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
561
- if rotary_cls is None:
562
- rotary_cls = RotaryEmbedding
563
- self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
564
 
565
  # MLP
566
- self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
567
- op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
568
  hidden_size = config.n_embd
569
 
570
- linear_cls = FusedDense if config.fused_dense else nn.Linear
571
- if linear_cls is None:
572
- linear_cls = nn.Linear
573
-
574
- self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
575
- self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
576
 
577
  # Attention
578
- self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
579
- self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
580
 
581
  self.layer_idx = layer_idx
582
  self.return_residual = return_residual
583
  self.checkpointing = checkpointing
584
 
585
- def _forward_self_attn(
586
- self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
587
- ) -> torch.FloatTensor:
588
- qkv = self.Wqkv(x)
589
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
590
-
591
- if self.rotary_emb_dim > 0:
592
- qkv = self.rotary_emb(qkv)
593
-
594
- if self.checkpointing:
595
- return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
596
-
597
- return self.inner_attn(qkv, attention_mask=attention_mask)
598
-
599
- def _forward_cross_attn(
600
  self,
601
  x: torch.FloatTensor,
602
- past_key_values: Optional[InferenceParams],
603
- attention_mask: Optional[torch.BoolTensor],
604
- ) -> torch.FloatTensor:
 
 
 
605
  qkv = self.Wqkv(x)
 
606
 
607
- q = qkv[..., : self.n_head * self.head_dim]
608
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
609
-
610
- kv = qkv[..., self.n_head * self.head_dim :]
611
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
612
-
613
- seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
614
- causal = None if seqlen_offset == 0 else False
615
  if self.rotary_emb_dim > 0:
616
- q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
617
 
618
  if past_key_values is not None:
619
- kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
620
 
621
- if self.checkpointing:
622
- return torch.utils.checkpoint.checkpoint(
623
- self.inner_cross_attn, q, kv, attention_mask=attention_mask, causal=causal
624
- )
625
 
626
- return self.inner_cross_attn(q, kv, attention_mask=attention_mask, causal=causal)
627
 
628
- def forward(
629
- self,
630
- x: torch.FloatTensor,
631
- past_key_values: Optional[InferenceParams] = None,
632
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
633
- **kwargs,
634
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
635
- if attention_mask is not None and torch.any(~attention_mask.bool()):
636
- attention_mask = attention_mask.bool()
637
- else:
638
- attention_mask = None
639
-
640
- # MHA
641
- if self.n_head == self.n_head_kv:
642
- if past_key_values is None:
643
- # If `past_key_values` are not supplied, we run self-attention
644
- attn_output = self._forward_self_attn(x, attention_mask)
645
  else:
646
- # If `past_key_values` are supplied, it means that we might have cached values and
647
- # could take advantage of cross-attention
648
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
649
- # MQA / GQA
650
  else:
651
- # Regardless of `past_key_values` being supplied or not, it always use cross-attention
652
- # because `q` and `kv` lengths might be different
653
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
654
 
655
  output = rearrange(attn_output, "... h d -> ... (h d)")
656
  output = self.out_proj(output)
@@ -768,29 +672,38 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
768
  if module.padding_idx is not None:
769
  module.weight.data[module.padding_idx].zero_()
770
  elif isinstance(module, nn.LayerNorm):
771
- if module.bias is not None:
772
- module.bias.data.zero_()
773
  module.weight.data.fill_(1.0)
774
 
775
  def prepare_inputs_for_generation(
776
  self,
777
  input_ids: torch.LongTensor,
778
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
779
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
780
  **kwargs,
781
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
782
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
783
  past_key_values = InferenceParams(
784
- max_seqlen=self.config.n_positions,
785
  max_batch_size=input_ids.shape[0],
786
- seqlen_offset=0,
 
787
  batch_size_offset=0,
 
788
  key_value_memory_dict={},
789
- lengths_per_sample=None,
790
  )
791
  else:
792
  # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
793
- past_key_values.seqlen_offset = len(input_ids[0]) - 1
794
  input_ids = input_ids[:, -1].unsqueeze(-1)
795
 
796
  return {
@@ -799,9 +712,9 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
799
  "attention_mask": attention_mask,
800
  }
801
 
802
- def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
803
- if isinstance(module, MixFormerSequentialPreTrainedModel):
804
- module.gradient_checkpointing = value
805
 
806
 
807
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
@@ -843,13 +756,19 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
843
  labels: Optional[torch.LongTensor] = None,
844
  **kwargs,
845
  ) -> CausalLMOutputWithPast:
846
- hidden_layer = self.layers[0](input_ids)
847
- for module in self.layers[1:-1]:
848
- hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
849
- lm_logits = self.layers[-1](hidden_layer)
 
 
 
 
 
 
850
 
851
  loss = None
852
  if labels is not None:
853
  loss = self.loss(lm_logits, labels)
854
 
855
- return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
 
34
  from __future__ import annotations
35
 
36
  import math
37
+ import copy
38
  from typing import Any, Dict, Optional, Tuple, Union
39
  from dataclasses import dataclass, field
40
 
41
  import torch
42
  import torch.nn as nn
43
 
44
+ from einops import rearrange
45
  from transformers.activations import ACT2FN
46
  from transformers import PretrainedConfig, PreTrainedModel
47
  from transformers.modeling_outputs import CausalLMOutputWithPast
48
 
49
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
50
 
 
 
 
 
 
 
 
 
 
51
  @dataclass
52
  class InferenceParams:
53
  """Inference parameters passed to model to efficiently calculate
 
57
  https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
58
 
59
  Args:
60
+ max_sequence_len: Maximum sequence length.
61
  max_batch_size: Maximum batch size.
62
+ sequence_len_offset: Sequence length offset.
63
  batch_size_offset: Batch size offset.
64
  key_value_memory_dict: Key value memory dictionary.
65
+ fused_ft_kernel: Whether to use fused kernel for fast inference.
66
  lengths_per_sample: Lengths per sample.
67
 
68
  """
69
 
70
+ max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
71
 
72
  max_batch_size: int = field(metadata={"help": "Maximum batch size."})
73
 
74
+ sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
75
 
76
  batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
77
 
 
79
  default_factory=dict, metadata={"help": "Key value memory dictionary."}
80
  )
81
 
82
+ fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
83
+
84
  lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
85
 
86
 
 
103
  return hidden_states
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  class RotaryEmbedding(nn.Module):
107
+ """Rotary embeddings.
108
 
109
  Reference:
110
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
111
+
 
112
  """
113
 
114
  def __init__(
 
116
  dim: int,
117
  base: int = 10000,
118
  scale_base: Optional[float] = None,
 
119
  device: Optional[str] = None,
120
  **kwargs,
121
  ) -> None:
 
124
  if scale_base is not None:
125
  raise NotImplementedError
126
 
127
+ # Generate and save the inverse frequency buffer (non-trainable)
128
  self.dim = dim
129
+ self.base = base
130
  self.scale_base = scale_base
 
131
  self.device = device
132
 
133
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
134
+ self.register_buffer("inv_freq", inv_freq)
 
135
 
 
136
  scale = (
137
  (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
138
  if scale_base is not None
139
  else None
140
  )
141
+ self.register_buffer("scale", scale)
142
 
143
  self._seq_len_cached = 0
144
  self._cos_cached = None
 
146
  self._cos_k_cached = None
147
  self._sin_k_cached = None
148
 
149
+ def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
150
+ # Reset the tables if the sequence length has changed,
151
+ # or if we're on a new device (possibly due to tracing for instance)
152
+ seqlen = x.shape[1] + seqlen_offset
153
 
154
+ # Re-generate the inverse frequency buffer if it's not fp32
155
+ # (for instance if model.half() was called)
156
+ if self.inv_freq.dtype != "torch.float32":
157
+ self.inv_freq = 1.0 / (
158
+ self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
159
+ )
 
 
 
 
 
 
 
160
 
161
+ if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
162
+ self._seq_len_cached = seqlen
163
+ t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
 
 
 
 
 
 
 
 
164
 
165
+ # Don't do einsum, it converts fp32 to fp16
166
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
167
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
168
  if self.scale is None:
169
+ self._cos_cached = torch.cos(freqs).to(x.dtype)
170
+ self._sin_cached = torch.sin(freqs).to(x.dtype)
171
  else:
172
  power = (
173
  torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
174
  ) / self.scale_base
175
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
176
 
177
+ # We want the multiplication by scale to happen in fp32
178
+ self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
179
+ self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
180
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
181
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
182
 
183
+ def _apply_rotary_emb_qkv(
184
  self,
185
+ qkv: torch.FloatTensor,
186
+ sin: torch.FloatTensor,
187
+ cos: torch.FloatTensor,
188
+ sin_k: Optional[torch.FloatTensor] = None,
189
+ cos_k: Optional[torch.FloatTensor] = None,
190
+ ) -> torch.FloatTensor:
191
+ _, seqlen, three, _, headdim = qkv.shape
192
+ assert three == 3
193
+
194
+ rotary_seqlen, rotary_dim = cos.shape
195
+ rotary_dim *= 2
196
+ assert rotary_dim <= headdim
197
+ assert seqlen <= rotary_seqlen
198
+
199
+ cos_k = cos if cos_k is None else cos_k
200
+ sin_k = sin if sin_k is None else sin_k
201
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
202
+
203
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
204
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
205
+
206
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
207
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
208
+
209
+ # Splits the queries and keys in half
210
+ q1, q2 = q_rot.chunk(2, dim=-1)
211
+ k1, k2 = k_rot.chunk(2, dim=-1)
212
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
213
+
214
+ # Casts to fp32 are necessary to prevent fp16 overflow issues
215
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
216
+
217
+ # Computes the new keys and queries, recasting to original dtype
218
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
219
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
220
+
221
+ return torch.cat(
222
+ [
223
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
224
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
225
+ qkv[:, :, 2:3, :, :],
226
+ ],
227
+ axis=2,
228
+ )
229
 
230
+ def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
231
+ # `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
232
+ self._update_cos_sin_cache(qkv, seqlen_offset)
233
+ return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
234
 
235
 
236
  class MLP(nn.Module):
 
290
  attention_mask: Optional[torch.BoolTensor] = None,
291
  **kwargs,
292
  ) -> torch.FloatTensor:
293
+ causal = self.causal if causal is None else causal
294
+ batch_size, seq_len = qkv.shape[0], qkv.shape[1]
295
  q, k, v = qkv.unbind(dim=2)
296
 
 
297
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
 
298
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
299
 
300
  if attention_mask is not None:
301
+ padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
302
  padding_mask.masked_fill_(attention_mask, 0.0)
303
 
304
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
305
 
306
  if causal:
307
+ causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
308
  scores = scores + causal_mask.to(dtype=scores.dtype)
309
 
310
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
 
343
  attention_mask: Optional[torch.BoolTensor] = None,
344
  **kwargs,
345
  ) -> torch.FloatTensor:
346
+ causal = self.causal if causal is None else causal
347
+ batch_size, seq_len_q = q.shape[0], q.shape[1]
348
+ assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
349
 
350
+ seq_len_k = kv.shape[1]
 
351
  k, v = kv.unbind(dim=2)
352
 
 
353
  softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
 
354
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
355
 
356
  if attention_mask is not None:
357
+ padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
358
  padding_mask.masked_fill_(attention_mask, 0.0)
359
 
360
  scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
361
 
362
  if causal:
363
+ causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
364
+ scores = scores + causal_mask.to(dtype=scores.dtype)
 
 
 
365
 
366
  attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
367
  attention = self.drop(attention)
 
371
  return output
372
 
373
 
374
+ def find_mha_dims(
375
+ config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
 
 
 
376
  ) -> Tuple[int, int]:
377
+ """Validate and return the number of heads and head dimension for multi-head attention.
378
+
379
+ Args:
380
+ config: Model configuration.
381
+ n_head: Number of heads.
382
+ head_dim: Head dimension.
383
+
384
+ Returns:
385
+ Number of heads and head dimension.
386
+
387
+ """
388
+
389
  assert all(
390
  hasattr(config, attr) for attr in ["n_embd", "n_head"]
391
  ), "`config` must have `n_embd` and `n_head` attributes."
 
401
  elif n_head is None or head_dim is None:
402
  raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
403
 
404
+ return n_head, head_dim
 
 
405
 
 
406
 
407
+ def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
408
+ """Update the key-value cache for inference.
409
+
410
+ Reference:
411
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
412
+
413
+ Args:
414
+ kv: Key-value tensor.
415
+ inference_params: Inference parameters.
416
+ layer_idx: Layer index.
417
+
418
+ Returns:
419
+ Updated key-value tensor.
420
+
421
+ """
422
 
 
423
  num_heads, head_dim = kv.shape[-2:]
424
 
425
  if layer_idx not in inference_params.key_value_memory_dict:
426
  kv_cache = torch.empty(
427
  inference_params.max_batch_size,
428
+ inference_params.max_sequence_len,
429
  2,
430
  num_heads,
431
  head_dim,
 
434
  )
435
  inference_params.key_value_memory_dict[layer_idx] = kv_cache
436
  else:
437
+ if not inference_params.fused_ft_kernel:
438
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
439
+ else:
440
+ k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
441
+ kv_cache = None
442
 
443
  batch_start = inference_params.batch_size_offset
444
  batch_end = batch_start + kv.shape[0]
445
+ assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
446
 
447
+ sequence_start = inference_params.sequence_len_offset
448
  sequence_end = sequence_start + kv.shape[1]
449
+ assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
450
 
451
+ if not inference_params.fused_ft_kernel:
452
+ assert kv_cache is not None
453
+
454
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
455
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
456
+
457
+ return kv
458
+
459
+ assert inference_params.sequence_len_offset == 0
460
+ assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
461
+
462
+ packsize = 4 if kv.dtype == torch.float32 else 8
463
+
464
+ if kv_cache is not None:
465
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
466
+ k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
467
+ v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
468
+ inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
469
+ else:
470
+ k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
471
+ kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
472
+ )
473
+ v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
474
 
475
  return kv
476
 
 
486
  rotary_dim: Optional[int] = None,
487
  rotary_emb_scale_base: Optional[float] = None,
488
  n_head: Optional[int] = None,
 
489
  head_dim: Optional[int] = None,
490
  bias: bool = True,
491
  causal: bool = True,
492
  softmax_scale: Optional[float] = None,
493
+ dropout: float = 0.0,
494
  layer_idx: Optional[int] = None,
495
  return_residual: bool = False,
496
  checkpointing: bool = False,
 
503
  rotary_kwargs = {"device": device}
504
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
505
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
506
+ self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
 
 
 
 
507
 
508
  # MLP
509
+ self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
510
+ op_size = self.n_head * self.head_dim
511
  hidden_size = config.n_embd
512
 
513
+ self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
514
+ self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
 
 
 
 
515
 
516
  # Attention
517
+ self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
518
+ self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
519
 
520
  self.layer_idx = layer_idx
521
  self.return_residual = return_residual
522
  self.checkpointing = checkpointing
523
 
524
+ def forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  self,
526
  x: torch.FloatTensor,
527
+ past_key_values: Optional[InferenceParams] = None,
528
+ attention_mask: Optional[torch.BoolTensor] = None,
529
+ cu_seqlens: Optional[torch.LongTensor] = None,
530
+ max_seqlen: Optional[int] = None,
531
+ **kwargs,
532
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
533
  qkv = self.Wqkv(x)
534
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
535
 
536
+ seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
 
 
 
 
 
 
 
537
  if self.rotary_emb_dim > 0:
538
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
539
 
540
  if past_key_values is not None:
541
+ kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
542
 
543
+ if attention_mask is not None:
544
+ attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
545
+ attention_mask = attention_mask.bool().to(qkv.device)
 
546
 
547
+ attention_kwargs = {"attention_mask": attention_mask}
548
 
549
+ if past_key_values is None or seqlen_offset == 0:
550
+ if self.checkpointing:
551
+ attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  else:
553
+ attn_output = self.inner_attn(qkv, **attention_kwargs)
 
 
 
554
  else:
555
+ q = qkv[:, :, 0]
556
+ causal = None if past_key_values.sequence_len_offset == 0 else False
557
+ attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
558
 
559
  output = rearrange(attn_output, "... h d -> ... (h d)")
560
  output = self.out_proj(output)
 
672
  if module.padding_idx is not None:
673
  module.weight.data[module.padding_idx].zero_()
674
  elif isinstance(module, nn.LayerNorm):
675
+ module.bias.data.zero_()
 
676
  module.weight.data.fill_(1.0)
677
 
678
  def prepare_inputs_for_generation(
679
  self,
680
  input_ids: torch.LongTensor,
681
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
682
+ attention_mask: Optional[torch.BoolTensor] = None,
683
  **kwargs,
684
  ) -> Dict[str, Any]:
685
+ if attention_mask is not None and torch.any(~attention_mask.bool()):
686
+ total_seq_len = torch.sum(attention_mask, dim=1)
687
+ max_seq_len = torch.max(total_seq_len)
688
+
689
+ total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
690
+ cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
691
+ attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
692
+ else:
693
+ attention_mask = None
694
+
695
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
696
  past_key_values = InferenceParams(
 
697
  max_batch_size=input_ids.shape[0],
698
+ max_sequence_len=self.config.n_positions,
699
+ sequence_len_offset=0,
700
  batch_size_offset=0,
701
+ fused_ft_kernel=False,
702
  key_value_memory_dict={},
 
703
  )
704
  else:
705
  # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
706
+ past_key_values.sequence_len_offset = len(input_ids[0]) - 1
707
  input_ids = input_ids[:, -1].unsqueeze(-1)
708
 
709
  return {
 
712
  "attention_mask": attention_mask,
713
  }
714
 
715
+ def _set_gradient_checkpointing(self, module, value=False):
716
+ if isinstance(module, MixFormerSequentialPreTrainedModel):
717
+ module.gradient_checkpointing = value
718
 
719
 
720
  class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
 
756
  labels: Optional[torch.LongTensor] = None,
757
  **kwargs,
758
  ) -> CausalLMOutputWithPast:
759
+ if attention_mask is not None and self.training:
760
+ print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")
761
+
762
+ if past_key_values is None and attention_mask is None:
763
+ lm_logits = self.layers(input_ids)
764
+ else:
765
+ hidden_layer = self.layers[0](input_ids)
766
+ for module in self.layers[1:-1]:
767
+ hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
768
+ lm_logits = self.layers[-1](hidden_layer)
769
 
770
  loss = None
771
  if labels is not None:
772
  loss = self.loss(lm_logits, labels)
773
 
774
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d121d287c708fc6d08043ed171921e4b9fb68d00f452c1d23ea1c55292bd1d5c
3
- size 5673168010
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eab6a12a9a2b78cac8f8975aea9f3a5e89ddadcb9e0dad27e40965e57e235a4a
3
+ size 2836623617
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 2048,
7
+ "tokenizer_class": "CodeGenTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff