ndhieunguyen commited on
Commit
6a53dd4
·
1 Parent(s): 92d2d45

feat: first commit

Browse files
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.modeling_t5 import T5ForSequenceClassification
2
+ import selfies as sf
3
+ import pandas as pd
4
+ from transformers import AutoTokenizer, pipeline
5
+ from chemistry_adapters.amino_acids import AminoAcidAdapter
6
+ from tqdm import tqdm
7
+ import gradio as gr
8
+
9
+
10
+ class xBitterT5_predictor:
11
+ def __init__(
12
+ self,
13
+ xBitterT5_640_ckpt="cbbl-skku-org/xBitterT5-640",
14
+ xBitterT5_720_ckpt="cbbl-skku-org/xBitterT5-720",
15
+ device="cpu",
16
+ ):
17
+ self.xBitterT5_640_ckpt = xBitterT5_640_ckpt
18
+ self.xBitterT5_720_ckpt = xBitterT5_720_ckpt
19
+ self.device = device
20
+
21
+ self.tokenizer = AutoTokenizer.from_pretrained(xBitterT5_640_ckpt)
22
+ self.xBitterT5_640 = self.load_model(xBitterT5_640_ckpt)
23
+ self.xBitterT5_720 = self.load_model(xBitterT5_720_ckpt)
24
+
25
+ self.classifier_640 = pipeline(
26
+ "text-classification",
27
+ model=self.xBitterT5_640,
28
+ tokenizer=self.tokenizer,
29
+ device=self.device,
30
+ )
31
+ self.classifier_720 = pipeline(
32
+ "text-classification",
33
+ model=self.xBitterT5_720,
34
+ tokenizer=self.tokenizer,
35
+ device=self.device,
36
+ )
37
+
38
+ def load_model(self, ckpt):
39
+ model = T5ForSequenceClassification.from_pretrained(ckpt)
40
+ model.eval()
41
+ model.to(self.device)
42
+ return model
43
+
44
+ def convert_sequence_to_smiles(self, sequence):
45
+ adapter = AminoAcidAdapter()
46
+ return adapter.convert_amino_acid_sequence_to_smiles(sequence)
47
+
48
+ def conver_smiles_to_selfies(self, smiles):
49
+ return sf.encoder(smiles)
50
+
51
+ def predict(
52
+ self,
53
+ input_dict,
54
+ model_type="xBitterT5-720",
55
+ batch_size=4,
56
+ ):
57
+ assert model_type in ["xBitterT5-640", "xBitterT5-720"]
58
+ df = pd.DataFrame(
59
+ {"id": list(input_dict.keys()), "sequence": list(input_dict.values())}
60
+ )
61
+
62
+ df["smiles"] = df.apply(
63
+ lambda row: self.convert_sequence_to_smiles(row["sequence"]),
64
+ axis=1,
65
+ )
66
+ df["selfies"] = df.apply(
67
+ lambda row: self.conver_smiles_to_selfies(row["smiles"]),
68
+ axis=1,
69
+ )
70
+
71
+ df["sequence"] = df.apply(
72
+ lambda row: "<bop>"
73
+ + "".join("<p>" + aa for aa in row["sequence"])
74
+ + "<eop>",
75
+ axis=1,
76
+ )
77
+ df["selfies"] = df.apply(lambda row: "<bom>" + row["selfies"] + "<eom>", axis=1)
78
+ df["text"] = df["sequence"] + df["selfies"]
79
+
80
+ text_inputs = df["text"].tolist()
81
+
82
+ if model_type == "xBitterT5-640":
83
+ classifier = self.classifier_640
84
+ else:
85
+ classifier = self.classifier_720
86
+
87
+ result = []
88
+ for i in tqdm(range(0, len(text_inputs), batch_size)):
89
+ batch = text_inputs[i : i + batch_size]
90
+ result.extend(classifier(batch))
91
+
92
+ y_pred, y_prob = [], []
93
+ for pred in result:
94
+ if pred["label"] == "bitter":
95
+ y_prob.append(pred["score"])
96
+ y_pred.append(1)
97
+ else:
98
+ y_prob.append(1 - pred["score"])
99
+ y_pred.append(0)
100
+
101
+ return {i: [y_prob[j], y_pred[j]] for j, i in enumerate(df["id"].tolist())}
102
+
103
+
104
+ predictor = xBitterT5_predictor()
105
+
106
+
107
+ def process_fasta(fasta_text):
108
+ """
109
+ Processes the input FASTA format text into a dictionary {id: sequence}.
110
+ """
111
+ fasta_dict = {}
112
+ current_id = None
113
+ current_sequence = []
114
+
115
+ for line in fasta_text.strip().split("\n"):
116
+ line = line.strip()
117
+ if line.startswith(">"): # Header line
118
+ if current_id:
119
+ fasta_dict[current_id] = "".join(current_sequence)
120
+ current_id = line[1:] # Remove '>'
121
+ current_sequence = []
122
+ else:
123
+ current_sequence.append(line)
124
+
125
+ # Add the last sequence
126
+ if current_id:
127
+ fasta_dict[current_id] = "".join(current_sequence)
128
+
129
+ return fasta_dict
130
+
131
+
132
+ # Create a Gradio interface
133
+ def gradio_process_fasta(fasta_text):
134
+ """
135
+ Wrapper for Gradio to process the FASTA text.
136
+ """
137
+ fasta_dict = process_fasta(fasta_text)
138
+ result = predictor.predict(fasta_dict)
139
+ return result
140
+
141
+
142
+ interface = gr.Interface(
143
+ fn=gradio_process_fasta,
144
+ inputs=gr.Textbox(
145
+ label="Enter FASTA format text", lines=10, placeholder=">id1\nATGC\n>id2\nCGTA"
146
+ ),
147
+ outputs=gr.JSON(label="Processed FASTA Dictionary with Probabilities and Classes"),
148
+ title="FASTA to Dictionary with Probabilities and Classes",
149
+ description=("Enter a FASTA-formatted text"),
150
+ )
151
+ # Launch the Gradio app
152
+ interface.launch()
inference.sh ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.4
2
+ aiohttp==3.11.11
3
+ aiosignal==1.3.2
4
+ async-timeout==5.0.1
5
+ attrs==24.3.0
6
+ backcall==0.2.0
7
+ captum==0.7.0
8
+ certifi==2024.12.14
9
+ charset-normalizer==3.4.1
10
+ chemistry_adapters==0.0.2
11
+ contourpy==1.3.0
12
+ cycler==0.12.1
13
+ datasets==3.2.0
14
+ decorator==5.1.1
15
+ dill==0.3.8
16
+ evaluate==0.4.3
17
+ filelock==3.16.1
18
+ fonttools==4.55.3
19
+ frozenlist==1.5.0
20
+ fsspec==2024.9.0
21
+ huggingface-hub==0.27.1
22
+ idna==3.10
23
+ importlib_resources==6.5.2
24
+ ipython==7.34.0
25
+ jedi==0.19.2
26
+ Jinja2==3.1.5
27
+ joblib==1.4.2
28
+ kiwisolver==1.4.7
29
+ MarkupSafe==3.0.2
30
+ matplotlib==3.9.4
31
+ matplotlib-inline==0.1.7
32
+ mpmath==1.3.0
33
+ multidict==6.1.0
34
+ multiprocess==0.70.16
35
+ networkx==3.2.1
36
+ numpy==2.0.2
37
+ nvidia-cublas-cu11==11.11.3.6
38
+ nvidia-cublas-cu12==12.1.3.1
39
+ nvidia-cuda-cupti-cu11==11.8.87
40
+ nvidia-cuda-cupti-cu12==12.1.105
41
+ nvidia-cuda-nvrtc-cu11==11.8.89
42
+ nvidia-cuda-nvrtc-cu12==12.1.105
43
+ nvidia-cuda-runtime-cu11==11.8.89
44
+ nvidia-cuda-runtime-cu12==12.1.105
45
+ nvidia-cudnn-cu11==9.1.0.70
46
+ nvidia-cudnn-cu12==9.1.0.70
47
+ nvidia-cufft-cu11==10.9.0.58
48
+ nvidia-cufft-cu12==11.0.2.54
49
+ nvidia-curand-cu11==10.3.0.86
50
+ nvidia-curand-cu12==10.3.2.106
51
+ nvidia-cusolver-cu11==11.4.1.48
52
+ nvidia-cusolver-cu12==11.4.5.107
53
+ nvidia-cusparse-cu11==11.7.5.86
54
+ nvidia-cusparse-cu12==12.1.0.106
55
+ nvidia-nccl-cu11==2.21.5
56
+ nvidia-nccl-cu12==2.21.5
57
+ nvidia-nvjitlink-cu12==12.4.127
58
+ nvidia-nvtx-cu11==11.8.86
59
+ nvidia-nvtx-cu12==12.1.105
60
+ packaging==24.2
61
+ pandas==2.2.3
62
+ parso==0.8.4
63
+ pexpect==4.9.0
64
+ pickleshare==0.7.5
65
+ pillow==11.1.0
66
+ prompt_toolkit==3.0.50
67
+ propcache==0.2.1
68
+ ptyprocess==0.7.0
69
+ pyarrow==19.0.0
70
+ Pygments==2.19.1
71
+ pyparsing==3.2.1
72
+ python-dateutil==2.9.0.post0
73
+ pytz==2024.2
74
+ PyYAML==6.0.2
75
+ regex==2024.11.6
76
+ requests==2.32.3
77
+ safetensors==0.5.2
78
+ scikit-learn==1.6.1
79
+ scipy==1.13.1
80
+ selfies==2.2.0
81
+ six==1.17.0
82
+ sympy==1.13.1
83
+ threadpoolctl==3.5.0
84
+ tokenizers==0.21.0
85
+ torch==2.5.1+cu121
86
+ tqdm==4.67.1
87
+ traitlets==5.14.3
88
+ transformers==4.48.1
89
+ transformers-interpret==0.10.0
90
+ triton==3.1.0
91
+ typing_extensions==4.12.2
92
+ tzdata==2024.2
93
+ urllib3==2.3.0
94
+ wcwidth==0.2.13
95
+ xxhash==3.5.0
96
+ yarl==1.18.3
97
+ zipp==3.21.0
src/__pycache__/modeling_t5.cpython-39.pyc ADDED
Binary file (71.4 kB). View file
 
