from src.modeling_t5 import T5ForSequenceClassification import selfies as sf import pandas as pd from transformers import AutoTokenizer, pipeline from chemistry_adapters.amino_acids import AminoAcidAdapter from tqdm import tqdm import gradio as gr class xBitterT5_predictor: def __init__( self, xBitterT5_640_ckpt="cbbl-skku-org/xBitterT5-640", xBitterT5_720_ckpt="cbbl-skku-org/xBitterT5-720", device="cpu", ): self.xBitterT5_640_ckpt = xBitterT5_640_ckpt self.xBitterT5_720_ckpt = xBitterT5_720_ckpt self.device = device self.tokenizer = AutoTokenizer.from_pretrained(xBitterT5_640_ckpt) self.xBitterT5_640 = self.load_model(xBitterT5_640_ckpt) self.xBitterT5_720 = self.load_model(xBitterT5_720_ckpt) self.classifier_640 = pipeline( "text-classification", model=self.xBitterT5_640, tokenizer=self.tokenizer, device=self.device, ) self.classifier_720 = pipeline( "text-classification", model=self.xBitterT5_720, tokenizer=self.tokenizer, device=self.device, ) def load_model(self, ckpt): model = T5ForSequenceClassification.from_pretrained(ckpt) model.eval() model.to(self.device) return model def convert_sequence_to_smiles(self, sequence): adapter = AminoAcidAdapter() return adapter.convert_amino_acid_sequence_to_smiles(sequence) def conver_smiles_to_selfies(self, smiles): return sf.encoder(smiles) def predict( self, input_dict, model_type="xBitterT5-720", batch_size=4, ): assert model_type in ["xBitterT5-640", "xBitterT5-720"] df = pd.DataFrame( {"id": list(input_dict.keys()), "sequence": list(input_dict.values())} ) df["smiles"] = df.apply( lambda row: self.convert_sequence_to_smiles(row["sequence"]), axis=1, ) df["selfies"] = df.apply( lambda row: self.conver_smiles_to_selfies(row["smiles"]), axis=1, ) df["sequence"] = df.apply( lambda row: "" + "".join("

" + aa for aa in row["sequence"]) + "", axis=1, ) df["selfies"] = df.apply(lambda row: "" + row["selfies"] + "", axis=1) df["text"] = df["sequence"] + df["selfies"] text_inputs = df["text"].tolist() if model_type == "xBitterT5-640": classifier = self.classifier_640 else: classifier = self.classifier_720 result = [] for i in tqdm(range(0, len(text_inputs), batch_size)): batch = text_inputs[i : i + batch_size] result.extend(classifier(batch)) y_pred, y_prob = [], [] for pred in result: if pred["label"] == "bitter": y_prob.append(pred["score"]) y_pred.append(1) else: y_prob.append(1 - pred["score"]) y_pred.append(0) return {i: [y_prob[j], y_pred[j]] for j, i in enumerate(df["id"].tolist())} predictor = xBitterT5_predictor() def process_fasta(fasta_text): """ Processes the input FASTA format text into a dictionary {id: sequence}. """ fasta_dict = {} current_id = None current_sequence = [] for line in fasta_text.strip().split("\n"): line = line.strip() if line.startswith(">"): # Header line if current_id: fasta_dict[current_id] = "".join(current_sequence) current_id = line[1:] # Remove '>' current_sequence = [] else: current_sequence.append(line) # Add the last sequence if current_id: fasta_dict[current_id] = "".join(current_sequence) return fasta_dict # Create a Gradio interface def predict(choice, fasta_text): """ Wrapper for Gradio to process the FASTA text. """ fasta_dict = process_fasta(fasta_text) result = predictor.predict(fasta_dict, model_type=choice) result_df = pd.DataFrame( { "id": list(result.keys()), "probability": [i[0] for i in result.values()], "class": ["bitter" if i[1] == 1 else "non-bitter" for i in result.values()], } ) # text_result = f"ID\tClass\tProbability\n" # for key, value in result.items(): # text_result += ( # f"{key}\t{'bitter' if value[1] == 1 else 'non-bitter'}\t{value[0]}\n" # ) return result_df interface = gr.Interface( fn=predict, inputs=[ gr.Dropdown( choices=["xBitterT5-640", "xBitterT5-720"], label="Select xBitterT5 variant", value="xBitterT5-720", ), gr.Textbox( label="Enter peptide sequences in FASTA format", lines=10, placeholder=">id1\nVAPFPE\n>id2\nRRPP\n>id3\nGH\nid4\nGVDTK", ), ], # outputs=gr.Textbox(label="Predictions", type="text"), outputs=gr.Dataframe( headers=["ID", "Class", "Probability"], ), title="xBitterT5", description=("Prediction of bitter peptides using xBitterT5."), flagging_mode="never", ) # Launch the Gradio app interface.launch()