nanogpt / infer.py
anantgupta129's picture
init application
0388426
import os
import pickle
from contextlib import nullcontext
import torch
from model import GPTConfig, GPT
device = 'cpu'
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
ctx = nullcontext()
ckpt_path = 'ckpt.pt'
checkpoint = torch.load(ckpt_path, map_location='cpu')
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
# model = torch.compile(model) # requires PyTorch 2.0 (optional)
print("model loaded !!")
meta_path = 'meta.pkl'
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
def run(prompt):
input_ids = encode(prompt)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device)[None, ...]
with torch.no_grad():
with ctx:
y = model.generate(input_ids, max_new_tokens, temperature=temperature, top_k=top_k)
response = decode(y[0].tolist())
return response