import Bio.SeqIO as sio import tensorflow as tf import numpy as np from sklearn.preprocessing import LabelBinarizer from tensorflow.keras.utils import to_categorical import random import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' import tqdm #load model filterm = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/AELS_tall.h5')) classifier = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/classifier-ls_tall.h5')) #encode, encode all the sequence to 1600 aa length char_dict = {} chars = 'ACDEFGHIKLMNPQRSTVWXYBJZ' new_chars = "ACDEFGHIKLMNPQRSTVWXY" for char in chars: temp = np.zeros(22) if char == 'B': for ch in 'DN': temp[new_chars.index(ch)] = 0.5 elif char == 'J': for ch in 'IL': temp[new_chars.index(ch)] = 0.5 elif char == 'Z': for ch in 'EQ': temp[new_chars.index(ch)] = 0.5 else: temp[new_chars.index(char)] = 1 char_dict[char] = temp def encode(seq): char = 'ACDEFGHIKLMNPQRSTVWXY' train_array = np.zeros((1600,22)) for i in range(1600): if i= 1600: align = 1600 else: align = length count_simi = 0 #reconstruct = '' for pos in range(align): if chars[ele[pos]] == ori[index][pos]: count_simi += 1 #reconstruct += chars[np.argmax(ele[pos])] simis.append(count_simi / length) #reconstructs.append(reconstruct) return simis def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i:i + n] train_labels = ['beta-lactam', 'multidrug', 'bacitracin', 'MLS', 'aminoglycoside', 'polymyxin', 'tetracycline', 'fosfomycin', 'chloramphenicol', 'glycopeptide', 'quinolone', 'peptide','sulfonamide', 'trimethoprim', 'rifamycin', 'qa_compound', 'aminocoumarin', 'kasugamycin', 'nitroimidazole', 'streptothricin', 'elfamycin', 'fusidic_acid', 'mupirocin', 'tetracenomycin', 'pleuromutilin', 'bleomycin', 'triclosan', 'ethambutol', 'isoniazid', 'tunicamycin', 'nitrofurantoin', 'puromycin', 'thiostrepton', 'pyrazinamide', 'oxazolidinone', 'fosmidomycin'] prepare = sorted(train_labels) label_dic = {} for index, ele in enumerate(prepare): label_dic[index] = ele def argnet_lsnt(input_file, outfile): cut = 0.2553725612 print('reading in test file...') test = [i for i in sio.parse(input_file, 'fasta')] print('encoding test file...') testencode, trans = test_encode(test) testencode_pre1 = [] for ele in list(chunks(testencode, 10000)): temp = filter_prediction_batch(ele) # if huge volumn of seqs (~ millions) this will be change to create batch in advanceā€¢ testencode_pre1.append(temp) testencode_pre = np.vstack([item for sublist in testencode_pre1 for item in sublist]) print('reconstruct, simi...') simis = reconstruction_simi(testencode_pre, trans) passed_encode = [] ### notice list and np.array passed_idx = [] notpass_idx = [] for index, ele in enumerate(simis): if ele >= cut: #passed.append(test[index]) passed_encode.append(testencode[index]) passed_idx.append(index) else: notpass_idx.append(index) ###classification print('classifying...') if len(passed_encode) > 0: classifications = classifier.predict(np.stack(passed_encode, axis=0), batch_size = 512) classification_argmax = np.argmax(classifications, axis=1) classification_max = np.max(classifications, axis=1) out = {} for i, ele in enumerate(passed_idx): out[ele] = [classification_max[i], label_dic[classification_argmax[i]]] ### output print('writing output...') with open(os.path.join(os.path.dirname(__file__), "../results/" + outfile) , 'w') as f: f.write('test_id' + '\t' + 'ARG_prediction' + '\t' + 'resistance_category' + '\t' + 'probability' + '\n') for idx, ele in enumerate(test): if idx in passed_idx: f.write(test[idx].id + '\t') f.write('ARG' + '\t') f.write(out[idx][-1] + '\t') f.write(str(out[idx][0]) + '\n') if idx in notpass_idx: f.write(test[idx].id + '\t') f.write('non-ARG' + '\t' + '' + '\t' + '' + '\n') if len(passed_encode) == 0: print('no seq passed!') with open(os.path.join(os.path.dirname(__file__), "../results/" + outfile) , 'w') as f: f.write('test_id' + '\t' + 'ARG_prediction' + '\t' + 'resistance_category' + '\t' + 'probability' + '\n') for idx, ele in enumerate(test): f.write(test[idx].id + '\t') f.write('non-ARG' + '\t' + '' + '\t' + '' + '\n') #pass