patrickvonplaten commited on
Commit
7fd184b
·
1 Parent(s): 7e14fa6
model/dict.txt ADDED
The diff for this file is too large to render. See raw diff
 
model/gpt2-merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model/gpt2-vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
model/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model/restored.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04468adeae5a767f6bd9eebdcfeabb8c3b097029c6766e5ca839e0ff3743a476
3
+ size 250622585
model/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}}
model/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"errors": "replace", "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "tokenizer_class": "GPT2Tokenizer"}
model/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
run.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/usr/bin/env bash
2
+ CUDA_VISIBLE_DEVICES="0" torchrun run_model.py --pipeline-model-parallel-size 1 --tensor-model-parallel-size 1
run_model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #!/usr/bin/env python3
3
+ import os
4
+ from transformers import AutoTokenizer, GPT2Tokenizer
5
+ from megatron.initialize import initialize_megatron
6
+ from metaseq import checkpoint_utils
7
+ import torch
8
+
9
+ path = "./model"
10
+
11
+ # just need to initialize args with something,
12
+ # => doesn't need to correspond to the "correct" architecture for this checkpoint
13
+ initialize_megatron(args_defaults={
14
+ "micro_batch_size": 1,
15
+ "num_layers": 12,
16
+ "hidden_size": 768,
17
+ "num_attention_heads": 12,
18
+ "max_position_embeddings": 2048,
19
+ "encoder_seq_length": 2048
20
+ })
21
+
22
+ vocab_file = os.path.join(path, "gpt2-vocab.json")
23
+ merges_file = os.path.join(path, "gpt2-merges.txt")
24
+
25
+ tokenizer = GPT2Tokenizer(vocab_file, merges_file)
26
+ tokenizer.save_pretrained(path)
27
+
28
+ checkpoint = checkpoint_utils.load_model_ensemble_and_task(
29
+ [os.path.join(path, "restored.pt")],
30
+ arg_overrides={
31
+ "vocab_filename": vocab_file,
32
+ "merges_filename": merges_file,
33
+ }
34
+ )
35
+
36
+ model = checkpoint[0][0].eval()
37
+ model = model.cuda().half()
38
+
39
+
40
+ # forward passes
41
+ def single_batch_forward_logits(prompts):
42
+ input_ids = tokenizer(prompts, return_tensors="pt").input_ids
43
+ input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1)
44
+ input_ids = input_ids.cuda()
45
+ with torch.no_grad():
46
+ logits = model(input_ids)[0]
47
+ return logits
48
+
49
+ prompts = [
50
+ "Today is a beautiful day and I want to",
51
+ "In the city of",
52
+ "Paris is the capital of France and",
53
+ "Computers and mobile phones have taken",
54
+ ]
55
+
56
+ print("Next word generation")
57
+ for prompt in prompts:
58
+ print("-------------")
59
+ print(f"Prompt: {prompt}...\n")
60
+ logits = single_batch_forward_logits(prompt)
61
+ pred_next_token = torch.argmax(logits[0, -1], -1)
62
+ next_token = tokenizer.convert_ids_to_tokens([pred_next_token])
63
+ next_token = next_token[0].replace("Ġ", "")
64
+ print(f"Next word: {next_token}")
65
+ print("-------------")