#%% import gradio as gr import time import sys import os import torch import torch.backends.cudnn as cudnn import numpy as np import json import networkx as nx import spacy # os.system("python -m spacy download en-core-web-sm") import pickle as pkl #%% # please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. # torch.loa from torch.nn.modules.loss import CrossEntropyLoss from transformers import AutoTokenizer from transformers import BioGptForCausalLM, BartForConditionalGeneration from server import server_utils import Parameters from Openai.chat import generate_abstract from DiseaseSpecific import utils, attack from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax specific_model = None def capitalize_the_first_letter(s): return s[0].upper() + s[1:] parser = utils.get_argument_parser() parser = utils.add_attack_parameters(parser) parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study args = parser.parse_args() args = utils.set_hyperparams(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") args.device = device args.device1 = device if torch.cuda.device_count() >= 2: args.device = "cuda:0" args.device1 = "cuda:1" utils.seed_all(args.seed) np.set_printoptions(precision=5) cudnn.benchmark = False model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop) model_path = 'DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name) data_path = os.path.join('DiseaseSpecific/processed_data', args.data) data = utils.load_data(os.path.join(data_path, 'all.txt')) n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path) with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl: filters = pkl.load(fl) with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl: entityid_to_nodetype = json.load(fl) with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl: edge_nghbrs = pkl.load(fl) with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl: disease_meshid = pkl.load(fl) with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl: entity_to_id = json.load(fl) with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: entity_raw_name = pkl.load(fl) with open(os.path.join(data_path, 'entities_reverse_dict.json'), 'r') as fl: id_to_entity = json.load(fl) id_to_meshid = id_to_entity.copy() with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: retieve_sentence_through_edgetype = pkl.load(fl) with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: raw_text_sen = pkl.load(fl) with open(Parameters.UMLSfile+'drug_term', 'rb') as fl: drug_term = pkl.load(fl) drug_dict = {} disease_dict = {} for k, v in entity_raw_name.items(): #chemical_mesh:c050048 tp = k.split('_')[0] v = capitalize_the_first_letter(v) if len(v) <= 2: continue if tp == 'chemical': drug_dict[v] = k elif tp == 'disease': disease_dict[v] = k drug_list = list(drug_dict.keys()) disease_list = list(disease_dict.keys()) drug_list.sort() disease_list.sort() init_mask = np.asarray([0] * n_ent).astype('int64') init_mask = (init_mask == 1) for k, v in filters.items(): for kk, vv in v.items(): tmp = init_mask.copy() tmp[np.asarray(vv)] = True t = torch.ByteTensor(tmp).to(args.device) filters[k][kk] = t gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') gpt_tokenizer.pad_token = gpt_tokenizer.eos_token gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id) gpt_model.eval() specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device) specific_model.eval() divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent) nlp = spacy.load("en_core_web_sm") bart_model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large') bart_model.eval() bart_tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large') def tune_chatgpt(draft, attack_data, dpath): dpath_i = 0 bart_model.to(args.device1) for i, v in enumerate(draft): input = v['in'].replace('\n', '') output = v['out'].replace('\n', '') s, r, o = attack_data[i] path_text = dpath[dpath_i].replace('\n', '') dpath_i += 1 text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] doc = nlp(output) words= input.split(' ') tokenized_sens = [sen for sen in doc.sents] sens = np.array([sen.text for sen in doc.sents]) checkset = set([text_s, text_o]) e_entity = set(['start_entity', 'end_entity']) for path in path_text.split(' '): a, b, c = path.split('|') if a not in e_entity: checkset.add(a) if c not in e_entity: checkset.add(c) vec = [] l = 0 while(l < len(words)): bo =False for j in range(len(words), l, -1): # reversing is important !!! cc = ' '.join(words[l:j]) if (cc in checkset): vec += [True] * (j-l) l = j bo = True break if not bo: vec.append(False) l += 1 vec, span = server_utils.find_mini_span(vec, words, checkset) # vec = np.vectorize(lambda x: x in checkset)(words) vec[-1] = True prompt = [] mask_num = 0 for j, bo in enumerate(vec): if not bo: mask_num += 1 else: if mask_num > 0: # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3) mask_num = max(mask_num, 1) mask_num= min(8, mask_num) prompt += [''] * mask_num prompt.append(words[j]) mask_num = 0 prompt = ' '.join(prompt) Text = [] Assist = [] for j in range(len(sens)): Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:]) assist = list(sens[:j]) + [input] +list(sens[j+1:]) Text.append(' '.join(Bart_input)) Assist.append(' '.join(assist)) for j in range(len(sens)): Bart_input = server_utils.mask_func(tokenized_sens[:j]) + [input] + server_utils.mask_func(tokenized_sens[j+1:]) assist = list(sens[:j]) + [input] +list(sens[j+1:]) Text.append(' '.join(Bart_input)) Assist.append(' '.join(assist)) batch_size = 8 Outs = [] for l in range(0, len(Text), batch_size): R = min(len(Text), l + batch_size) A = bart_tokenizer(Text[l:R], truncation = True, padding = True, max_length = 1024, return_tensors="pt") input_ids = A['input_ids'].to(args.device1) attention_mask = A['attention_mask'].to(args.device1) aaid = bart_model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024) outs = bart_tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False) Outs += outs bart_model.to('cpu') return span, prompt, Outs, Text, Assist def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v): criterion = CrossEntropyLoss(reduction="none") text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] sen_list = [server_utils.process(text) for text in sen_list] path_text = dpath[0].replace('\n', '') checkset = set([text_s, text_o]) e_entity = set(['start_entity', 'end_entity']) for path in path_text.split(' '): a, b, c = path.split('|') if a not in e_entity: checkset.add(a) if c not in e_entity: checkset.add(c) input = v['in'].replace('\n', '') output = v['out'].replace('\n', '') doc = nlp(output) gpt_sens = [sen.text for sen in doc.sents] assert len(gpt_sens) == len(sen_list) // 2 word_sets = [] for sen in gpt_sens: word_sets.append(set(sen.split(' '))) def sen_align(word_sets, modified_word_sets): l = 0 while(l < len(modified_word_sets)): if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8: l += 1 else: break if l == len(modified_word_sets): return -1, -1, -1, -1 r = l + 1 r1 = None r2 = None for pos1 in range(r, len(word_sets)): for pos2 in range(r, len(modified_word_sets)): if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8: r1 = pos1 r2 = pos2 break if r1 is not None: break if r1 is None: r1 = len(word_sets) r2 = len(modified_word_sets) return l, r1, l, r2 replace_sen_list = [] boundary = [] assert len(sen_list) % 2 == 0 for j in range(len(sen_list) // 2): doc = nlp(sen_list[j]) sens = [sen.text for sen in doc.sents] modified_word_sets = [set(sen.split(' ')) for sen in sens] l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets) boundary.append((l1, r1, l2, r2)) if l1 == -1: replace_sen_list.append(sen_list[j]) continue check_text = ' '.join(sens[l2: r2]) replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:])) sen_list = replace_sen_list + sen_list[len(sen_list) // 2:] gpt_model.to(args.device1) sen_list.append(output) tokens = gpt_tokenizer( sen_list, truncation = True, padding = True, max_length = 1024, return_tensors="pt") target_ids = tokens['input_ids'].to(args.device1) attention_mask = tokens['attention_mask'].to(args.device1) L = len(sen_list) ret_log_L = [] for l in range(0, L, 5): R = min(L, l + 5) target = target_ids[l:R, :] attention = attention_mask[l:R, :] outputs = gpt_model(input_ids = target, attention_mask = attention, labels = target) logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = target[..., 1:].contiguous() Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) Loss = Loss.view(-1, shift_logits.shape[1]) attention = attention[..., 1:].contiguous() log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) ret_log_L.append(log_Loss.detach()) log_Loss = torch.cat(ret_log_L, -1).cpu().numpy() gpt_model.to('cpu') p = np.argmin(log_Loss) return sen_list[p] def generate_template_for_triplet(attack_data): criterion = CrossEntropyLoss(reduction="none") gpt_model.to(args.device1) print('Generating template ...') GPT_batch_size = 8 single_sentence = [] test_text = [] test_dp = [] test_parse = [] s, r, o = attack_data[0] dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] candidate_sen = [] Dp_path = [] L = len(dependency_sen_dict.keys()) bound = 500 // L if bound == 0: bound = 1 for dp_path, sen_list in dependency_sen_dict.items(): if len(sen_list) > bound: index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False) sen_list = [sen_list[aa] for aa in index] ssen_list = [] for aa in range(len(sen_list)): paper_id, sen_id = sen_list[aa] if raw_text_sen[paper_id][sen_id]['start_formatted'] == raw_text_sen[paper_id][sen_id]['end_formatted']: continue ssen_list.append(sen_list[aa]) sen_list = ssen_list candidate_sen += sen_list Dp_path += [dp_path] * len(sen_list) text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] candidate_text_sen = [] candidate_ori_sen = [] candidate_parse_sen = [] for paper_id, sen_id in candidate_sen: sen = raw_text_sen[paper_id][sen_id] text = sen['text'] candidate_ori_sen.append(text) ss = sen['start_formatted'] oo = sen['end_formatted'] text = text.replace('-LRB-', '(') text = text.replace('-RRB-', ')') text = text.replace('-LSB-', '[') text = text.replace('-RSB-', ']') text = text.replace('-LCB-', '{') text = text.replace('-RCB-', '}') parse_text = text parse_text = parse_text.replace(ss, text_s.replace(' ', '_')) parse_text = parse_text.replace(oo, text_o.replace(' ', '_')) text = text.replace(ss, text_s) text = text.replace(oo, text_o) text = text.replace('_', ' ') candidate_text_sen.append(text) candidate_parse_sen.append(parse_text) tokens = gpt_tokenizer( candidate_text_sen, truncation = True, padding = True, max_length = 300, return_tensors="pt") target_ids = tokens['input_ids'].to(args.device1) attention_mask = tokens['attention_mask'].to(args.device1) L = len(candidate_text_sen) assert L > 0 ret_log_L = [] for l in range(0, L, GPT_batch_size): R = min(L, l + GPT_batch_size) target = target_ids[l:R, :] attention = attention_mask[l:R, :] outputs = gpt_model(input_ids = target, attention_mask = attention, labels = target) logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = target[..., 1:].contiguous() Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) Loss = Loss.view(-1, shift_logits.shape[1]) attention = attention[..., 1:].contiguous() log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) ret_log_L.append(log_Loss.detach()) ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy()) sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen)) sen_score.sort(key = lambda x: x[1]) test_text.append(sen_score[0][2]) test_dp.append(sen_score[0][3]) test_parse.append(sen_score[0][4]) single_sentence.append(sen_score[0][0]) gpt_model.to('cpu') return single_sentence, test_text, test_dp, test_parse meshids = list(id_to_meshid.values()) cal = { 'chemical' : 0, 'disease' : 0, 'gene' : 0 } for meshid in meshids: cal[meshid.split('_')[0]] += 1 def check_reasonable(s, r, o): train_trip = np.asarray([[s, r, o]]) train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() # edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1)) edge_loss = edge_loss.item() edge_loss = (edge_loss - data_mean) / data_std edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) ) bound = 1 - args.reasonable_rate return (edge_losses_prob > bound), edge_losses_prob edgeid_to_edgetype = {} edgeid_to_reversemask = {} for k, id_list in Parameters.edge_type_to_id.items(): for iid, mask in zip(id_list, Parameters.reverse_mask[k]): edgeid_to_edgetype[str(iid)] = k edgeid_to_reversemask[str(iid)] = mask reverse_tot = 0 G = nx.DiGraph() for s, r, o in data: assert id_to_meshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0] if edgeid_to_reversemask[r] == 1: reverse_tot += 1 G.add_edge(int(o), int(s)) else: G.add_edge(int(s), int(o)) print('Page ranking ...') pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7) drug_meshid = [] drug_list = [] for meshid, nm in entity_raw_name.items(): if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical': drug_meshid.append(meshid) drug_list.append(capitalize_the_first_letter(nm)) drug_list = list(set(drug_list)) drug_list.sort() drug_meshid = set(drug_meshid) pr = list(pagerank_value_1.items()) pr.sort(key = lambda x: x[1]) sorted_rank = { 'chemical' : [], 'gene' : [], 'disease': [], 'merged' : []} for iid, score in pr: tp = id_to_meshid[str(iid)].split('_')[0] if tp == 'chemical': if id_to_meshid[str(iid)] in drug_meshid: sorted_rank[tp].append((iid, score)) else: sorted_rank[tp].append((iid, score)) sorted_rank['merged'].append((iid, score)) llen = len(sorted_rank['merged']) sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ] def generate_specific_attack_edge(start_entity, end_entity): if not torch.cuda.is_available(): print('We can just set the malicious link equals to the target link, since the generation of malicious link is too slow on cpu') return entity_to_id[drug_dict[start_entity]], '10', entity_to_id[disease_dict[end_entity]] global specific_model specific_model.to(device) strat_meshid = drug_dict[start_entity] end_meshid = disease_dict[end_entity] start_entity = entity_to_id[strat_meshid] end_entity = entity_to_id[end_meshid] target_data = np.array([[start_entity, '10', end_entity]]) neighbors = attack.generate_nghbrs(target_data, edge_nghbrs, args) ret = f'Generating malicious link for {strat_meshid}_treatment_{end_meshid}', 'Generation malicious text ...' param_optimizer = list(specific_model.named_parameters()) param_influence = [] for n,p in param_optimizer: param_influence.append(p) len_list = [] for v in neighbors.values(): len_list.append(len(v)) mean_len = np.mean(len_list) attack_trip, score_record = attack.addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, specific_model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std, cache_intermidiate = False) s, r, o = attack_trip[0] specific_model.to('cpu') return s, r, o def generate_agnostic_attack_edge(targets): specific_model.to(device) attack_edge_list = [] for target in targets: candidate_list = [] score_list = [] loss_list = [] main_dict = {} for iid, score in sorted_rank['merged']: a = G.number_of_edges(iid, target) + 1 if a != 1: continue b = G.out_degree(iid) + 1 tp = id_to_meshid[str(iid)].split('_')[0] edge_losses = [] r_list = [] for r in range(len(edgeid_to_edgetype)): r_tp = edgeid_to_edgetype[str(r)] if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'): train_trip = np.array([[iid, r, target]]) train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() edge_losses.append(edge_loss.unsqueeze(0).detach()) r_list.append(r) elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp): train_trip = np.array([[iid, r, target]]) # add batch dim train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() edge_losses.append(edge_loss.unsqueeze(0).detach()) r_list.append(r) if len(edge_losses)==0: continue min_index = torch.argmin(torch.cat(edge_losses, dim = 0)) r = r_list[min_index] r_tp = edgeid_to_edgetype[str(r)] old_len = len(candidate_list) if (edgeid_to_reversemask[str(r)] == 0): bo, prob = check_reasonable(iid, r, target) if bo: candidate_list.append((iid, r, target)) score_list.append(score * a / b) loss_list.append(edge_losses[min_index].item()) if (edgeid_to_reversemask[str(r)] == 1): bo, prob = check_reasonable(target, r, iid) if bo: candidate_list.append((target, r, iid)) score_list.append(score * a / b) loss_list.append(edge_losses[min_index].item()) if len(candidate_list) == 0: if args.added_edge_num == '' or int(args.added_edge_num) == 1: attack_edge_list.append((-1,-1,-1)) else: attack_edge_list.append([]) continue norm_score = np.array(score_list) / np.sum(score_list) norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list))) total_score = norm_score * norm_loss total_score_index = list(zip(range(len(total_score)), total_score)) total_score_index.sort(key = lambda x: x[1], reverse = True) total_index = np.argsort(total_score)[::-1] assert total_index[0] == total_score_index[0][0] # find rank of main index max_index = np.argmax(total_score) assert max_index == total_score_index[0][0] tmp_add = [] add_num = 1 if args.added_edge_num == '' or int(args.added_edge_num) == 1: attack_edge_list.append(candidate_list[max_index]) else: add_num = int(args.added_edge_num) for i in range(add_num): tmp_add.append(candidate_list[total_score_index[i][0]]) attack_edge_list.append(tmp_add) specific_model.to('cpu') return attack_edge_list[0] def specific_func(start_entity, end_entity): args.reasonable_rate = 0.5 s, r, o = generate_specific_attack_edge(start_entity, end_entity) if int(s) == -1: return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' s_name = entity_raw_name[id_to_entity[str(s)]] r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] o_name = entity_raw_name[id_to_entity[str(o)]] attack_data = np.array([[s, r, o]]) path_list = [] with open(f'DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl: for line in fl.readlines(): line.replace('\n', '') path_list.append(line) with open(f'DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl: sentence_dict = json.load(fl) dpath = [] for k, v in sentence_dict.items(): if f'{s}_{r}_{o}' in k: single_sentence = [v] dpath = [path_list[int(k.split('_')[-1])]] break if len(dpath) == 0: single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) elif not(s_name in single_sentence[0] and o_name in single_sentence[0]): single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) print('Using ChatGPT for generation...') draft = generate_abstract(single_sentence[0]) print('Using BioBART for tuning...') span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) # f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}' def agnostic_func(agnostic_entity): args.reasonable_rate = 0.7 target_id = entity_to_id[drug_dict[agnostic_entity]] s = generate_agnostic_attack_edge([int(target_id)]) if len(s) == 0: return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' if int(s[0]) == -1: return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' s, r, o = str(s[0]), str(s[1]), str(s[2]) s_name = entity_raw_name[id_to_entity[str(s)]] r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] o_name = entity_raw_name[id_to_entity[str(o)]] attack_data = np.array([[s, r, o]]) single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) print('Using ChatGPT for generation...') draft = generate_abstract(single_sentence[0]) print('Using BioBART for tuning...') span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) #%% with gr.Blocks() as demo: with gr.Column(): gr.Markdown("Poison scitific knowledge with Scorpius") # with gr.Column(): with gr.Row(): # Center with gr.Column(): gr.Markdown("Select your poison target") with gr.Tab('Target specific'): with gr.Column(): with gr.Row(): start_entity = gr.Dropdown(drug_list, label="Promoting drug") end_entity = gr.Dropdown(disease_list, label="Target disease") specific_generation_button = gr.Button('Poison!') with gr.Tab('Target agnostic'): agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug") agnostic_generation_button = gr.Button('Poison!') with gr.Column(): gr.Markdown("Malicious link") malicisous_link = gr.Textbox(lines=1, label="Malicious link") gr.Markdown("Malicious text") malicious_text = gr.Textbox(label="Malicious text", lines=5) specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text]) agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text]) # demo.launch(server_name="0.0.0.0", server_port=8000, debug=False) demo.launch()