LLM-foundry update June 16, 2023 22:55:57
#51
by
daking
- opened
- custom_embedding.py +1 -2
- modeling_mpt.py +11 -1
custom_embedding.py
CHANGED
@@ -3,10 +3,9 @@ import torch.nn as nn
|
|
3 |
import torch.nn.functional as F
|
4 |
from torch import Tensor
|
5 |
|
6 |
-
|
7 |
class SharedEmbedding(nn.Embedding):
|
8 |
|
9 |
-
def forward(self, input: Tensor, unembed: bool
|
10 |
if unembed:
|
11 |
return F.linear(input, self.weight)
|
12 |
return super().forward(input)
|
|
|
3 |
import torch.nn.functional as F
|
4 |
from torch import Tensor
|
5 |
|
|
|
6 |
class SharedEmbedding(nn.Embedding):
|
7 |
|
8 |
+
def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
|
9 |
if unembed:
|
10 |
return F.linear(input, self.weight)
|
11 |
return super().forward(input)
|
modeling_mpt.py
CHANGED
@@ -40,6 +40,11 @@ class MPTModel(MPTPreTrainedModel):
|
|
40 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
41 |
self.alibi = config.attn_config['alibi']
|
42 |
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
|
|
|
|
|
|
|
|
|
|
43 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
44 |
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
|
45 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
@@ -47,7 +52,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
47 |
self.embedding_fraction = config.embedding_fraction
|
48 |
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
49 |
if not self.alibi:
|
50 |
-
self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
51 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
52 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
53 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
@@ -221,6 +226,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
221 |
if not config.tie_word_embeddings:
|
222 |
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
223 |
self.transformer = MPTModel(config)
|
|
|
|
|
|
|
|
|
|
|
224 |
self.logit_scale = None
|
225 |
if config.logit_scale is not None:
|
226 |
logit_scale = config.logit_scale
|
|
|
40 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
41 |
self.alibi = config.attn_config['alibi']
|
42 |
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
43 |
+
if config.init_device == 'mixed':
|
44 |
+
if dist.get_local_rank() == 0:
|
45 |
+
config.init_device = 'cpu'
|
46 |
+
else:
|
47 |
+
config.init_device = 'meta'
|
48 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
49 |
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
|
50 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
|
|
52 |
self.embedding_fraction = config.embedding_fraction
|
53 |
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
54 |
if not self.alibi:
|
55 |
+
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
56 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
57 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
58 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
|
|
226 |
if not config.tie_word_embeddings:
|
227 |
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
228 |
self.transformer = MPTModel(config)
|
229 |
+
for child in self.transformer.children():
|
230 |
+
if isinstance(child, torch.nn.ModuleList):
|
231 |
+
continue
|
232 |
+
if isinstance(child, torch.nn.Module):
|
233 |
+
child._fsdp_wrap = True
|
234 |
self.logit_scale = None
|
235 |
if config.logit_scale is not None:
|
236 |
logit_scale = config.logit_scale
|