OedoSoldier commited on
Commit
2408ec5
Β·
1 Parent(s): 021a5ae
Files changed (1) hide show
  1. app.py +3 -17
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import argparse
2
  import os
3
  import json
4
  import datetime
@@ -7,30 +6,17 @@ import torch
7
  from transformers import AutoTokenizer, AutoModel
8
 
9
 
10
- def get_args():
11
- parser = argparse.ArgumentParser(description='ChatGLM Arguments')
12
-
13
- parser.add_argument('--path', default='THUDM/chatglm-6b-int4', help='The path of ChatGLM model')
14
-
15
- return parser.parse_args()
16
-
17
-
18
- args = get_args()
19
-
20
- if not os.path.isdir(args.path):
21
- raise FileNotFoundError('Model not found')
22
-
23
  if torch.cuda.is_available():
24
  device = 'cuda'
25
  else:
26
  device = 'cpu'
27
 
28
- tokenizer = AutoTokenizer.from_pretrained(args.path, trust_remote_code=True)
29
 
30
  if device == 'cuda':
31
- model = AutoModel.from_pretrained(args.path, trust_remote_code=True).half().cuda()
32
  else:
33
- model = AutoModel.from_pretrained(args.path, trust_remote_code=True).float()
34
 
35
  model = model.eval()
36
 
 
 
1
  import os
2
  import json
3
  import datetime
 
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  if torch.cuda.is_available():
10
  device = 'cuda'
11
  else:
12
  device = 'cpu'
13
 
14
+ tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm-6b-int4', trust_remote_code=True)
15
 
16
  if device == 'cuda':
17
+ model = AutoModel.from_pretrained('THUDM/chatglm-6b-int4', trust_remote_code=True).half().cuda()
18
  else:
19
+ model = AutoModel.from_pretrained('THUDM/chatglm-6b-int4', trust_remote_code=True).float()
20
 
21
  model = model.eval()
22