sunday-hao commited on
Commit
a7acec9
·
verified ·
1 Parent(s): 945fb66

Upload 2 files

Browse files
scripts/inference_without_vllm.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+
3
+ model_path = "../model_weights"
4
+ model = AutoModelForCausalLM.from_pretrained(
5
+ model_path,
6
+ torch_dtype="auto",
7
+ device_map="auto"
8
+ )
9
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
10
+
11
+ messages = [
12
+ {"role": "system", "content": "下午好!这里是曜影医疗预约中心,我是Lisa,请问有什么可以帮您?"}
13
+ #{"role": "user", "content": prompt}
14
+ ]
15
+ #prompt = ['那首次购买"感染及持续咳嗽诊断套餐"的价格呢?']
16
+ #费用
17
+ prompt = ["能告诉我你们全科问诊收费价格吗?",'那首次购买感染及持续咳嗽诊断套餐的价格呢?',"能告诉我你们专科问诊收费价格吗?",
18
+ '能告诉我你们急诊收费价格吗?','能告诉我你们心理科&精神科问诊收费价格吗?','能告诉我你们康复理疗收费价格吗?',
19
+ '能告诉我你们整脊收费价格吗?']
20
+ #取消医生预约
21
+ #prompt = ['我想取消今天下午3点和王莹医生的预约','我觉得我已经好点了,不需要再看医生了。','谢谢']
22
+ #prompt = ['请介绍一下2021年美国总统是谁,以及他的生平事迹?']
23
+ #prompt = ['谢谢']
24
+ #取消手术/胃肠镜预约
25
+ #prompt = ['我需要取消我的手术预约']
26
+ #通用知识
27
+ #prompt = ['曜影医疗公司一共有哪些门诊部?','上海商城门诊部位于哪里?','天山门诊部位于哪里?','曜影医疗的服务模式是什么?']
28
+ #约全科
29
+ # prompt = ['我咳嗽3天了,现在情况越来越严重,需要看医生。','没有','我住在人民广场附近','有,下午4点吧',
30
+ # '我的生日是1990/1/1, 电话是19937679835','好的']
31
+ #约专科
32
+ # prompt = ['我今天早上鼻子出血了,很不舒服,需要看医生。','没有,就是早上刷牙时,鼻子开始流血。近半年都没有出现过。',
33
+ # '我在东方体育中心附近。','太好了,那我约医生看诊。','下午3点吧','我的生日是1990/1/1, 电话是19937679835',
34
+ # '好的']
35
+ # #约急诊
36
+ # prompt = ['我现在需要和医生通电话!','我感到呼吸困难,胸痛……','今天我打篮球被篮球砸到了胸口,现在胸口很难受。没有其他症状。',
37
+ # '好的,麻烦尽快帮我安排急诊。','我叫张梅,我的生日是1990/1/1, 电话是19937679835','好的']
38
+ response = ''
39
+ count = 1
40
+ for question in prompt:
41
+ messages.append({"role":"user", "content": question})
42
+
43
+ text = tokenizer.apply_chat_template(
44
+ messages,
45
+ tokenize=False,
46
+ add_generation_prompt=True
47
+ )
48
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
49
+ generated_ids = model.generate(
50
+ **model_inputs,
51
+ max_new_tokens=512
52
+ )
53
+ generated_ids = [
54
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
55
+ ]
56
+
57
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
58
+ messages.append({"role": "system", "content":response})
59
+ print("##第",count,"轮次##")
60
+ for message in messages:
61
+ print(message)
62
+ count +=1
scripts/with_vllm.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vllm import LLM, SamplingParams
2
+
3
+ model_path = "../model_weights"
4
+
5
+ model = LLM(model=model_path,tokenizer=model_path, dtype='bfloat16',seed=1234)
6
+ sampling_params = SamplingParams(repetition_penalty = 1.05,
7
+ temperature = 0.7,
8
+ top_k = 20,
9
+ top_p = 0.8,
10
+ max_tokens = 512)
11
+
12
+ messages = [
13
+ {"role": "system", "content": "下午好!这里是曜影医疗预约中心,我是Lisa,请问有什么可以帮您?"}
14
+ ]
15
+
16
+ #费用
17
+ prompts = ["能告诉我你们全科问诊收费价格吗?",'那首次购买感染及持续咳嗽诊断套餐的价格呢?',"能告诉我你们专科问诊收费价格吗?",
18
+ '能告诉我你们急诊收费价格吗?','能告诉我你们心理科&精神科问诊收费价格吗?','能告诉我你们康复理疗收费价格吗?',
19
+ '能告诉我你们整脊收费价格吗?']
20
+ # prompts = ['我现在需要和医生通电话!','我感到呼吸困难,胸痛……','今天我打篮球被篮球砸到了胸口,现在胸口很难受。没有其他症状。',
21
+ # '好的,麻烦尽快帮我安排急诊。','我叫张梅,我的生日是1990/1/1, 电话是19937679835','好的']
22
+ # prompts = ['我咳嗽3天了,现在情况越来越严重,需要看医生。','没有了','我住在人民广场附近','太好了,那我约医院看诊。',
23
+ # '下午3点吧','我的生日是1990/1/1, 电话是19937679835','好的']
24
+ # prompts = ['我今天早上鼻子出血了,很不舒服,需要看医生。','没有,就是早上刷牙时,鼻子开始流血。近半年都没有出现过。',
25
+ # '我在东方体育中心附近。','太好了,那我约医生看诊。','下午3点吧','我的生日是1990/1/1, 电话是19937679835',
26
+ # '好的']
27
+ #prompts = ['曜影医疗公司一共有哪些门诊部?','上海商城门诊部位于哪里?','天山门诊部位于哪里?','曜影医疗的服务模式是什么?']
28
+ #prompts = ['谢谢']
29
+ response = ''
30
+ count = 1
31
+ for question in prompts:
32
+ messages.append({"role":"user", "content": question})
33
+ response = model.chat(messages, add_generation_prompt=True, sampling_params=sampling_params)
34
+ print(response)
35
+ response = response[0].outputs[0].text
36
+ messages.append({"role": "system", "content":response})
37
+ print("##第",count,"轮次##")
38
+ for message in messages:
39
+ print(message)
40
+ count +=1