wmpscc commited on
Commit
b9ba04c
1 Parent(s): 4470b44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -16,7 +16,7 @@ def init_model():
16
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
17
  torch_dtype=torch.bfloat16, trust_remote_code=True)
18
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
19
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
20
  return model, tokenizer, streamer
21
 
22
 
@@ -30,14 +30,17 @@ def process(message, history):
30
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048, do_sample=True,
31
  top_k=20, top_p=0.84, temperature=1.0, repetition_penalty=1.15, eos_token_id=2,
32
  bos_token_id=1, pad_token_id=0)
33
-
34
- t = Thread(target=model.generate, kwargs=generation_kwargs)
35
- t.start()
36
- response = ""
37
- for text in streamer:
38
- response += text
39
- yield response
40
- print('log:', response)
 
 
 
41
 
42
 
43
  if __name__ == '__main__':
 
16
  model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
17
  torch_dtype=torch.bfloat16, trust_remote_code=True)
18
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
19
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30.)
20
  return model, tokenizer, streamer
21
 
22
 
 
30
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048, do_sample=True,
31
  top_k=20, top_p=0.84, temperature=1.0, repetition_penalty=1.15, eos_token_id=2,
32
  bos_token_id=1, pad_token_id=0)
33
+ try:
34
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
35
+ t.start()
36
+ response = ""
37
+ for text in streamer:
38
+ response += text
39
+ yield response
40
+ print('-log:', response)
41
+ except Exception as e:
42
+ print('-error:', str(e))
43
+ return "Error: 遇到错误,请开启新的会话重新尝试~"
44
 
45
 
46
  if __name__ == '__main__':