Minh Q. Le
Pushed COSMIC code
a446b0b
raw
history blame
1.09 kB
from comet.src.models.gpt import (LMModel, DEFAULT_CONFIG, load_openai_pretrained_model)
import torch.nn as nn
def make_model(opt, n_vocab, n_ctx, n_special, load=True,
return_acts=True, return_probs=False,
clf_token="<CLASS>", answer_size=None):
print(n_ctx)
if opt.exp == "generation":
model = LMModel(
opt.net, n_vocab, n_ctx, return_acts=return_acts,
return_probs=return_probs)
elif opt.exp == "classification":
model = ClfModel(
opt.net, n_vocab, n_ctx, clf_token, answer_size)
if load:
print("LOADING PRETRAINED TRANSFORMER")
load_openai_pretrained_model(
model.transformer, n_ctx=n_ctx, n_special=n_special)
return model
def multi_gpu(model, devices):
return nn.DataParallel(model, device_ids=devices)
def load_state_dict(model, state_dict):
try:
model.load_state_dict(state_dict)
except RuntimeError:
new_state_dict = {i[len("module."):]: j for i, j in state_dict.items()}
model.load_state_dict(new_state_dict)