yjwtheonly commited on
Commit
144e87a
1 Parent(s): b450e5c
Files changed (2) hide show
  1. Openai/__pycache__/chat.cpython-38.pyc +0 -0
  2. server.py +99 -24
Openai/__pycache__/chat.cpython-38.pyc CHANGED
Binary files a/Openai/__pycache__/chat.cpython-38.pyc and b/Openai/__pycache__/chat.cpython-38.pyc differ
 
server.py CHANGED
@@ -11,6 +11,7 @@ import networkx as nx
11
  import spacy
12
  # os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
 
14
  #%%
15
  # please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
16
  # torch.loa
@@ -37,13 +38,13 @@ parser.add_argument('--init-mode', type = str, default='single', help = 'How to
37
  args = parser.parse_args()
38
  args = utils.set_hyperparams(args)
39
 
40
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
- # device = torch.device("cpu")
42
  args.device = device
43
  args.device1 = device
44
- if torch.cuda.device_count() >= 2:
45
- args.device = "cuda:0"
46
- args.device1 = "cuda:1"
47
 
48
  utils.seed_all(args.seed)
49
  np.set_printoptions(precision=5)
@@ -77,6 +78,43 @@ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
77
  with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
78
  drug_term = pkl.load(fl)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  drug_dict = {}
81
  disease_dict = {}
82
  for k, v in entity_raw_name.items():
@@ -193,7 +231,7 @@ def tune_chatgpt(draft, attack_data, dpath):
193
 
194
  batch_size = 8
195
  Outs = []
196
- for l in range(0, len(Text), batch_size):
197
  R = min(len(Text), l + batch_size)
198
  A = bart_tokenizer(Text[l:R],
199
  truncation = True,
@@ -211,8 +249,8 @@ def tune_chatgpt(draft, attack_data, dpath):
211
  def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v):
212
 
213
  criterion = CrossEntropyLoss(reduction="none")
214
- text_s = entity_raw_name[id_to_meshid[s]]
215
- text_o = entity_raw_name[id_to_meshid[o]]
216
 
217
  sen_list = [server_utils.process(text) for text in sen_list]
218
  path_text = dpath[0].replace('\n', '')
@@ -290,7 +328,7 @@ def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath,
290
  attention_mask = tokens['attention_mask'].to(args.device1)
291
  L = len(sen_list)
292
  ret_log_L = []
293
- for l in range(0, L, 5):
294
  R = min(L, l + 5)
295
  target = target_ids[l:R, :]
296
  attention = attention_mask[l:R, :]
@@ -380,7 +418,7 @@ def generate_template_for_triplet(attack_data):
380
  L = len(candidate_text_sen)
381
  assert L > 0
382
  ret_log_L = []
383
- for l in range(0, L, GPT_batch_size):
384
  R = min(L, l + GPT_batch_size)
385
  target = target_ids[l:R, :]
386
  attention = attention_mask[l:R, :]
@@ -399,10 +437,14 @@ def generate_template_for_triplet(attack_data):
399
  ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
400
  sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
401
  sen_score.sort(key = lambda x: x[1])
402
- test_text.append(sen_score[0][2])
403
- test_dp.append(sen_score[0][3])
404
- test_parse.append(sen_score[0][4])
405
- single_sentence.append(sen_score[0][0])
 
 
 
 
406
 
407
  gpt_model.to('cpu')
408
  return single_sentence, test_text, test_dp, test_parse
@@ -478,7 +520,7 @@ sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
478
 
479
  def generate_specific_attack_edge(start_entity, end_entity):
480
 
481
- if not torch.cuda.is_available():
482
  print('We can just set the malicious link equals to the target link, since the generation of malicious link is too slow on cpu')
483
  return entity_to_id[drug_dict[start_entity]], '10', entity_to_id[disease_dict[end_entity]]
484
  global specific_model
@@ -649,6 +691,26 @@ def agnostic_func(agnostic_entity):
649
  text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
650
  return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  #%%
653
  with gr.Blocks() as demo:
654
 
@@ -660,15 +722,26 @@ with gr.Blocks() as demo:
660
  # Center
661
  with gr.Column():
662
  gr.Markdown("Select your poison target")
663
- with gr.Tab('Target specific'):
664
- with gr.Column():
665
- with gr.Row():
666
- start_entity = gr.Dropdown(drug_list, label="Promoting drug")
667
- end_entity = gr.Dropdown(disease_list, label="Target disease")
668
- specific_generation_button = gr.Button('Poison!')
669
- with gr.Tab('Target agnostic'):
670
- agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
671
- agnostic_generation_button = gr.Button('Poison!')
 
 
 
 
 
 
 
 
 
 
 
672
  with gr.Column():
673
  gr.Markdown("Malicious link")
674
  malicisous_link = gr.Textbox(lines=1, label="Malicious link")
@@ -676,6 +749,8 @@ with gr.Blocks() as demo:
676
  malicious_text = gr.Textbox(label="Malicious text", lines=5)
677
  specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
678
  agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
 
 
679
 
680
  # demo.launch(server_name="0.0.0.0", server_port=8000, debug=False)
681
  demo.launch()
 
11
  import spacy
12
  # os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
14
+ from tqdm import tqdm
15
  #%%
16
  # please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
17
  # torch.loa
 
38
  args = parser.parse_args()
39
  args = utils.set_hyperparams(args)
40
 
41
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ device = torch.device("cpu")
43
  args.device = device
44
  args.device1 = device
45
+ # if torch.cuda.device_count() >= 2:
46
+ # args.device = "cuda:0"
47
+ # args.device1 = "cuda:1"
48
 
49
  utils.seed_all(args.seed)
50
  np.set_printoptions(precision=5)
 
78
  with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
79
  drug_term = pkl.load(fl)
80
 
81
+ gallery_specific_target_path = os.path.join(data_path, 'DD_target_distmult_GNBR_random_50_exists:False_single.txt')
82
+ gallery_specific_link_path = 'DiseaseSpecific/attack_results/GNBR/cos_distmult_random_50_exists:False_20_quadratic_single_0.5.txt'
83
+ gallery_specific_text_path = 'DiseaseSpecific/generate_abstract/random_0.5_bioBART_finetune.json'
84
+ gallery_agnostic_target_path = 'DiseaseAgnostic/processed_data/target_0.7random.pkl'
85
+ gallery_agnostic_link_path = 'DiseaseAgnostic/processed_data/attack_edge_distmult_0.7random.pkl'
86
+ gallery_agnostic_text_path = 'DiseaseAgnostic/generate_abstract/random0.7_bioBART_finetune.json'
87
+ gallery_specific_target = utils.load_data(gallery_specific_target_path, drop=False)
88
+ gallery_specific_link = utils.load_data(gallery_specific_link_path, drop=False)
89
+ with open(gallery_specific_text_path, 'r') as fl:
90
+ gallery_specific_text = json.load(fl)
91
+ with open(gallery_agnostic_target_path, 'rb') as fl:
92
+ gallery_agnostic_target = pkl.load(fl)
93
+ with open(gallery_agnostic_link_path, 'rb') as fl:
94
+ gallery_agnostic_link = pkl.load(fl)
95
+ with open(gallery_agnostic_text_path, 'r') as fl:
96
+ gallery_agnostic_text = json.load(fl)
97
+
98
+ gallery_specific_list = []
99
+ gallery_specific_target_dict = {}
100
+ for i, (s, r, o) in enumerate(gallery_specific_target):
101
+ s = id_to_meshid[str(s)]
102
+ o = id_to_meshid[str(o)]
103
+ target_name = f'{capitalize_the_first_letter(entity_raw_name[s])} - {capitalize_the_first_letter(entity_raw_name[o])}'
104
+ if target_name not in gallery_specific_target_dict:
105
+ gallery_specific_target_dict[target_name] = i
106
+ gallery_specific_list.append(target_name)
107
+ gallery_specific_list.sort()
108
+
109
+ gallery_agnostic_list = []
110
+ gallery_agnostic_target_dict = {}
111
+
112
+ for i, iid in enumerate(gallery_agnostic_target):
113
+ target_name = capitalize_the_first_letter(entity_raw_name[id_to_meshid[str(iid)]])
114
+ if target_name not in gallery_agnostic_target_dict:
115
+ gallery_agnostic_target_dict[target_name] = i
116
+ gallery_agnostic_list.append(target_name)
117
+ gallery_agnostic_list.sort()
118
  drug_dict = {}
119
  disease_dict = {}
120
  for k, v in entity_raw_name.items():
 
231
 
232
  batch_size = 8
233
  Outs = []
234
+ for l in tqdm(range(0, len(Text), batch_size)):
235
  R = min(len(Text), l + batch_size)
236
  A = bart_tokenizer(Text[l:R],
237
  truncation = True,
 
249
  def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v):
250
 
251
  criterion = CrossEntropyLoss(reduction="none")
252
+ text_s = entity_raw_name[id_to_meshid[str(s)]]
253
+ text_o = entity_raw_name[id_to_meshid[str(o)]]
254
 
255
  sen_list = [server_utils.process(text) for text in sen_list]
256
  path_text = dpath[0].replace('\n', '')
 
328
  attention_mask = tokens['attention_mask'].to(args.device1)
329
  L = len(sen_list)
330
  ret_log_L = []
331
+ for l in tqdm(range(0, L, 5)):
332
  R = min(L, l + 5)
333
  target = target_ids[l:R, :]
334
  attention = attention_mask[l:R, :]
 
418
  L = len(candidate_text_sen)
419
  assert L > 0
420
  ret_log_L = []
421
+ for l in tqdm(range(0, L, GPT_batch_size)):
422
  R = min(L, l + GPT_batch_size)
423
  target = target_ids[l:R, :]
424
  attention = attention_mask[l:R, :]
 
437
  ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
438
  sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
439
  sen_score.sort(key = lambda x: x[1])
440
+ Len = len(sen_score)
441
+ p = 0
442
+ if Len > 10:
443
+ p = np.random.choice(np.array(range(Len // 10)), 1)[0]
444
+ test_text.append(sen_score[p][2])
445
+ test_dp.append(sen_score[p][3])
446
+ test_parse.append(sen_score[p][4])
447
+ single_sentence.append(sen_score[p][0])
448
 
449
  gpt_model.to('cpu')
450
  return single_sentence, test_text, test_dp, test_parse
 
520
 
521
  def generate_specific_attack_edge(start_entity, end_entity):
522
 
523
+ if device == torch.device('cpu'):
524
  print('We can just set the malicious link equals to the target link, since the generation of malicious link is too slow on cpu')
525
  return entity_to_id[drug_dict[start_entity]], '10', entity_to_id[disease_dict[end_entity]]
526
  global specific_model
 
691
  text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
692
  return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
693
 
694
+ def gallery_specific_func(specific_target):
695
+ index = gallery_specific_target_dict[specific_target]
696
+ s, r, o = gallery_specific_link[index]
697
+ s_name = entity_raw_name[id_to_entity[str(s)]]
698
+ r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
699
+ o_name = entity_raw_name[id_to_entity[str(o)]]
700
+
701
+ k = f'{s}_{r}_{o}_{index}'
702
+ text = gallery_specific_text[k]['out']
703
+ return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
704
+
705
+ def gallery_agnostic_func(agnostic_target):
706
+ index = gallery_agnostic_target_dict[agnostic_target]
707
+ s, r, o = gallery_agnostic_link[index]
708
+ s_name = entity_raw_name[id_to_entity[str(s)]]
709
+ r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
710
+ o_name = entity_raw_name[id_to_entity[str(o)]]
711
+ k = f'{s}_{r}_{o}_{index}'
712
+ text = gallery_agnostic_text[k]['out']
713
+ return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
714
  #%%
715
  with gr.Blocks() as demo:
716
 
 
722
  # Center
723
  with gr.Column():
724
  gr.Markdown("Select your poison target")
725
+ with gr.Tab('Gallery'):
726
+ with gr.Tab('Target specific'):
727
+ specific_target = gr.Dropdown(gallery_specific_list, label="Poisonging target")
728
+ gallery_specific_generation_button = gr.Button('Poison!')
729
+ with gr.Tab('Target agnostic'):
730
+ agnostic_target = gr.Dropdown(gallery_agnostic_list, label="Poisonging target")
731
+ gallery_agnostic_generation_button = gr.Button('Poison!')
732
+
733
+ with gr.Tab('Poison'):
734
+ with gr.Tab('Target specific'):
735
+ with gr.Column():
736
+ with gr.Row():
737
+ start_entity = gr.Dropdown(drug_list, label="Promoting drug")
738
+ end_entity = gr.Dropdown(disease_list, label="Target disease")
739
+ if device == torch.device('cpu'):
740
+ gr.Markdown("Since the project is currently running on the CPU, we directly treat the malicious link as equivalent to the poisoning target, to accelerate the generation process.")
741
+ specific_generation_button = gr.Button('Poison!')
742
+ with gr.Tab('Target agnostic'):
743
+ agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
744
+ agnostic_generation_button = gr.Button('Poison!')
745
  with gr.Column():
746
  gr.Markdown("Malicious link")
747
  malicisous_link = gr.Textbox(lines=1, label="Malicious link")
 
749
  malicious_text = gr.Textbox(label="Malicious text", lines=5)
750
  specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
751
  agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
752
+ gallery_specific_generation_button.click(gallery_specific_func, inputs=[specific_target], outputs=[malicisous_link, malicious_text])
753
+ gallery_agnostic_generation_button.click(gallery_agnostic_func, inputs=[agnostic_target], outputs=[malicisous_link, malicious_text])
754
 
755
  # demo.launch(server_name="0.0.0.0", server_port=8000, debug=False)
756
  demo.launch()