File size: 1,091 Bytes
a446b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from comet.src.models.gpt import (LMModel, DEFAULT_CONFIG, load_openai_pretrained_model)
import torch.nn as nn


def make_model(opt, n_vocab, n_ctx, n_special, load=True,
               return_acts=True, return_probs=False,
               clf_token="<CLASS>", answer_size=None):
    print(n_ctx)
    if opt.exp == "generation":
        model = LMModel(
            opt.net, n_vocab, n_ctx, return_acts=return_acts,
            return_probs=return_probs)
    elif opt.exp == "classification":
        model = ClfModel(
            opt.net, n_vocab, n_ctx, clf_token, answer_size)
    if load:
        print("LOADING PRETRAINED TRANSFORMER")
        load_openai_pretrained_model(
            model.transformer, n_ctx=n_ctx, n_special=n_special)
    return model


def multi_gpu(model, devices):
    return nn.DataParallel(model, device_ids=devices)


def load_state_dict(model, state_dict):
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        new_state_dict = {i[len("module."):]: j for i, j in state_dict.items()}
        model.load_state_dict(new_state_dict)