File size: 883 Bytes
591004d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#!python
# -*- coding: utf-8 -*-
# @author: Kun

import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
from flagai.model.predictor.aquila import aquila_generate
from flagai.data.tokenizer import Tokenizer
import bminf



max_token: int = 128 # 10000 # 64 
temperature: float = 0.75
top_p = 0.9

state_dict = "./checkpoints_in"
model_name = 'aquilachat-7b'

def load_model():
    loader = AutoLoader(
        "lm",
        model_dir=state_dict,
        model_name=model_name,
        use_cache=True,
        fp16=True)
    model = loader.get_model()
    tokenizer = loader.get_tokenizer()
    cache_dir = os.path.join(state_dict, model_name)

    model.eval()

    with torch.cuda.device(0):
        model = bminf.wrapper(model, quantization=False, memory_limit=2 << 30)
        
    return tokenizer, model