mariordoniez commited on
Commit
fc863b3
1 Parent(s): d92db69

Upload 9 files

Browse files
README.md CHANGED
@@ -1,3 +1,148 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: other
3
+ base_model: microsoft/phi-1_5
4
+ tags:
5
+ - generated_from_trainer
6
+ - sales
7
+ model-index:
8
+ - name: salesGPT_v2
9
+ results: []
10
+ datasets:
11
+ - goendalf666/sales-conversations-2
12
+ - goendalf666/sales-conversations-instruction-ext
13
+ - goendalf666/sales-conversations-instruction-base
14
+ - goendalf666/sales-textbook_for_convincing_and_selling
15
+ language:
16
+ - en
17
+ pipeline_tag: text-generation
18
  ---
19
+
20
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
21
+ should probably proofread and complete it, then remove this comment. -->
22
+
23
+ # salesGPT_v2
24
+
25
+ **Model Card for salesGPT_v2**
26
+
27
+ ### Model Description
28
+ salesGPT_v2, derived from microsoft/phi-1_5, is specialized in simulating sales conversations, wherein it understands customer requirements, manages objections, and suggests suitable products or services. It was fine-tuned on a variety of sales-related datasets and seems proficient in initiating conversations, asking pertinent questions, and sustaining interactive dialogues with users.
29
+
30
+ ### Related Ressources
31
+
32
+ Github: https://github.com/tom813/salesGPT_foundation
33
+ salesGPT_v1: https://huggingface.co/goendalf666/salesGPT_v1
34
+
35
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63797fcb2cb50dda39d8aec6/re7MmsaYNzTYVH2jEXDDu.png)
36
+
37
+ ### Intended Uses & Limitations
38
+ **Intended Uses:**
39
+ - Simulating sales conversations for training or evaluation purposes.
40
+ - Providing guidelines or suggested dialogues for sales representatives.
41
+
42
+ **Limitations:**
43
+ - The model might repetitively ask questions in certain scenarios.
44
+ - May struggle with handling customers who lack specific preferences or knowledge about products.
45
+ - The objection handling could be more focused on convincing techniques rather than objective criteria.
46
+ - Challenges in providing appropriate suggestions for customers without specific needs.
47
+ - Limited effectiveness in handling financial and budgetary conversations or sensitivities.
48
+
49
+ ### Training and Evaluation Data
50
+ **Training Data:**
51
+ 1. **Textbook v1 Dataset**
52
+ - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling)
53
+ - Content: Textbook content for sales, derived from structural points and detailed subpoints created through API calls.
54
+
55
+ 2. **Sales Conversation Dataset**
56
+ - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations)
57
+ - Content: Sales conversations, generated based on the chapters of the textbook.
58
+
59
+ 3. **Sales Conversations Instruction Base Dataset**
60
+ - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-base)
61
+ - Content: Extended sales conversations with structured dialogues.
62
+
63
+ 4. **Sales Conversations Instruction Extension Dataset**
64
+ - URL: [Dataset](https://huggingface.co/datasets/goendalf666/sales-conversations-instruction-ext)
65
+ - Content: Updates based on real conversations with the model to improve its proficiency in unconvincing cases.
66
+
67
+ **Evaluation Data:**
68
+ - More information is needed regarding how and where the model was evaluated. If it was assessed on a separate test set, providing access and details to that dataset would be crucial.
69
+
70
+ ### Training Procedure
71
+ Fine-tuning of salesGPT_v2 was executed in three phases using the LoRa approach with Rank 64:
72
+ 1. Training on a textbook for 20k steps.
73
+ 2. Training on sales conversations for 40k steps, resulting in salesGPT_v1.
74
+ 3. Training on sales conversations instruction for 40k steps, evolving into salesGPT_v2.
75
+
76
+ Hyperparameters used during training:
77
+ - Learning rate: 0.0002
78
+ - Train batch size: 2
79
+ - Eval batch size: 8
80
+ - Seed: 42
81
+ - Optimizer: Adam
82
+ - LR scheduler type: Cosine
83
+ - Training steps: 40,000
84
+ - [Additional parameters...](https://huggingface.co/goendalf666/salesGPT_v2)
85
+
86
+ Model size details:
87
+ - Trainable params: 18,874,368
88
+ - All params: 1,437,145,088
89
+ - Trainable%: 1.313323766514519
90
+
91
+ ### Results
92
+ While salesGPT_v2 showcases notable capabilities like initiating a structured dialogue and asking relevant questions, there are some key areas of improvement and observations:
93
+ - The model tends to keep asking questions if a user provides vague information.
94
+ - In some cases, it's not assertive enough to suggest a product when a customer lacks preferences.
95
+ - It tends to leverage objective criteria over convincing strategies during objections.
96
+ - Needs better handling of conversations regarding budgetary or financial concerns.
97
+
98
+ ### Example Usage & Conversations
99
+ #### Example 1:
100
+ **Customer**: I am interested in buying a new laptop.
101
+ **salesGPT_v2**: I understand. Could you please tell me more about your specific requirements and budget?
102
+
103
+ #### Example 2:
104
+ **Customer**: Oh I don't know how I should explain this to my wife. I said we can't go to vacation because we don't have any money.
105
+ **salesGPT_v2**: I see. Can you please provide more details about your financial situation and the specific expenses you are concerned about?
106
+
107
+ These examples underline the model's capability to engage in sales dialogues and also hint at its limitations, such as possibly prying too much into a customer's personal financial situation.
108
+
109
+ ### Future Directions & Improvement
110
+ - Enhancement in handling objections by focusing more on persuasive techniques and emotional intelligence.
111
+ - Improving suggestion-making capability especially when customers are indecisive.
112
+ - Better navigation through the conversation that involves budgetary and financial aspects without seeming insensitive or intrusive.
113
+ - Striking a balance between being informative and being too technical in its product descriptions.
114
+ - Possible implementation of more ethical and privacy-guided conversation guidelines, especially in discussing customers' financial capacities.
115
+
116
+ ### Ethical Considerations
117
+ The model’s tendency to repeatedly ask for specific information, especially related to personal financial details, raises ethical concerns regarding privacy and data sensitivity. Care must be taken to ensure the model respects user privacy and does not persistently probe for personal or sensitive information.
118
+
119
+ ### Conclusion
120
+ salesGPT_v2 offers a foundation for simulating sales conversations with potential for future refinement in handling objections, making product suggestions, and managing conversations delicately around financial discussions. Future versions might seek to refine its balance between being convincingly persuasive and remaining ethically and emotionally intelligent within dialogues.
121
+
122
+ ### Inference
123
+
124
+ ```
125
+ from transformers import AutoModelForCausalLM, AutoTokenizer
126
+
127
+ # Initialize the model and tokenizer
128
+ cuda = "cuda:0" if torch.cuda.is_available() else ""
129
+ model = AutoModelForCausalLM.from_pretrained("goendalf666/salesGPT_v2", trust_remote_code=True, torch_dtype=torch.float32, device_map={"":0})
130
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True, device_map={"":0})
131
+
132
+ inputs = tokenizer(conversation_text, return_tensors="pt", return_attention_mask=False)
133
+ inputs.to(cuda)
134
+
135
+ # Generate response
136
+ outputs = model.generate(**inputs, max_length=512)
137
+ response_text = tokenizer.batch_decode(outputs)[0]
138
+ ```
139
+ Or
140
+
141
+ Inference script: https://github.com/tom813/salesGPT_foundation/blob/main/inference.py
142
+
143
+ ### Framework versions
144
+
145
+ - Transformers 4.32.1
146
+ - Pytorch 2.1.0.dev20230829+cu121
147
+ - Datasets 2.14.5
148
+ - Tokenizers 0.13.3
adapter_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": null,
3
+ "base_model_name_or_path": "microsoft/phi-1_5",
4
+ "bias": "none",
5
+ "fan_in_fan_out": false,
6
+ "inference_mode": true,
7
+ "init_lora_weights": true,
8
+ "layers_pattern": null,
9
+ "layers_to_transform": null,
10
+ "lora_alpha": 16,
11
+ "lora_dropout": 0.05,
12
+ "modules_to_save": null,
13
+ "peft_type": "LORA",
14
+ "r": 64,
15
+ "target_modules": [
16
+ "Wqkv",
17
+ "out_proj"
18
+ ],
19
+ "task_type": "CAUSAL_LM"
20
+ }
adapter_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e896326bf0827004d90f3ddd361d14ec98a8bf8e62aa1b490f90eab86cc9e10
3
+ size 75531342
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "mariordoniez/phi",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "MixFormerSequentialForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "mariordoniez/phi--configuration_mixformer_sequential.MixFormerSequentialConfig",
9
+ "AutoModelForCausalLM": "mariordoniez/phi--modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
10
+ },
11
+ "embd_pdrop": 0.0,
12
+ "initializer_range": 0.02,
13
+ "layer_norm_epsilon": 1e-05,
14
+ "model_type": "mixformer-sequential",
15
+ "n_embd": 2048,
16
+ "n_head": 32,
17
+ "n_inner": null,
18
+ "n_layer": 24,
19
+ "n_positions": 2048,
20
+ "resid_pdrop": 0.0,
21
+ "rotary_dim": 32,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.32.1",
25
+ "vocab_size": 51200
26
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.32.1"
4
+ }
modeling_mixformer_sequential.py ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+ #
4
+ # BSD 3-Clause License
5
+ #
6
+ # Copyright (c) 2022, Tri Dao, [email protected].
7
+ # All rights reserved.
8
+ #
9
+ # Redistribution and use in source and binary forms, with or without
10
+ # modification, are permitted provided that the following conditions are met:
11
+ #
12
+ # * Redistributions of source code must retain the above copyright notice, this
13
+ # list of conditions and the following disclaimer.
14
+ #
15
+ # * Redistributions in binary form must reproduce the above copyright notice,
16
+ # this list of conditions and the following disclaimer in the documentation
17
+ # and/or other materials provided with the distribution.
18
+ #
19
+ # * Neither the name of the copyright holder nor the names of its
20
+ # contributors may be used to endorse or promote products derived from
21
+ # this software without specific prior written permission.
22
+ #
23
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
+
34
+ from __future__ import annotations
35
+
36
+ import math
37
+ from typing import Any, Dict, Optional, Tuple, Union
38
+ from dataclasses import dataclass, field
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+
43
+ from einops import rearrange, repeat
44
+ from transformers.activations import ACT2FN
45
+ from transformers import PretrainedConfig, PreTrainedModel
46
+ from transformers.modeling_outputs import CausalLMOutputWithPast
47
+
48
+ from .configuration_mixformer_sequential import MixFormerSequentialConfig
49
+
50
+
51
+ try:
52
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
53
+ from flash_attn.ops.fused_dense import FusedDense
54
+ except:
55
+ FlashRotaryEmbedding = None
56
+ FusedDense = None
57
+
58
+
59
+ @dataclass
60
+ class InferenceParams:
61
+ """Inference parameters passed to model to efficiently calculate
62
+ and store context during inference.
63
+
64
+ Reference:
65
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
66
+
67
+ Args:
68
+ max_seqlen: Maximum sequence length.
69
+ max_batch_size: Maximum batch size.
70
+ seqlen_offset: Sequence length offset.
71
+ batch_size_offset: Batch size offset.
72
+ key_value_memory_dict: Key value memory dictionary.
73
+ lengths_per_sample: Lengths per sample.
74
+
75
+ """
76
+
77
+ max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
78
+
79
+ max_batch_size: int = field(metadata={"help": "Maximum batch size."})
80
+
81
+ seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
82
+
83
+ batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
84
+
85
+ key_value_memory_dict: Dict[str, Any] = field(
86
+ default_factory=dict, metadata={"help": "Key value memory dictionary."}
87
+ )
88
+
89
+ lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
90
+
91
+
92
+ class Embedding(nn.Module):
93
+ """Token embedding with dropout."""
94
+
95
+ def __init__(self, config: PretrainedConfig) -> None:
96
+ super().__init__()
97
+
98
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
99
+ self.drop = nn.Dropout(config.embd_pdrop)
100
+
101
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
102
+ input_shape = input_ids.size()
103
+ input_ids = input_ids.view(-1, input_shape[-1])
104
+
105
+ hidden_states = self.wte(input_ids)
106
+ hidden_states = self.drop(hidden_states)
107
+
108
+ return hidden_states
109
+
110
+
111
+ def _apply_rotary_emb(
112
+ x: torch.FloatTensor,
113
+ cos: torch.FloatTensor,
114
+ sin: torch.FloatTensor,
115
+ ) -> torch.FloatTensor:
116
+ _, seqlen, _, head_dim = x.shape
117
+ rotary_seqlen, rotary_dim = cos.shape
118
+ rotary_dim *= 2
119
+
120
+ assert rotary_dim <= head_dim
121
+ assert seqlen <= rotary_seqlen
122
+ assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
123
+
124
+ x_rot = x[:, :, :, :rotary_dim]
125
+ x_pass = x[:, :, :, rotary_dim:]
126
+
127
+ x1, x2 = x_rot.chunk(2, dim=-1)
128
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
129
+ x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
130
+
131
+ x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
132
+
133
+ return torch.cat([x_rot, x_pass], axis=-1)
134
+
135
+
136
+ def _apply_rotary_emb_kv(
137
+ kv: torch.FloatTensor,
138
+ cos: torch.FloatTensor,
139
+ sin: torch.FloatTensor,
140
+ cos_k: Optional[torch.FloatTensor] = None,
141
+ sin_k: Optional[torch.FloatTensor] = None,
142
+ ) -> torch.FloatTensor:
143
+ _, seqlen, two, _, head_dim = kv.shape
144
+ assert two == 2
145
+
146
+ rotary_seqlen, rotary_dim = cos.shape
147
+ rotary_dim *= 2
148
+ assert rotary_dim <= head_dim
149
+ assert seqlen <= rotary_seqlen
150
+ assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
151
+
152
+ k_rot = kv[:, :, 0, :, :rotary_dim]
153
+ k_pass = kv[:, :, 0, :, rotary_dim:]
154
+
155
+ k1, k2 = k_rot.chunk(2, dim=-1)
156
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
157
+ k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
158
+
159
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
160
+
161
+ return torch.cat(
162
+ [
163
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
164
+ kv[:, :, 1:2, :, :],
165
+ ],
166
+ axis=2,
167
+ )
168
+
169
+
170
+ def _apply_rotary_emb_qkv(
171
+ qkv: torch.FloatTensor,
172
+ cos: torch.FloatTensor,
173
+ sin: torch.FloatTensor,
174
+ cos_k: Optional[torch.FloatTensor] = None,
175
+ sin_k: Optional[torch.FloatTensor] = None,
176
+ ) -> torch.FloatTensor:
177
+ _, seqlen, three, _, head_dim = qkv.shape
178
+ assert three == 3
179
+
180
+ rotary_seqlen, rotary_dim = cos.shape
181
+ rotary_dim *= 2
182
+ assert rotary_dim <= head_dim
183
+ assert seqlen <= rotary_seqlen
184
+ assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
185
+
186
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
187
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
188
+
189
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
190
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
191
+
192
+ q1, q2 = q_rot.chunk(2, dim=-1)
193
+ k1, k2 = k_rot.chunk(2, dim=-1)
194
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
195
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
196
+
197
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
198
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
199
+
200
+ return torch.cat(
201
+ [
202
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
203
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
204
+ qkv[:, :, 2:3, :, :],
205
+ ],
206
+ axis=2,
207
+ )
208
+
209
+
210
+ class RotaryEmbedding(nn.Module):
211
+ """Rotary positional embedding (RoPE).
212
+
213
+ Reference:
214
+ RoFormer: Enhanced Transformer with Rotary Position Embedding.
215
+ https://arxiv.org/pdf/2104.09864.pdf.
216
+
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ dim: int,
222
+ base: int = 10000,
223
+ scale_base: Optional[float] = None,
224
+ pos_idx_in_fp32: bool = True,
225
+ device: Optional[str] = None,
226
+ **kwargs,
227
+ ) -> None:
228
+ super().__init__()
229
+
230
+ if scale_base is not None:
231
+ raise NotImplementedError
232
+
233
+ self.dim = dim
234
+ self.base = float(base)
235
+ self.scale_base = scale_base
236
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
237
+ self.device = device
238
+
239
+ # Generate and save the inverse frequency buffer (non-trainable)
240
+ inv_freq = self._compute_inv_freq(device)
241
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
242
+
243
+ # Generate and save the scale buffer (non-trainable)
244
+ scale = (
245
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
246
+ if scale_base is not None
247
+ else None
248
+ )
249
+ self.register_buffer("scale", scale, persistent=False)
250
+
251
+ self._seq_len_cached = 0
252
+ self._cos_cached = None
253
+ self._sin_cached = None
254
+ self._cos_k_cached = None
255
+ self._sin_k_cached = None
256
+
257
+ def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
258
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
259
+
260
+ def _update_cos_sin_cache(
261
+ self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
262
+ ) -> None:
263
+ # Reset the tables if sequence length has been chaned, if we are on a
264
+ # new device or if we are switching from inference mode to training
265
+ if (
266
+ seqlen > self._seq_len_cached
267
+ or self._cos_cached is None
268
+ or self._cos_cached.device != device
269
+ or self._cos_cached.dtype != dtype
270
+ or (self.training and self._cos_cached.is_inference())
271
+ ):
272
+ self._seq_len_cached = seqlen
273
+
274
+ # fp32 is preferred since the output of `torch.arange` can be quite large
275
+ # and bf16 would lose a lot of precision
276
+ if self.pos_idx_in_fp32:
277
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
278
+ if self.inv_freq.dtype != torch.float32:
279
+ inv_freq = self._compute_inv_freq(device=device)
280
+ else:
281
+ inv_freq = self.inv_freq
282
+ else:
283
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
284
+ inv_freq = self.inv_freq
285
+
286
+ # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
287
+ freqs = torch.outer(t, inv_freq)
288
+ if self.scale is None:
289
+ self._cos_cached = torch.cos(freqs).to(dtype)
290
+ self._sin_cached = torch.sin(freqs).to(dtype)
291
+ else:
292
+ power = (
293
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
294
+ ) / self.scale_base
295
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
296
+
297
+ # Force the scale multiplication to happen in fp32
298
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
299
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
300
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
301
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
302
+
303
+ def forward(
304
+ self,
305
+ qkv: torch.Tensor,
306
+ kv: Optional[torch.Tensor] = None,
307
+ seqlen_offset: int = 0,
308
+ max_seqlen: Optional[int] = None,
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ seqlen = qkv.shape[1]
311
+
312
+ if max_seqlen is not None:
313
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
314
+ else:
315
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
316
+
317
+ if kv is None:
318
+ return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
319
+ else:
320
+ q = _apply_rotary_emb(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
321
+ kv = _apply_rotary_emb_kv(kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
322
+
323
+ return q, kv
324
+
325
+
326
+ class MLP(nn.Module):
327
+ """Multi-Layer Perceptron.
328
+
329
+ Reference:
330
+ Attention Is All You Need.
331
+ https://arxiv.org/pdf/1706.03762.pdf.
332
+
333
+ """
334
+
335
+ def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None) -> None:
336
+ super().__init__()
337
+
338
+ act_fn = config.activation_function if act_fn is None else act_fn
339
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
340
+
341
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
342
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
343
+
344
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
345
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
346
+ self.act = ACT2FN[act_fn]
347
+
348
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
349
+ hidden_states = self.fc1(hidden_states)
350
+ hidden_states = self.act(hidden_states)
351
+ hidden_states = self.fc2(hidden_states)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class SelfAttention(nn.Module):
357
+ """Self-attention layer (compatible with PyTorch).
358
+
359
+ Reference:
360
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
361
+
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ causal: bool = True,
367
+ softmax_scale: Optional[float] = None,
368
+ attention_dropout: float = 0.0,
369
+ ) -> None:
370
+ super().__init__()
371
+
372
+ self.causal = causal
373
+ self.softmax_scale = softmax_scale
374
+ self.drop = nn.Dropout(attention_dropout)
375
+
376
+ def forward(
377
+ self,
378
+ qkv: torch.FloatTensor,
379
+ causal: bool = None,
380
+ attention_mask: Optional[torch.BoolTensor] = None,
381
+ **kwargs,
382
+ ) -> torch.FloatTensor:
383
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
384
+ q, k, v = qkv.unbind(dim=2)
385
+
386
+ causal = self.causal if causal is None else causal
387
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
388
+
389
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
390
+
391
+ if attention_mask is not None:
392
+ padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
393
+ padding_mask.masked_fill_(attention_mask, 0.0)
394
+
395
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
396
+
397
+ if causal:
398
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
399
+ scores = scores + causal_mask.to(dtype=scores.dtype)
400
+
401
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
402
+ attention = self.drop(attention)
403
+
404
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
405
+
406
+ return output
407
+
408
+
409
+ class CrossAttention(nn.Module):
410
+ """Cross-attention layer (compatible with PyTorch).
411
+
412
+ Reference:
413
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
414
+
415
+ """
416
+
417
+ def __init__(
418
+ self,
419
+ causal: bool = True,
420
+ softmax_scale: Optional[float] = None,
421
+ attention_dropout: float = 0.0,
422
+ ) -> None:
423
+ super().__init__()
424
+
425
+ self.causal = causal
426
+ self.softmax_scale = softmax_scale
427
+ self.drop = nn.Dropout(attention_dropout)
428
+
429
+ def forward(
430
+ self,
431
+ q: torch.FloatTensor,
432
+ kv: torch.FloatTensor,
433
+ causal: bool = None,
434
+ attention_mask: Optional[torch.BoolTensor] = None,
435
+ **kwargs,
436
+ ) -> torch.FloatTensor:
437
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
438
+ seqlen_k = kv.shape[1]
439
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
440
+
441
+ if kv.shape[3] != q.shape[2]:
442
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
443
+ k, v = kv.unbind(dim=2)
444
+
445
+ causal = self.causal if causal is None else causal
446
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
447
+
448
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
449
+
450
+ if attention_mask is not None:
451
+ padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
452
+ padding_mask.masked_fill_(attention_mask, 0.0)
453
+
454
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
455
+
456
+ if causal:
457
+ rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
458
+ cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
459
+ causal_mask = cols > rows + seqlen_k - seqlen_q
460
+
461
+ scores = scores.masked_fill(causal_mask, -10000.0)
462
+
463
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
464
+ attention = self.drop(attention)
465
+
466
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
467
+
468
+ return output
469
+
470
+
471
+ def _find_mha_dims(
472
+ config: PretrainedConfig,
473
+ n_head: Optional[int] = None,
474
+ n_head_kv: Optional[int] = None,
475
+ head_dim: Optional[int] = None,
476
+ ) -> Tuple[int, int]:
477
+ assert all(
478
+ hasattr(config, attr) for attr in ["n_embd", "n_head"]
479
+ ), "`config` must have `n_embd` and `n_head` attributes."
480
+
481
+ if head_dim is None:
482
+ assert (
483
+ config.n_embd % config.n_head == 0
484
+ ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
485
+
486
+ if n_head is None and head_dim is None:
487
+ head_dim = config.n_embd // config.n_head
488
+ n_head = config.n_head
489
+ elif n_head is None or head_dim is None:
490
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
491
+
492
+ if n_head_kv is None:
493
+ n_head_kv = getattr(config, "n_head_kv", None) or n_head
494
+ assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
495
+
496
+ return n_head, n_head_kv, head_dim
497
+
498
+
499
+ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
500
+ num_heads, head_dim = kv.shape[-2:]
501
+
502
+ if layer_idx not in inference_params.key_value_memory_dict:
503
+ kv_cache = torch.empty(
504
+ inference_params.max_batch_size,
505
+ inference_params.max_seqlen,
506
+ 2,
507
+ num_heads,
508
+ head_dim,
509
+ dtype=kv.dtype,
510
+ device=kv.device,
511
+ )
512
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
513
+ else:
514
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
515
+
516
+ batch_start = inference_params.batch_size_offset
517
+ batch_end = batch_start + kv.shape[0]
518
+ assert batch_end <= kv_cache.shape[0]
519
+
520
+ sequence_start = inference_params.seqlen_offset
521
+ sequence_end = sequence_start + kv.shape[1]
522
+ assert sequence_end <= kv_cache.shape[1]
523
+
524
+ assert kv_cache is not None
525
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
526
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
527
+
528
+ return kv
529
+
530
+
531
+ class MHA(nn.Module):
532
+ """Multi-head attention layer."""
533
+
534
+ def __init__(
535
+ self,
536
+ config: PretrainedConfig,
537
+ dtype: Optional[torch.dtype] = None,
538
+ device: Optional[str] = None,
539
+ rotary_dim: Optional[int] = None,
540
+ rotary_emb_scale_base: Optional[float] = None,
541
+ n_head: Optional[int] = None,
542
+ n_head_kv: Optional[int] = None,
543
+ head_dim: Optional[int] = None,
544
+ bias: bool = True,
545
+ causal: bool = True,
546
+ softmax_scale: Optional[float] = None,
547
+ layer_idx: Optional[int] = None,
548
+ return_residual: bool = False,
549
+ checkpointing: bool = False,
550
+ ) -> None:
551
+ super().__init__()
552
+
553
+ # Rotary embedding
554
+ self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
555
+ if self.rotary_emb_dim > 0:
556
+ rotary_kwargs = {"device": device}
557
+ if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
558
+ rotary_kwargs["scale_base"] = rotary_emb_scale_base
559
+
560
+ rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
561
+ if rotary_cls is None:
562
+ rotary_cls = RotaryEmbedding
563
+ self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
564
+
565
+ # MLP
566
+ self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
567
+ op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
568
+ hidden_size = config.n_embd
569
+
570
+ linear_cls = FusedDense if config.fused_dense else nn.Linear
571
+ if linear_cls is None:
572
+ linear_cls = nn.Linear
573
+
574
+ self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
575
+ self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
576
+
577
+ # Attention
578
+ self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
579
+ self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
580
+
581
+ self.layer_idx = layer_idx
582
+ self.return_residual = return_residual
583
+ self.checkpointing = checkpointing
584
+
585
+ def _forward_self_attn(
586
+ self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
587
+ ) -> torch.FloatTensor:
588
+ qkv = self.Wqkv(x)
589
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
590
+
591
+ if self.rotary_emb_dim > 0:
592
+ qkv = self.rotary_emb(qkv)
593
+
594
+ if self.checkpointing:
595
+ return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
596
+
597
+ return self.inner_attn(qkv, attention_mask=attention_mask)
598
+
599
+ def _forward_cross_attn(
600
+ self,
601
+ x: torch.FloatTensor,
602
+ past_key_values: Optional[InferenceParams],
603
+ attention_mask: Optional[torch.BoolTensor],
604
+ ) -> torch.FloatTensor:
605
+ qkv = self.Wqkv(x)
606
+
607
+ q = qkv[..., : self.n_head * self.head_dim]
608
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
609
+
610
+ kv = qkv[..., self.n_head * self.head_dim :]
611
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
612
+
613
+ seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
614
+ causal = None if seqlen_offset == 0 else False
615
+ if self.rotary_emb_dim > 0:
616
+ q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
617
+
618
+ if past_key_values is not None:
619
+ kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
620
+
621
+ if self.checkpointing:
622
+ return torch.utils.checkpoint.checkpoint(
623
+ self.inner_cross_attn, q, kv, attention_mask=attention_mask, causal=causal
624
+ )
625
+
626
+ return self.inner_cross_attn(q, kv, attention_mask=attention_mask, causal=causal)
627
+
628
+ def forward(
629
+ self,
630
+ x: torch.FloatTensor,
631
+ past_key_values: Optional[InferenceParams] = None,
632
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
633
+ **kwargs,
634
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
635
+ if attention_mask is not None and torch.any(~attention_mask.bool()):
636
+ attention_mask = attention_mask.bool()
637
+ else:
638
+ attention_mask = None
639
+
640
+ # MHA
641
+ if self.n_head == self.n_head_kv:
642
+ if past_key_values is None:
643
+ # If `past_key_values` are not supplied, we run self-attention
644
+ attn_output = self._forward_self_attn(x, attention_mask)
645
+ else:
646
+ # If `past_key_values` are supplied, it means that we might have cached values and
647
+ # could take advantage of cross-attention
648
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
649
+ # MQA / GQA
650
+ else:
651
+ # Regardless of `past_key_values` being supplied or not, it always use cross-attention
652
+ # because `q` and `kv` lengths might be different
653
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
654
+
655
+ output = rearrange(attn_output, "... h d -> ... (h d)")
656
+ output = self.out_proj(output)
657
+
658
+ return output if not self.return_residual else (output, x)
659
+
660
+
661
+ class ParallelBlock(nn.Module):
662
+ """Parallel block.
663
+
664
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
665
+
666
+ """
667
+
668
+ def __init__(
669
+ self,
670
+ config: PretrainedConfig,
671
+ block_idx: Optional[int] = None,
672
+ ) -> None:
673
+ super().__init__()
674
+
675
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
676
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
677
+ self.block_idx = block_idx
678
+
679
+ self.mixer = MHA(config, layer_idx=block_idx)
680
+ self.mlp = MLP(config)
681
+
682
+ def forward(
683
+ self,
684
+ hidden_states: torch.FloatTensor,
685
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
686
+ attention_mask: Optional[torch.BoolTensor] = None,
687
+ **kwargs,
688
+ ) -> torch.FloatTensor:
689
+ residual = hidden_states
690
+ hidden_states = self.ln(hidden_states)
691
+
692
+ attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask)
693
+ if isinstance(attn_outputs, tuple):
694
+ attn_outputs = attn_outputs[0]
695
+
696
+ attn_outputs = self.resid_dropout(attn_outputs)
697
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
698
+
699
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
700
+
701
+ return hidden_states
702
+
703
+
704
+ class CausalLMHead(nn.Module):
705
+ """Causal Language Modeling head.
706
+
707
+ Reference:
708
+ Improving Language Understanding by Generative Pre-Training.
709
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
710
+
711
+ """
712
+
713
+ def __init__(self, config: PretrainedConfig) -> None:
714
+ super().__init__()
715
+
716
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
717
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
718
+
719
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
720
+ hidden_states = self.ln(hidden_states)
721
+ logits = self.linear(hidden_states).to(torch.float32)
722
+
723
+ return logits
724
+
725
+
726
+ class CausalLMLoss(nn.Module):
727
+ """Causal Language Modeling loss.
728
+
729
+ Reference:
730
+ Improving Language Understanding by Generative Pre-Training.
731
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
732
+
733
+ """
734
+
735
+ def __init__(self, shift_labels: bool = True) -> None:
736
+ super().__init__()
737
+
738
+ self.shift_labels = shift_labels
739
+ self.loss_fct = nn.CrossEntropyLoss()
740
+
741
+ def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor:
742
+ if self.shift_labels:
743
+ logits = logits[..., :-1, :].contiguous()
744
+ labels = labels[..., 1:].contiguous()
745
+
746
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
747
+
748
+ return loss
749
+
750
+
751
+ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
752
+ """MixFormer (sequential for DeepSpeed) pre-trained model."""
753
+
754
+ config_class = MixFormerSequentialConfig
755
+ base_model_prefix = "transformer"
756
+ supports_gradient_checkpointing = True
757
+
758
+ def __init__(self, *inputs, **kwargs) -> None:
759
+ super().__init__(*inputs, **kwargs)
760
+
761
+ def _init_weights(self, module: nn.Module) -> None:
762
+ if isinstance(module, (nn.Linear,)):
763
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
764
+ if module.bias is not None:
765
+ module.bias.data.zero_()
766
+ elif isinstance(module, nn.Embedding):
767
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
768
+ if module.padding_idx is not None:
769
+ module.weight.data[module.padding_idx].zero_()
770
+ elif isinstance(module, nn.LayerNorm):
771
+ if module.bias is not None:
772
+ module.bias.data.zero_()
773
+ module.weight.data.fill_(1.0)
774
+
775
+ def prepare_inputs_for_generation(
776
+ self,
777
+ input_ids: torch.LongTensor,
778
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
779
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
780
+ **kwargs,
781
+ ) -> Dict[str, Any]:
782
+ if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
783
+ past_key_values = InferenceParams(
784
+ max_seqlen=self.config.n_positions,
785
+ max_batch_size=input_ids.shape[0],
786
+ seqlen_offset=0,
787
+ batch_size_offset=0,
788
+ key_value_memory_dict={},
789
+ lengths_per_sample=None,
790
+ )
791
+ else:
792
+ # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
793
+ past_key_values.seqlen_offset = len(input_ids[0]) - 1
794
+ input_ids = input_ids[:, -1].unsqueeze(-1)
795
+
796
+ return {
797
+ "input_ids": input_ids,
798
+ "past_key_values": past_key_values,
799
+ "attention_mask": attention_mask,
800
+ }
801
+
802
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
803
+ if isinstance(module, MixFormerSequentialPreTrainedModel):
804
+ module.gradient_checkpointing = value
805
+
806
+
807
+ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
808
+ """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
809
+
810
+ _keys_to_ignore_on_load_missing = [""]
811
+ _keys_to_ignore_on_load_unexpected = [r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
812
+ _no_split_modules = ["ParallelBlock"]
813
+
814
+ def __init__(self, config: MixFormerSequentialConfig) -> None:
815
+ super().__init__(config)
816
+
817
+ modules = [Embedding(config)]
818
+ modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
819
+ modules.append(CausalLMHead(config))
820
+
821
+ self.layers = nn.Sequential(*modules)
822
+ self.loss = CausalLMLoss()
823
+
824
+ self.post_init()
825
+
826
+ def get_input_embeddings(self) -> nn.Embedding:
827
+ return self.layers[0].wte
828
+
829
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
830
+ self.layers[0].wte = new_embeddings
831
+
832
+ def get_output_embeddings(self) -> nn.Linear:
833
+ return self.layers[-1].linear
834
+
835
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
836
+ self.layers[-1].linear = new_embeddings
837
+
838
+ def forward(
839
+ self,
840
+ input_ids: torch.LongTensor,
841
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
842
+ attention_mask: Optional[torch.BoolTensor] = None,
843
+ labels: Optional[torch.LongTensor] = None,
844
+ **kwargs,
845
+ ) -> CausalLMOutputWithPast:
846
+ hidden_layer = self.layers[0](input_ids)
847
+ for module in self.layers[1:-1]:
848
+ hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
849
+ lm_logits = self.layers[-1](hidden_layer)
850
+
851
+ loss = None
852
+ if labels is not None:
853
+ loss = self.loss(lm_logits, labels)
854
+
855
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d121d287c708fc6d08043ed171921e4b9fb68d00f452c1d23ea1c55292bd1d5c
3
+ size 5673168010
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4acc51c52c33dccf606b129498bc828aef164175ad72fb3176ccedff193d49b0
3
+ size 4536