File size: 2,382 Bytes
402c662
 
 
 
 
 
 
 
05abaae
402c662
492f975
 
402c662
 
 
 
 
 
 
 
c611b16
 
05abaae
c611b16
52b97b4
c611b16
402c662
 
748b107
402c662
c611b16
402c662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a97f6
1ccd59a
 
402c662
 
 
 
 
c562a5d
402c662
 
 
 
 
 
 
c57a7ee
402c662
 
 
 
 
 
 
 
0580eba
402c662
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch

import gradio as gr
import argparse
from utils import load_hyperparam, load_model
from models.tokenize import Tokenizer
from models.llama import *
from generate import LmGeneration
from huggingface_hub import hf_hub_download

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

args = None
lm_generation = None

def init_args():
    global args
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    args = parser.parse_args()
    args.load_model_path = 'Linly-AI/ChatFlow-13B'
    #args.load_model_path = 'Linly-AI/ChatFlow-7B'
    # args.load_model_path = './model_file/chatllama_7b.bin'
    #args.config_path = './config/llama_7b.json'
    #args.load_model_path = './model_file/chatflow_13b.bin'
    args.config_path = './config/llama_13b_config.json'
    args.spm_model_path = './model_file/tokenizer.model'
    args.batch_size = 1
    args.seq_length = 1024
    args.world_size = 1
    args.use_int8 = True
    args.top_p = 0
    args.repetition_penalty_range = 1024
    args.repetition_penalty_slope = 0
    args.repetition_penalty = 1.15

    args = load_hyperparam(args)

    args.tokenizer = Tokenizer(model_path=args.spm_model_path)
    args.vocab_size = args.tokenizer.sp_model.vocab_size()


def init_model():
    global lm_generation
    torch.set_default_tensor_type(torch.HalfTensor)
    model = LLaMa(args)
    torch.set_default_tensor_type(torch.FloatTensor)
    # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
    args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_7b.bin')
    model = load_model(model, args.load_model_path)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(torch.cuda.max_memory_allocated() / 1024 ** 3)
    lm_generation = LmGeneration(model, args.tokenizer)


def chat(prompt, top_k, temperature):
    args.top_k = int(top_k)
    args.temperature = temperature
    response = lm_generation.generate(args, [prompt])
    print('log:', response[0])
    return response[0]


if __name__ == '__main__':
    init_args()
    init_model()
    demo = gr.Interface(
        fn=chat,
        inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
        outputs="text",
    )
    demo.launch()