article / models /aquila_fa.py
zxsipola123456's picture
Upload 33 files
591004d verified
raw
history blame contribute delete
No virus
883 Bytes
#!python
# -*- coding: utf-8 -*-
# @author: Kun
import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
from flagai.model.predictor.aquila import aquila_generate
from flagai.data.tokenizer import Tokenizer
import bminf
max_token: int = 128 # 10000 # 64
temperature: float = 0.75
top_p = 0.9
state_dict = "./checkpoints_in"
model_name = 'aquilachat-7b'
def load_model():
loader = AutoLoader(
"lm",
model_dir=state_dict,
model_name=model_name,
use_cache=True,
fp16=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
cache_dir = os.path.join(state_dict, model_name)
model.eval()
with torch.cuda.device(0):
model = bminf.wrapper(model, quantization=False, memory_limit=2 << 30)
return tokenizer, model