ducknew commited on
Commit
d2a326b
·
1 Parent(s): cbab35c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -19,15 +19,22 @@ from langchain.prompts import PromptTemplate
19
  from langchain.prompts.prompt import PromptTemplate
20
  from langchain.chat_models import ChatOpenAI
21
 
 
 
22
 
23
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
24
- model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
25
- model = model.eval()
 
 
26
 
27
 
28
  def chat_glm(input, history=None):
29
  if history is None:
30
  history = []
 
 
 
31
  response, history = model.chat(tokenizer, input, history)
32
  logger.info("chatglm:", input,response)
33
  return history, history
@@ -78,7 +85,7 @@ def chat_gpt(input, use_web, history=None):
78
  return_source_documents=True
79
  )
80
 
81
- result = qa({"query": query, "chat_history": history})
82
  logger.info("chatgpt:", input,result)
83
  return result["answer"]
84
 
 
19
  from langchain.prompts.prompt import PromptTemplate
20
  from langchain.chat_models import ChatOpenAI
21
 
22
+ LOAD_MODEL=False
23
+ tokenizer,model = None,None
24
 
25
+ def load_model():
26
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
27
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
28
+ model = model.eval()
29
+ return tokenizer,model
30
 
31
 
32
  def chat_glm(input, history=None):
33
  if history is None:
34
  history = []
35
+ if not LOAD_MODEL:
36
+ LOAD_MODEL=True
37
+ tokenizer,model = load_model()
38
  response, history = model.chat(tokenizer, input, history)
39
  logger.info("chatglm:", input,response)
40
  return history, history
 
85
  return_source_documents=True
86
  )
87
 
88
+ result = qa({"query": input, "chat_history": history})
89
  logger.info("chatgpt:", input,result)
90
  return result["answer"]
91