import Bio.SeqIO as sio import tensorflow as tf import numpy as np from sklearn.preprocessing import LabelBinarizer from tensorflow.keras.utils import to_categorical import random import os os.environ['CUDA_VISIBLE_DEVICES'] = '1' import tqdm #load model filterm = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/AELS.h5')) classifier = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/classifier_ls.h5')) #encode, encode all the sequence to 1600 aa length char_dict = {} chars = 'ACDEFGHIKLMNPQRSTVWXYBJZ' new_chars = "ACDEFGHIKLMNPQRSTVWXY" for char in chars: temp = np.zeros(22) if char == 'B': for ch in 'DN': temp[new_chars.index(ch)] = 0.5 elif char == 'J': for ch in 'IL': temp[new_chars.index(ch)] = 0.5 elif char == 'Z': for ch in 'EQ': temp[new_chars.index(ch)] = 0.5 else: temp[new_chars.index(char)] = 1 char_dict[char] = temp def encode(seq): char = 'ACDEFGHIKLMNPQRSTVWXY' train_array = np.zeros((1600,22)) for i in range(1600): if i= cut: #passed.append(test[index]) passed_encode.append(testencode[index]) passed_idx.append(index) else: notpass_idx.append(index) ###classification train_data = [i for i in sio.parse(os.path.join(os.path.dirname(__file__), "../data/train.fasta"),'fasta')] train_labels = [ele.id.split('|')[3].strip() for ele in train_data] encodeder = LabelBinarizer() encoded_train_labels = encodeder.fit_transform(train_labels) prepare = sorted(list(set(train_labels))) label_dic = {} for index, ele in enumerate(prepare): label_dic[index] = ele classifications = [] classifications = classifier.predict(np.stack(passed_encode, axis=0), batch_size = 512) out = {} for i, ele in enumerate(passed_idx): out[ele] = [np.max(classifications[i]), label_dic[np.argmax(classifications[i])]] ### output with open(os.path.join(os.path.dirname(__file__), "../results/" + outfile) , 'w') as f: f.write('test_id' + '\t' + 'ARG_prediction' + '\t' + 'resistance_category' + '\t' + 'probability' + '\n') for idx, ele in enumerate(test): if idx in passed_idx: f.write(test[idx].id + '\t') f.write('ARG' + '\t') f.write(out[idx][-1] + '\t') f.write(str(out[idx][0]) + '\n') if idx in notpass_idx: f.write(test[idx].id + '\t') f.write('non-ARG' + '\t' + '' + '\t' + '' + '\n')