Spaces:
Sleeping
Sleeping
from src.modeling_t5 import T5ForSequenceClassification | |
import selfies as sf | |
import pandas as pd | |
from transformers import AutoTokenizer, pipeline | |
from chemistry_adapters.amino_acids import AminoAcidAdapter | |
from tqdm import tqdm | |
import gradio as gr | |
class xBitterT5_predictor: | |
def __init__( | |
self, | |
xBitterT5_640_ckpt="cbbl-skku-org/xBitterT5-640", | |
xBitterT5_720_ckpt="cbbl-skku-org/xBitterT5-720", | |
device="cpu", | |
): | |
self.xBitterT5_640_ckpt = xBitterT5_640_ckpt | |
self.xBitterT5_720_ckpt = xBitterT5_720_ckpt | |
self.device = device | |
self.tokenizer = AutoTokenizer.from_pretrained(xBitterT5_640_ckpt) | |
self.xBitterT5_640 = self.load_model(xBitterT5_640_ckpt) | |
self.xBitterT5_720 = self.load_model(xBitterT5_720_ckpt) | |
self.classifier_640 = pipeline( | |
"text-classification", | |
model=self.xBitterT5_640, | |
tokenizer=self.tokenizer, | |
device=self.device, | |
) | |
self.classifier_720 = pipeline( | |
"text-classification", | |
model=self.xBitterT5_720, | |
tokenizer=self.tokenizer, | |
device=self.device, | |
) | |
def load_model(self, ckpt): | |
model = T5ForSequenceClassification.from_pretrained(ckpt) | |
model.eval() | |
model.to(self.device) | |
return model | |
def convert_sequence_to_smiles(self, sequence): | |
adapter = AminoAcidAdapter() | |
return adapter.convert_amino_acid_sequence_to_smiles(sequence) | |
def conver_smiles_to_selfies(self, smiles): | |
return sf.encoder(smiles) | |
def predict( | |
self, | |
input_dict, | |
model_type="xBitterT5-720", | |
batch_size=4, | |
): | |
assert model_type in ["xBitterT5-640", "xBitterT5-720"] | |
df = pd.DataFrame( | |
{"id": list(input_dict.keys()), "sequence": list(input_dict.values())} | |
) | |
df["smiles"] = df.apply( | |
lambda row: self.convert_sequence_to_smiles(row["sequence"]), | |
axis=1, | |
) | |
df["selfies"] = df.apply( | |
lambda row: self.conver_smiles_to_selfies(row["smiles"]), | |
axis=1, | |
) | |
df["sequence"] = df.apply( | |
lambda row: "<bop>" | |
+ "".join("<p>" + aa for aa in row["sequence"]) | |
+ "<eop>", | |
axis=1, | |
) | |
df["selfies"] = df.apply(lambda row: "<bom>" + row["selfies"] + "<eom>", axis=1) | |
df["text"] = df["sequence"] + df["selfies"] | |
text_inputs = df["text"].tolist() | |
if model_type == "xBitterT5-640": | |
classifier = self.classifier_640 | |
else: | |
classifier = self.classifier_720 | |
result = [] | |
for i in tqdm(range(0, len(text_inputs), batch_size)): | |
batch = text_inputs[i : i + batch_size] | |
result.extend(classifier(batch)) | |
y_pred, y_prob = [], [] | |
for pred in result: | |
if pred["label"] == "bitter": | |
y_prob.append(pred["score"]) | |
y_pred.append(1) | |
else: | |
y_prob.append(1 - pred["score"]) | |
y_pred.append(0) | |
return {i: [y_prob[j], y_pred[j]] for j, i in enumerate(df["id"].tolist())} | |
predictor = xBitterT5_predictor() | |
def process_fasta(fasta_text): | |
""" | |
Processes the input FASTA format text into a dictionary {id: sequence}. | |
""" | |
fasta_dict = {} | |
current_id = None | |
current_sequence = [] | |
for line in fasta_text.strip().split("\n"): | |
line = line.strip() | |
if line.startswith(">"): # Header line | |
if current_id: | |
fasta_dict[current_id] = "".join(current_sequence) | |
current_id = line[1:] # Remove '>' | |
current_sequence = [] | |
else: | |
current_sequence.append(line) | |
# Add the last sequence | |
if current_id: | |
fasta_dict[current_id] = "".join(current_sequence) | |
return fasta_dict | |
# Create a Gradio interface | |
def gradio_process_fasta(fasta_text): | |
""" | |
Wrapper for Gradio to process the FASTA text. | |
""" | |
fasta_dict = process_fasta(fasta_text) | |
result = predictor.predict(fasta_dict) | |
return result | |
interface = gr.Interface( | |
fn=gradio_process_fasta, | |
inputs=gr.Textbox( | |
label="Enter FASTA format text", lines=10, placeholder=">id1\nATGC\n>id2\nCGTA" | |
), | |
outputs=gr.JSON(label="Processed FASTA Dictionary with Probabilities and Classes"), | |
title="FASTA to Dictionary with Probabilities and Classes", | |
description=("Enter a FASTA-formatted text"), | |
) | |
# Launch the Gradio app | |
interface.launch() | |