File size: 2,808 Bytes
402c662
 
 
 
6c88d0a
 
 
 
402c662
 
 
 
05abaae
402c662
492f975
 
402c662
 
 
 
 
 
 
 
c611b16
 
05abaae
c611b16
52b97b4
c611b16
402c662
 
748b107
402c662
c611b16
402c662
 
 
 
 
 
 
6c88d0a
 
402c662
 
 
 
 
6c88d0a
 
 
 
1ccd59a
6c88d0a
 
402c662
6c88d0a
 
 
 
c562a5d
402c662
 
 
 
 
 
 
c57a7ee
402c662
 
 
 
 
 
 
 
0580eba
402c662
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch

import gradio as gr
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# from transformers.generation.utils import GenerationConfig

from utils import load_hyperparam, load_model
from models.tokenize import Tokenizer
from models.llama import *
from generate import LmGeneration
from huggingface_hub import hf_hub_download

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

args = None
lm_generation = None

def init_args():
    global args
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    args = parser.parse_args()
    args.load_model_path = 'Linly-AI/ChatFlow-13B'
    #args.load_model_path = 'Linly-AI/ChatFlow-7B'
    # args.load_model_path = './model_file/chatllama_7b.bin'
    #args.config_path = './config/llama_7b.json'
    #args.load_model_path = './model_file/chatflow_13b.bin'
    args.config_path = './config/llama_13b_config.json'
    args.spm_model_path = './model_file/tokenizer.model'
    args.batch_size = 1
    args.seq_length = 1024
    args.world_size = 1
    args.use_int8 = True
    args.top_p = 0
    args.repetition_penalty_range = 1024
    args.repetition_penalty_slope = 0
    args.repetition_penalty = 1.15

    args = load_hyperparam(args)

    # args.tokenizer = Tokenizer(model_path=args.spm_model_path)
    args.tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Linly-ChatFlow", use_fast=False, trust_remote_code=True)
    args.vocab_size = args.tokenizer.sp_model.vocab_size()


def init_model():
    global lm_generation
    # torch.set_default_tensor_type(torch.HalfTensor)
    # model = LLaMa(args)
    # torch.set_default_tensor_type(torch.FloatTensor)
    # # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
    # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
    # model = load_model(model, args.load_model_path)
    # model.eval()

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model.to(device)
    model = AutoModelForCausalLM.from_pretrained("Linly-AI/Linly-ChatFlow", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
    
    print(torch.cuda.max_memory_allocated() / 1024 ** 3)
    lm_generation = LmGeneration(model, args.tokenizer)


def chat(prompt, top_k, temperature):
    args.top_k = int(top_k)
    args.temperature = temperature
    response = lm_generation.generate(args, [prompt])
    print('log:', response[0])
    return response[0]


if __name__ == '__main__':
    init_args()
    init_model()
    demo = gr.Interface(
        fn=chat,
        inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
        outputs="text",
    )
    demo.launch()