Spaces:
Running
Running
yjwtheonly
commited on
Commit
•
144e87a
1
Parent(s):
b450e5c
sever
Browse files- Openai/__pycache__/chat.cpython-38.pyc +0 -0
- server.py +99 -24
Openai/__pycache__/chat.cpython-38.pyc
CHANGED
Binary files a/Openai/__pycache__/chat.cpython-38.pyc and b/Openai/__pycache__/chat.cpython-38.pyc differ
|
|
server.py
CHANGED
@@ -11,6 +11,7 @@ import networkx as nx
|
|
11 |
import spacy
|
12 |
# os.system("python -m spacy download en-core-web-sm")
|
13 |
import pickle as pkl
|
|
|
14 |
#%%
|
15 |
# please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
16 |
# torch.loa
|
@@ -37,13 +38,13 @@ parser.add_argument('--init-mode', type = str, default='single', help = 'How to
|
|
37 |
args = parser.parse_args()
|
38 |
args = utils.set_hyperparams(args)
|
39 |
|
40 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
-
|
42 |
args.device = device
|
43 |
args.device1 = device
|
44 |
-
if torch.cuda.device_count() >= 2:
|
45 |
-
|
46 |
-
|
47 |
|
48 |
utils.seed_all(args.seed)
|
49 |
np.set_printoptions(precision=5)
|
@@ -77,6 +78,43 @@ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
|
77 |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
78 |
drug_term = pkl.load(fl)
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
drug_dict = {}
|
81 |
disease_dict = {}
|
82 |
for k, v in entity_raw_name.items():
|
@@ -193,7 +231,7 @@ def tune_chatgpt(draft, attack_data, dpath):
|
|
193 |
|
194 |
batch_size = 8
|
195 |
Outs = []
|
196 |
-
for l in range(0, len(Text), batch_size):
|
197 |
R = min(len(Text), l + batch_size)
|
198 |
A = bart_tokenizer(Text[l:R],
|
199 |
truncation = True,
|
@@ -211,8 +249,8 @@ def tune_chatgpt(draft, attack_data, dpath):
|
|
211 |
def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v):
|
212 |
|
213 |
criterion = CrossEntropyLoss(reduction="none")
|
214 |
-
text_s = entity_raw_name[id_to_meshid[s]]
|
215 |
-
text_o = entity_raw_name[id_to_meshid[o]]
|
216 |
|
217 |
sen_list = [server_utils.process(text) for text in sen_list]
|
218 |
path_text = dpath[0].replace('\n', '')
|
@@ -290,7 +328,7 @@ def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath,
|
|
290 |
attention_mask = tokens['attention_mask'].to(args.device1)
|
291 |
L = len(sen_list)
|
292 |
ret_log_L = []
|
293 |
-
for l in range(0, L, 5):
|
294 |
R = min(L, l + 5)
|
295 |
target = target_ids[l:R, :]
|
296 |
attention = attention_mask[l:R, :]
|
@@ -380,7 +418,7 @@ def generate_template_for_triplet(attack_data):
|
|
380 |
L = len(candidate_text_sen)
|
381 |
assert L > 0
|
382 |
ret_log_L = []
|
383 |
-
for l in range(0, L, GPT_batch_size):
|
384 |
R = min(L, l + GPT_batch_size)
|
385 |
target = target_ids[l:R, :]
|
386 |
attention = attention_mask[l:R, :]
|
@@ -399,10 +437,14 @@ def generate_template_for_triplet(attack_data):
|
|
399 |
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
|
400 |
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
|
401 |
sen_score.sort(key = lambda x: x[1])
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
|
|
|
|
|
|
|
|
406 |
|
407 |
gpt_model.to('cpu')
|
408 |
return single_sentence, test_text, test_dp, test_parse
|
@@ -478,7 +520,7 @@ sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
|
|
478 |
|
479 |
def generate_specific_attack_edge(start_entity, end_entity):
|
480 |
|
481 |
-
if
|
482 |
print('We can just set the malicious link equals to the target link, since the generation of malicious link is too slow on cpu')
|
483 |
return entity_to_id[drug_dict[start_entity]], '10', entity_to_id[disease_dict[end_entity]]
|
484 |
global specific_model
|
@@ -649,6 +691,26 @@ def agnostic_func(agnostic_entity):
|
|
649 |
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
|
650 |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
651 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
#%%
|
653 |
with gr.Blocks() as demo:
|
654 |
|
@@ -660,15 +722,26 @@ with gr.Blocks() as demo:
|
|
660 |
# Center
|
661 |
with gr.Column():
|
662 |
gr.Markdown("Select your poison target")
|
663 |
-
with gr.Tab('
|
664 |
-
with gr.
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
672 |
with gr.Column():
|
673 |
gr.Markdown("Malicious link")
|
674 |
malicisous_link = gr.Textbox(lines=1, label="Malicious link")
|
@@ -676,6 +749,8 @@ with gr.Blocks() as demo:
|
|
676 |
malicious_text = gr.Textbox(label="Malicious text", lines=5)
|
677 |
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
|
678 |
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
|
|
|
|
|
679 |
|
680 |
# demo.launch(server_name="0.0.0.0", server_port=8000, debug=False)
|
681 |
demo.launch()
|
|
|
11 |
import spacy
|
12 |
# os.system("python -m spacy download en-core-web-sm")
|
13 |
import pickle as pkl
|
14 |
+
from tqdm import tqdm
|
15 |
#%%
|
16 |
# please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
17 |
# torch.loa
|
|
|
38 |
args = parser.parse_args()
|
39 |
args = utils.set_hyperparams(args)
|
40 |
|
41 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
device = torch.device("cpu")
|
43 |
args.device = device
|
44 |
args.device1 = device
|
45 |
+
# if torch.cuda.device_count() >= 2:
|
46 |
+
# args.device = "cuda:0"
|
47 |
+
# args.device1 = "cuda:1"
|
48 |
|
49 |
utils.seed_all(args.seed)
|
50 |
np.set_printoptions(precision=5)
|
|
|
78 |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
79 |
drug_term = pkl.load(fl)
|
80 |
|
81 |
+
gallery_specific_target_path = os.path.join(data_path, 'DD_target_distmult_GNBR_random_50_exists:False_single.txt')
|
82 |
+
gallery_specific_link_path = 'DiseaseSpecific/attack_results/GNBR/cos_distmult_random_50_exists:False_20_quadratic_single_0.5.txt'
|
83 |
+
gallery_specific_text_path = 'DiseaseSpecific/generate_abstract/random_0.5_bioBART_finetune.json'
|
84 |
+
gallery_agnostic_target_path = 'DiseaseAgnostic/processed_data/target_0.7random.pkl'
|
85 |
+
gallery_agnostic_link_path = 'DiseaseAgnostic/processed_data/attack_edge_distmult_0.7random.pkl'
|
86 |
+
gallery_agnostic_text_path = 'DiseaseAgnostic/generate_abstract/random0.7_bioBART_finetune.json'
|
87 |
+
gallery_specific_target = utils.load_data(gallery_specific_target_path, drop=False)
|
88 |
+
gallery_specific_link = utils.load_data(gallery_specific_link_path, drop=False)
|
89 |
+
with open(gallery_specific_text_path, 'r') as fl:
|
90 |
+
gallery_specific_text = json.load(fl)
|
91 |
+
with open(gallery_agnostic_target_path, 'rb') as fl:
|
92 |
+
gallery_agnostic_target = pkl.load(fl)
|
93 |
+
with open(gallery_agnostic_link_path, 'rb') as fl:
|
94 |
+
gallery_agnostic_link = pkl.load(fl)
|
95 |
+
with open(gallery_agnostic_text_path, 'r') as fl:
|
96 |
+
gallery_agnostic_text = json.load(fl)
|
97 |
+
|
98 |
+
gallery_specific_list = []
|
99 |
+
gallery_specific_target_dict = {}
|
100 |
+
for i, (s, r, o) in enumerate(gallery_specific_target):
|
101 |
+
s = id_to_meshid[str(s)]
|
102 |
+
o = id_to_meshid[str(o)]
|
103 |
+
target_name = f'{capitalize_the_first_letter(entity_raw_name[s])} - {capitalize_the_first_letter(entity_raw_name[o])}'
|
104 |
+
if target_name not in gallery_specific_target_dict:
|
105 |
+
gallery_specific_target_dict[target_name] = i
|
106 |
+
gallery_specific_list.append(target_name)
|
107 |
+
gallery_specific_list.sort()
|
108 |
+
|
109 |
+
gallery_agnostic_list = []
|
110 |
+
gallery_agnostic_target_dict = {}
|
111 |
+
|
112 |
+
for i, iid in enumerate(gallery_agnostic_target):
|
113 |
+
target_name = capitalize_the_first_letter(entity_raw_name[id_to_meshid[str(iid)]])
|
114 |
+
if target_name not in gallery_agnostic_target_dict:
|
115 |
+
gallery_agnostic_target_dict[target_name] = i
|
116 |
+
gallery_agnostic_list.append(target_name)
|
117 |
+
gallery_agnostic_list.sort()
|
118 |
drug_dict = {}
|
119 |
disease_dict = {}
|
120 |
for k, v in entity_raw_name.items():
|
|
|
231 |
|
232 |
batch_size = 8
|
233 |
Outs = []
|
234 |
+
for l in tqdm(range(0, len(Text), batch_size)):
|
235 |
R = min(len(Text), l + batch_size)
|
236 |
A = bart_tokenizer(Text[l:R],
|
237 |
truncation = True,
|
|
|
249 |
def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v):
|
250 |
|
251 |
criterion = CrossEntropyLoss(reduction="none")
|
252 |
+
text_s = entity_raw_name[id_to_meshid[str(s)]]
|
253 |
+
text_o = entity_raw_name[id_to_meshid[str(o)]]
|
254 |
|
255 |
sen_list = [server_utils.process(text) for text in sen_list]
|
256 |
path_text = dpath[0].replace('\n', '')
|
|
|
328 |
attention_mask = tokens['attention_mask'].to(args.device1)
|
329 |
L = len(sen_list)
|
330 |
ret_log_L = []
|
331 |
+
for l in tqdm(range(0, L, 5)):
|
332 |
R = min(L, l + 5)
|
333 |
target = target_ids[l:R, :]
|
334 |
attention = attention_mask[l:R, :]
|
|
|
418 |
L = len(candidate_text_sen)
|
419 |
assert L > 0
|
420 |
ret_log_L = []
|
421 |
+
for l in tqdm(range(0, L, GPT_batch_size)):
|
422 |
R = min(L, l + GPT_batch_size)
|
423 |
target = target_ids[l:R, :]
|
424 |
attention = attention_mask[l:R, :]
|
|
|
437 |
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
|
438 |
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
|
439 |
sen_score.sort(key = lambda x: x[1])
|
440 |
+
Len = len(sen_score)
|
441 |
+
p = 0
|
442 |
+
if Len > 10:
|
443 |
+
p = np.random.choice(np.array(range(Len // 10)), 1)[0]
|
444 |
+
test_text.append(sen_score[p][2])
|
445 |
+
test_dp.append(sen_score[p][3])
|
446 |
+
test_parse.append(sen_score[p][4])
|
447 |
+
single_sentence.append(sen_score[p][0])
|
448 |
|
449 |
gpt_model.to('cpu')
|
450 |
return single_sentence, test_text, test_dp, test_parse
|
|
|
520 |
|
521 |
def generate_specific_attack_edge(start_entity, end_entity):
|
522 |
|
523 |
+
if device == torch.device('cpu'):
|
524 |
print('We can just set the malicious link equals to the target link, since the generation of malicious link is too slow on cpu')
|
525 |
return entity_to_id[drug_dict[start_entity]], '10', entity_to_id[disease_dict[end_entity]]
|
526 |
global specific_model
|
|
|
691 |
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
|
692 |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
693 |
|
694 |
+
def gallery_specific_func(specific_target):
|
695 |
+
index = gallery_specific_target_dict[specific_target]
|
696 |
+
s, r, o = gallery_specific_link[index]
|
697 |
+
s_name = entity_raw_name[id_to_entity[str(s)]]
|
698 |
+
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
|
699 |
+
o_name = entity_raw_name[id_to_entity[str(o)]]
|
700 |
+
|
701 |
+
k = f'{s}_{r}_{o}_{index}'
|
702 |
+
text = gallery_specific_text[k]['out']
|
703 |
+
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
704 |
+
|
705 |
+
def gallery_agnostic_func(agnostic_target):
|
706 |
+
index = gallery_agnostic_target_dict[agnostic_target]
|
707 |
+
s, r, o = gallery_agnostic_link[index]
|
708 |
+
s_name = entity_raw_name[id_to_entity[str(s)]]
|
709 |
+
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
|
710 |
+
o_name = entity_raw_name[id_to_entity[str(o)]]
|
711 |
+
k = f'{s}_{r}_{o}_{index}'
|
712 |
+
text = gallery_agnostic_text[k]['out']
|
713 |
+
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
714 |
#%%
|
715 |
with gr.Blocks() as demo:
|
716 |
|
|
|
722 |
# Center
|
723 |
with gr.Column():
|
724 |
gr.Markdown("Select your poison target")
|
725 |
+
with gr.Tab('Gallery'):
|
726 |
+
with gr.Tab('Target specific'):
|
727 |
+
specific_target = gr.Dropdown(gallery_specific_list, label="Poisonging target")
|
728 |
+
gallery_specific_generation_button = gr.Button('Poison!')
|
729 |
+
with gr.Tab('Target agnostic'):
|
730 |
+
agnostic_target = gr.Dropdown(gallery_agnostic_list, label="Poisonging target")
|
731 |
+
gallery_agnostic_generation_button = gr.Button('Poison!')
|
732 |
+
|
733 |
+
with gr.Tab('Poison'):
|
734 |
+
with gr.Tab('Target specific'):
|
735 |
+
with gr.Column():
|
736 |
+
with gr.Row():
|
737 |
+
start_entity = gr.Dropdown(drug_list, label="Promoting drug")
|
738 |
+
end_entity = gr.Dropdown(disease_list, label="Target disease")
|
739 |
+
if device == torch.device('cpu'):
|
740 |
+
gr.Markdown("Since the project is currently running on the CPU, we directly treat the malicious link as equivalent to the poisoning target, to accelerate the generation process.")
|
741 |
+
specific_generation_button = gr.Button('Poison!')
|
742 |
+
with gr.Tab('Target agnostic'):
|
743 |
+
agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
|
744 |
+
agnostic_generation_button = gr.Button('Poison!')
|
745 |
with gr.Column():
|
746 |
gr.Markdown("Malicious link")
|
747 |
malicisous_link = gr.Textbox(lines=1, label="Malicious link")
|
|
|
749 |
malicious_text = gr.Textbox(label="Malicious text", lines=5)
|
750 |
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
|
751 |
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
|
752 |
+
gallery_specific_generation_button.click(gallery_specific_func, inputs=[specific_target], outputs=[malicisous_link, malicious_text])
|
753 |
+
gallery_agnostic_generation_button.click(gallery_agnostic_func, inputs=[agnostic_target], outputs=[malicisous_link, malicious_text])
|
754 |
|
755 |
# demo.launch(server_name="0.0.0.0", server_port=8000, debug=False)
|
756 |
demo.launch()
|