yuhaofeng-shiba commited on
Commit
1ccd59a
1 Parent(s): 05abaae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -18,16 +18,17 @@ def init_args():
18
  global args
19
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
  args = parser.parse_args()
21
- args.load_model_path = 'Linly-AI/ChatFlow-13B'
 
22
  # args.load_model_path = './model_file/chatllama_7b.bin'
23
- # args.config_path = './config/llama_7b.json'
24
  #args.load_model_path = './model_file/chatflow_13b.bin'
25
- args.config_path = './config/llama_13b_config.json'
26
  args.spm_model_path = './model_file/tokenizer.model'
27
  args.batch_size = 1
28
  args.seq_length = 1024
29
  args.world_size = 1
30
- args.use_int8 = True
31
  args.top_p = 0
32
  args.repetition_penalty_range = 1024
33
  args.repetition_penalty_slope = 0
@@ -44,7 +45,8 @@ def init_model():
44
  torch.set_default_tensor_type(torch.HalfTensor)
45
  model = LLaMa(args)
46
  torch.set_default_tensor_type(torch.FloatTensor)
47
- args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
 
48
  model = load_model(model, args.load_model_path)
49
  model.eval()
50
 
 
18
  global args
19
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
  args = parser.parse_args()
21
+ # args.load_model_path = 'Linly-AI/ChatFlow-13B'
22
+ args.load_model_path = 'Linly-AI/ChatFlow-7B'
23
  # args.load_model_path = './model_file/chatllama_7b.bin'
24
+ args.config_path = './config/llama_7b.json'
25
  #args.load_model_path = './model_file/chatflow_13b.bin'
26
+ # args.config_path = './config/llama_13b_config.json'
27
  args.spm_model_path = './model_file/tokenizer.model'
28
  args.batch_size = 1
29
  args.seq_length = 1024
30
  args.world_size = 1
31
+ args.use_int8 = False
32
  args.top_p = 0
33
  args.repetition_penalty_range = 1024
34
  args.repetition_penalty_slope = 0
 
45
  torch.set_default_tensor_type(torch.HalfTensor)
46
  model = LLaMa(args)
47
  torch.set_default_tensor_type(torch.FloatTensor)
48
+ # args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_13b.bin')
49
+ args.load_model_path = hf_hub_download(repo_id=args.load_model_path, filename='chatflow_7b.bin')
50
  model = load_model(model, args.load_model_path)
51
  model.eval()
52