src/data.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset, DatasetDict
2
+ import pandas as pd
3
+ import glob
4
+ import os
5
+ import numpy as np
6
+
7
+
8
+ def create_dataset_from_dataframe(
9
+ dataframe_path, pretrained_name, chosen_features=None
10
+ ):
11
+ dataframe = pd.read_csv(dataframe_path, usecols=["label"] + chosen_features)
12
+ rows_with_nan = dataframe[chosen_features].isna().any(axis=1)
13
+ dataframe = dataframe[np.logical_not(rows_with_nan)]
14
+ if len(chosen_features) > 1:
15
+ for feature in chosen_features:
16
+ if feature == "selfies":
17
+ dataframe[feature] = dataframe.apply(
18
+ lambda row: "<bom>" + row[feature] + "<eom>", axis=1
19
+ )
20
+ elif feature == "sequence":
21
+ dataframe[feature] = dataframe.apply(
22
+ lambda row: "<bop>"
23
+ + "".join("<p>" + aa for aa in row[feature])
24
+ + "<eop>",
25
+ axis=1,
26
+ )
27
+
28
+ dataframe["text"] = dataframe.apply(
29
+ lambda row: "".join([f"{row[feature]}" for feature in chosen_features]),
30
+ axis=1,
31
+ )
32
+
33
+ elif len(chosen_features) == 1:
34
+ chosen_feature = chosen_features[0]
35
+ if chosen_feature == "selfies":
36
+ dataframe["text"] = dataframe.apply(
37
+ lambda row: "<bom>" + row[chosen_feature] + "<eom>", axis=1
38
+ )
39
+ elif chosen_feature == "smiles":
40
+ dataframe["text"] = dataframe[chosen_feature]
41
+ elif chosen_feature == "sequence":
42
+ if "biot5" in pretrained_name:
43
+ dataframe["text"] = dataframe.apply(
44
+ lambda row: "<bop>"
45
+ + "".join("<p>" + aa for aa in row[chosen_feature])
46
+ + "<eop>",
47
+ axis=1,
48
+ )
49
+ else:
50
+ dataframe["text"] = dataframe.apply(
51
+ lambda row: " ".join(row[chosen_feature]),
52
+ axis=1,
53
+ )
54
+ dataframe.drop(columns=chosen_features, inplace=True)
55
+ dataset = Dataset.from_pandas(dataframe)
56
+ return dataset
57
+
58
+
59
+ def create_and_save_datadict(train, val, test, save_path):
60
+ if val is None:
61
+ dataset_dict = DatasetDict({"train": train, "test": test})
62
+ dataset_dict.save_to_disk(save_path)
63
+ return dataset_dict
64
+ dataset_dict = DatasetDict({"train": train, "val": val, "test": test})
65
+ dataset_dict.save_to_disk(save_path)
66
+ return dataset_dict
67
+
68
+
69
+ def prepare_dataset(args):
70
+ fold_folders = glob.glob(args.data_folder + "/fold_*/")
71
+ for fold_folder in fold_folders:
72
+ train_path = os.path.join(fold_folder, "train.csv")
73
+ val_path = os.path.join(fold_folder, "val.csv")
74
+ test_path = os.path.join(fold_folder, "test.csv")
75
+
76
+ train = create_dataset_from_dataframe(
77
+ train_path, args.pretrained_name, args.chosen_features
78
+ )
79
+ val = create_dataset_from_dataframe(
80
+ val_path, args.pretrained_name, args.chosen_features
81
+ )
82
+ test = create_dataset_from_dataframe(
83
+ test_path, args.pretrained_name, args.chosen_features
84
+ )
85
+ folder_name = f"dataset_{'_'.join(args.chosen_features)}_{args.pretrained_name.split('/')[-1].replace('-', '_')}"
86
+ save_path = os.path.join(fold_folder, folder_name)
87
+ create_and_save_datadict(train, val, test, save_path)
88
+
89
+ train_path = os.path.join(args.data_folder, "train.csv")
90
+ test_path = os.path.join(args.data_folder, "test.csv")
91
+ train = create_dataset_from_dataframe(
92
+ train_path, args.pretrained_name, args.chosen_features
93
+ )
94
+ test = create_dataset_from_dataframe(
95
+ test_path, args.pretrained_name, args.chosen_features
96
+ )
97
+ save_path = os.path.join(
98
+ args.data_folder,
99
+ f"dataset_{'_'.join(args.chosen_features)}_{args.pretrained_name.split('/')[-1].replace('-', '_')}",
100
+ )
101
+ create_and_save_datadict(train, None, test, save_path)
src/explainer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers_interpret import SequenceClassificationExplainer
2
+ from typing import List, Tuple, Union
3
+ import torch
4
+
5
+
6
+ class xBitterT5_explainer(SequenceClassificationExplainer):
7
+ def _make_input_reference_pair(
8
+ self, text: Union[List, str]
9
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
10
+ if isinstance(text, list):
11
+ raise NotImplementedError("Lists of text are not currently supported.")
12
+
13
+ text_ids = self.encode(text)
14
+ input_ids = self.tokenizer.encode(text, add_special_tokens=True)
15
+
16
+ # if no special tokens were added
17
+ if len(text_ids) == len(input_ids):
18
+ ref_input_ids = [self.ref_token_id] * len(text_ids)
19
+ else:
20
+ ref_input_ids = (
21
+ [self.cls_token_id]
22
+ + [self.ref_token_id] * len(text_ids)
23
+ + [self.sep_token_id]
24
+ )
25
+
26
+ # Use this because pretrained BioT5 plus does not have cls_token_id
27
+ ref_input_ids = [self.ref_token_id] * len(text_ids) + [self.sep_token_id]
28
+ return (
29
+ torch.tensor([input_ids], device=self.device),
30
+ torch.tensor([ref_input_ids], device=self.device),
31
+ len(text_ids),
32
+ )
src/model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
2
+ from transformers import AutoTokenizer, T5Tokenizer, AutoConfig
3
+ from src.modeling_t5 import T5ForSequenceClassification
4
+
5
+
6
+ def prepare_tokenizer(args):
7
+ try:
8
+ try:
9
+ return AutoTokenizer.from_pretrained(args.pretrained_name)
10
+ except Exception as e:
11
+ print(f"Error: {e}")
12
+ return T5Tokenizer.from_pretrained(
13
+ args.pretrained_name,
14
+ do_lower_case=False,
15
+ )
16
+ except Exception as e:
17
+ print(f"Error: {e}")
18
+ return T5Tokenizer.from_pretrained(args.pretrained_name)
19
+
20
+
21
+ def check_unfreeze_layer(name, trainable_layers):
22
+ flag = False
23
+ for layer in trainable_layers:
24
+ if name.startswith(f"transformer.decoder.block.{layer}"):
25
+ flag = True
26
+ break
27
+ return flag
28
+
29
+
30
+ def prepare_model(args):
31
+ id2lable = {0: "non-bitter", 1: "bitter"}
32
+ label2id = {"non-bitter": 0, "bitter": 1}
33
+ config = AutoConfig.from_pretrained(
34
+ args.pretrained_name,
35
+ cache_dir=args.cache_dir,
36
+ num_labels=2,
37
+ id2label=id2lable,
38
+ label2id=label2id,
39
+ )
40
+ config.dropout_rate = args.dropout
41
+ config.classifier_dropout = args.dropout
42
+ config.problem_type = "single_label_classification"
43
+
44
+ model = T5ForSequenceClassification.from_pretrained(
45
+ args.pretrained_name,
46
+ cache_dir=args.cache_dir,
47
+ config=config,
48
+ )
49
+ model.to(args.accelerator)
50
+ for name, param in model.named_parameters():
51
+ if name.startswith("classification_head") or check_unfreeze_layer(
52
+ name, args.trainable_layers
53
+ ):
54
+ param.requires_grad = True
55
+ else:
56
+ param.requires_grad = False
57
+
58
+ return model
src/modeling_t5.py ADDED
The diff for this file is too large to render. See raw diff
 
src/old_modeling_t5.py ADDED
The diff for this file is too large to render. See raw diff
 
src/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ from datetime import datetime
4
+ from zoneinfo import ZoneInfo
5
+ from torch.nn.functional import softmax
6
+ from torch import tensor
7
+ from sklearn.metrics import confusion_matrix, roc_curve, auc
8
+
9
+
10
+ bitter_metrics = evaluate.combine(
11
+ ["accuracy", "f1", "precision", "recall", "matthews_correlation"]
12
+ )
13
+
14
+
15
+ def compute_metrics(eval_pred):
16
+ predictions, labels = eval_pred
17
+ preds = np.argmax(predictions[0], axis=1)
18
+ prediction_scores = softmax(tensor(predictions[0]), dim=-1)
19
+ prediction_scores = prediction_scores[:, 1].cpu().numpy()
20
+
21
+ metrics = bitter_metrics.compute(predictions=preds, references=labels)
22
+ tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
23
+ specificity = tn / (tn + fp)
24
+ metrics.update(
25
+ {
26
+ "eval_specificity": specificity,
27
+ "eval_tn": tn,
28
+ "eval_fp": fp,
29
+ "eval_fn": fn,
30
+ "eval_tp": tp,
31
+ }
32
+ )
33
+
34
+ fpr2, tpr2, _ = roc_curve(labels, prediction_scores, pos_label=1)
35
+ auc2 = auc(fpr2, tpr2)
36
+ metrics.update({"eval_auc": auc2})
37
+
38
+ metrics = dict(sorted(metrics.items()))
39
+ return metrics
40
+
41
+
42
+ def get_time_string():
43
+ return datetime.now(tz=ZoneInfo("Asia/Seoul")).strftime("%Y_%m_%d__%H_%M_%S")