wmpscc commited on
Commit
6c88d0a
1 Parent(s): 657102d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -2,6 +2,10 @@ import torch
2
 
3
  import gradio as gr
4
  import argparse
 
 
 
 
5
  from utils import load_hyperparam, load_model
6
  from models.tokenize import Tokenizer
7
  from models.llama import *
@@ -36,22 +40,25 @@ def init_args():
36
 
37
  args = load_hyperparam(args)
38
 
39
- args.tokenizer = Tokenizer(model_path=args.spm_model_path)
 
40
  args.vocab_size = args.tokenizer.sp_model.vocab_size()
41
 
42
 
43
  def init_model():
44
  global lm_generation
45
- torch.set_default_tensor_type(torch.HalfTensor)
46
- model = LLaMa(args)
47
- torch.set_default_tensor_type(torch.FloatTensor)
 
48
  # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
49
- args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
50
- model = load_model(model, args.load_model_path)
51
- model.eval()
52
 
53
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- model.to(device)
 
 
55
  print(torch.cuda.max_memory_allocated() / 1024 ** 3)
56
  lm_generation = LmGeneration(model, args.tokenizer)
57
 
 
2
 
3
  import gradio as gr
4
  import argparse
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ # from transformers.generation.utils import GenerationConfig
8
+
9
  from utils import load_hyperparam, load_model
10
  from models.tokenize import Tokenizer
11
  from models.llama import *
 
40
 
41
  args = load_hyperparam(args)
42
 
43
+ # args.tokenizer = Tokenizer(model_path=args.spm_model_path)
44
+ args.tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Linly-ChatFlow", use_fast=False, trust_remote_code=True)
45
  args.vocab_size = args.tokenizer.sp_model.vocab_size()
46
 
47
 
48
  def init_model():
49
  global lm_generation
50
+ # torch.set_default_tensor_type(torch.HalfTensor)
51
+ # model = LLaMa(args)
52
+ # torch.set_default_tensor_type(torch.FloatTensor)
53
+ # # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
54
  # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
55
+ # model = load_model(model, args.load_model_path)
56
+ # model.eval()
 
57
 
58
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ # model.to(device)
60
+ model = AutoModelForCausalLM.from_pretrained("Linly-AI/Linly-ChatFlow", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
61
+
62
  print(torch.cuda.max_memory_allocated() / 1024 ** 3)
63
  lm_generation = LmGeneration(model, args.tokenizer)
64