Spaces:
Running
Running
ndhieunguyen
commited on
Commit
·
6a53dd4
1
Parent(s):
92d2d45
feat: first commit
Browse files- app.py +152 -0
- inference.sh +0 -0
- requirements.txt +97 -0
- src/__pycache__/modeling_t5.cpython-39.pyc +0 -0
- src/data.py +101 -0
- src/explainer.py +32 -0
- src/model.py +58 -0
- src/modeling_t5.py +0 -0
- src/old_modeling_t5.py +0 -0
- src/utils.py +43 -0
app.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.modeling_t5 import T5ForSequenceClassification
|
2 |
+
import selfies as sf
|
3 |
+
import pandas as pd
|
4 |
+
from transformers import AutoTokenizer, pipeline
|
5 |
+
from chemistry_adapters.amino_acids import AminoAcidAdapter
|
6 |
+
from tqdm import tqdm
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
|
10 |
+
class xBitterT5_predictor:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
xBitterT5_640_ckpt="cbbl-skku-org/xBitterT5-640",
|
14 |
+
xBitterT5_720_ckpt="cbbl-skku-org/xBitterT5-720",
|
15 |
+
device="cpu",
|
16 |
+
):
|
17 |
+
self.xBitterT5_640_ckpt = xBitterT5_640_ckpt
|
18 |
+
self.xBitterT5_720_ckpt = xBitterT5_720_ckpt
|
19 |
+
self.device = device
|
20 |
+
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(xBitterT5_640_ckpt)
|
22 |
+
self.xBitterT5_640 = self.load_model(xBitterT5_640_ckpt)
|
23 |
+
self.xBitterT5_720 = self.load_model(xBitterT5_720_ckpt)
|
24 |
+
|
25 |
+
self.classifier_640 = pipeline(
|
26 |
+
"text-classification",
|
27 |
+
model=self.xBitterT5_640,
|
28 |
+
tokenizer=self.tokenizer,
|
29 |
+
device=self.device,
|
30 |
+
)
|
31 |
+
self.classifier_720 = pipeline(
|
32 |
+
"text-classification",
|
33 |
+
model=self.xBitterT5_720,
|
34 |
+
tokenizer=self.tokenizer,
|
35 |
+
device=self.device,
|
36 |
+
)
|
37 |
+
|
38 |
+
def load_model(self, ckpt):
|
39 |
+
model = T5ForSequenceClassification.from_pretrained(ckpt)
|
40 |
+
model.eval()
|
41 |
+
model.to(self.device)
|
42 |
+
return model
|
43 |
+
|
44 |
+
def convert_sequence_to_smiles(self, sequence):
|
45 |
+
adapter = AminoAcidAdapter()
|
46 |
+
return adapter.convert_amino_acid_sequence_to_smiles(sequence)
|
47 |
+
|
48 |
+
def conver_smiles_to_selfies(self, smiles):
|
49 |
+
return sf.encoder(smiles)
|
50 |
+
|
51 |
+
def predict(
|
52 |
+
self,
|
53 |
+
input_dict,
|
54 |
+
model_type="xBitterT5-720",
|
55 |
+
batch_size=4,
|
56 |
+
):
|
57 |
+
assert model_type in ["xBitterT5-640", "xBitterT5-720"]
|
58 |
+
df = pd.DataFrame(
|
59 |
+
{"id": list(input_dict.keys()), "sequence": list(input_dict.values())}
|
60 |
+
)
|
61 |
+
|
62 |
+
df["smiles"] = df.apply(
|
63 |
+
lambda row: self.convert_sequence_to_smiles(row["sequence"]),
|
64 |
+
axis=1,
|
65 |
+
)
|
66 |
+
df["selfies"] = df.apply(
|
67 |
+
lambda row: self.conver_smiles_to_selfies(row["smiles"]),
|
68 |
+
axis=1,
|
69 |
+
)
|
70 |
+
|
71 |
+
df["sequence"] = df.apply(
|
72 |
+
lambda row: "<bop>"
|
73 |
+
+ "".join("<p>" + aa for aa in row["sequence"])
|
74 |
+
+ "<eop>",
|
75 |
+
axis=1,
|
76 |
+
)
|
77 |
+
df["selfies"] = df.apply(lambda row: "<bom>" + row["selfies"] + "<eom>", axis=1)
|
78 |
+
df["text"] = df["sequence"] + df["selfies"]
|
79 |
+
|
80 |
+
text_inputs = df["text"].tolist()
|
81 |
+
|
82 |
+
if model_type == "xBitterT5-640":
|
83 |
+
classifier = self.classifier_640
|
84 |
+
else:
|
85 |
+
classifier = self.classifier_720
|
86 |
+
|
87 |
+
result = []
|
88 |
+
for i in tqdm(range(0, len(text_inputs), batch_size)):
|
89 |
+
batch = text_inputs[i : i + batch_size]
|
90 |
+
result.extend(classifier(batch))
|
91 |
+
|
92 |
+
y_pred, y_prob = [], []
|
93 |
+
for pred in result:
|
94 |
+
if pred["label"] == "bitter":
|
95 |
+
y_prob.append(pred["score"])
|
96 |
+
y_pred.append(1)
|
97 |
+
else:
|
98 |
+
y_prob.append(1 - pred["score"])
|
99 |
+
y_pred.append(0)
|
100 |
+
|
101 |
+
return {i: [y_prob[j], y_pred[j]] for j, i in enumerate(df["id"].tolist())}
|
102 |
+
|
103 |
+
|
104 |
+
predictor = xBitterT5_predictor()
|
105 |
+
|
106 |
+
|
107 |
+
def process_fasta(fasta_text):
|
108 |
+
"""
|
109 |
+
Processes the input FASTA format text into a dictionary {id: sequence}.
|
110 |
+
"""
|
111 |
+
fasta_dict = {}
|
112 |
+
current_id = None
|
113 |
+
current_sequence = []
|
114 |
+
|
115 |
+
for line in fasta_text.strip().split("\n"):
|
116 |
+
line = line.strip()
|
117 |
+
if line.startswith(">"): # Header line
|
118 |
+
if current_id:
|
119 |
+
fasta_dict[current_id] = "".join(current_sequence)
|
120 |
+
current_id = line[1:] # Remove '>'
|
121 |
+
current_sequence = []
|
122 |
+
else:
|
123 |
+
current_sequence.append(line)
|
124 |
+
|
125 |
+
# Add the last sequence
|
126 |
+
if current_id:
|
127 |
+
fasta_dict[current_id] = "".join(current_sequence)
|
128 |
+
|
129 |
+
return fasta_dict
|
130 |
+
|
131 |
+
|
132 |
+
# Create a Gradio interface
|
133 |
+
def gradio_process_fasta(fasta_text):
|
134 |
+
"""
|
135 |
+
Wrapper for Gradio to process the FASTA text.
|
136 |
+
"""
|
137 |
+
fasta_dict = process_fasta(fasta_text)
|
138 |
+
result = predictor.predict(fasta_dict)
|
139 |
+
return result
|
140 |
+
|
141 |
+
|
142 |
+
interface = gr.Interface(
|
143 |
+
fn=gradio_process_fasta,
|
144 |
+
inputs=gr.Textbox(
|
145 |
+
label="Enter FASTA format text", lines=10, placeholder=">id1\nATGC\n>id2\nCGTA"
|
146 |
+
),
|
147 |
+
outputs=gr.JSON(label="Processed FASTA Dictionary with Probabilities and Classes"),
|
148 |
+
title="FASTA to Dictionary with Probabilities and Classes",
|
149 |
+
description=("Enter a FASTA-formatted text"),
|
150 |
+
)
|
151 |
+
# Launch the Gradio app
|
152 |
+
interface.launch()
|
inference.sh
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohappyeyeballs==2.4.4
|
2 |
+
aiohttp==3.11.11
|
3 |
+
aiosignal==1.3.2
|
4 |
+
async-timeout==5.0.1
|
5 |
+
attrs==24.3.0
|
6 |
+
backcall==0.2.0
|
7 |
+
captum==0.7.0
|
8 |
+
certifi==2024.12.14
|
9 |
+
charset-normalizer==3.4.1
|
10 |
+
chemistry_adapters==0.0.2
|
11 |
+
contourpy==1.3.0
|
12 |
+
cycler==0.12.1
|
13 |
+
datasets==3.2.0
|
14 |
+
decorator==5.1.1
|
15 |
+
dill==0.3.8
|
16 |
+
evaluate==0.4.3
|
17 |
+
filelock==3.16.1
|
18 |
+
fonttools==4.55.3
|
19 |
+
frozenlist==1.5.0
|
20 |
+
fsspec==2024.9.0
|
21 |
+
huggingface-hub==0.27.1
|
22 |
+
idna==3.10
|
23 |
+
importlib_resources==6.5.2
|
24 |
+
ipython==7.34.0
|
25 |
+
jedi==0.19.2
|
26 |
+
Jinja2==3.1.5
|
27 |
+
joblib==1.4.2
|
28 |
+
kiwisolver==1.4.7
|
29 |
+
MarkupSafe==3.0.2
|
30 |
+
matplotlib==3.9.4
|
31 |
+
matplotlib-inline==0.1.7
|
32 |
+
mpmath==1.3.0
|
33 |
+
multidict==6.1.0
|
34 |
+
multiprocess==0.70.16
|
35 |
+
networkx==3.2.1
|
36 |
+
numpy==2.0.2
|
37 |
+
nvidia-cublas-cu11==11.11.3.6
|
38 |
+
nvidia-cublas-cu12==12.1.3.1
|
39 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
40 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
41 |
+
nvidia-cuda-nvrtc-cu11==11.8.89
|
42 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
43 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
44 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
45 |
+
nvidia-cudnn-cu11==9.1.0.70
|
46 |
+
nvidia-cudnn-cu12==9.1.0.70
|
47 |
+
nvidia-cufft-cu11==10.9.0.58
|
48 |
+
nvidia-cufft-cu12==11.0.2.54
|
49 |
+
nvidia-curand-cu11==10.3.0.86
|
50 |
+
nvidia-curand-cu12==10.3.2.106
|
51 |
+
nvidia-cusolver-cu11==11.4.1.48
|
52 |
+
nvidia-cusolver-cu12==11.4.5.107
|
53 |
+
nvidia-cusparse-cu11==11.7.5.86
|
54 |
+
nvidia-cusparse-cu12==12.1.0.106
|
55 |
+
nvidia-nccl-cu11==2.21.5
|
56 |
+
nvidia-nccl-cu12==2.21.5
|
57 |
+
nvidia-nvjitlink-cu12==12.4.127
|
58 |
+
nvidia-nvtx-cu11==11.8.86
|
59 |
+
nvidia-nvtx-cu12==12.1.105
|
60 |
+
packaging==24.2
|
61 |
+
pandas==2.2.3
|
62 |
+
parso==0.8.4
|
63 |
+
pexpect==4.9.0
|
64 |
+
pickleshare==0.7.5
|
65 |
+
pillow==11.1.0
|
66 |
+
prompt_toolkit==3.0.50
|
67 |
+
propcache==0.2.1
|
68 |
+
ptyprocess==0.7.0
|
69 |
+
pyarrow==19.0.0
|
70 |
+
Pygments==2.19.1
|
71 |
+
pyparsing==3.2.1
|
72 |
+
python-dateutil==2.9.0.post0
|
73 |
+
pytz==2024.2
|
74 |
+
PyYAML==6.0.2
|
75 |
+
regex==2024.11.6
|
76 |
+
requests==2.32.3
|
77 |
+
safetensors==0.5.2
|
78 |
+
scikit-learn==1.6.1
|
79 |
+
scipy==1.13.1
|
80 |
+
selfies==2.2.0
|
81 |
+
six==1.17.0
|
82 |
+
sympy==1.13.1
|
83 |
+
threadpoolctl==3.5.0
|
84 |
+
tokenizers==0.21.0
|
85 |
+
torch==2.5.1+cu121
|
86 |
+
tqdm==4.67.1
|
87 |
+
traitlets==5.14.3
|
88 |
+
transformers==4.48.1
|
89 |
+
transformers-interpret==0.10.0
|
90 |
+
triton==3.1.0
|
91 |
+
typing_extensions==4.12.2
|
92 |
+
tzdata==2024.2
|
93 |
+
urllib3==2.3.0
|
94 |
+
wcwidth==0.2.13
|
95 |
+
xxhash==3.5.0
|
96 |
+
yarl==1.18.3
|
97 |
+
zipp==3.21.0
|
src/__pycache__/modeling_t5.cpython-39.pyc
ADDED
Binary file (71.4 kB). View file
|
|
src/data.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import Dataset, DatasetDict
|
2 |
+
import pandas as pd
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def create_dataset_from_dataframe(
|
9 |
+
dataframe_path, pretrained_name, chosen_features=None
|
10 |
+
):
|
11 |
+
dataframe = pd.read_csv(dataframe_path, usecols=["label"] + chosen_features)
|
12 |
+
rows_with_nan = dataframe[chosen_features].isna().any(axis=1)
|
13 |
+
dataframe = dataframe[np.logical_not(rows_with_nan)]
|
14 |
+
if len(chosen_features) > 1:
|
15 |
+
for feature in chosen_features:
|
16 |
+
if feature == "selfies":
|
17 |
+
dataframe[feature] = dataframe.apply(
|
18 |
+
lambda row: "<bom>" + row[feature] + "<eom>", axis=1
|
19 |
+
)
|
20 |
+
elif feature == "sequence":
|
21 |
+
dataframe[feature] = dataframe.apply(
|
22 |
+
lambda row: "<bop>"
|
23 |
+
+ "".join("<p>" + aa for aa in row[feature])
|
24 |
+
+ "<eop>",
|
25 |
+
axis=1,
|
26 |
+
)
|
27 |
+
|
28 |
+
dataframe["text"] = dataframe.apply(
|
29 |
+
lambda row: "".join([f"{row[feature]}" for feature in chosen_features]),
|
30 |
+
axis=1,
|
31 |
+
)
|
32 |
+
|
33 |
+
elif len(chosen_features) == 1:
|
34 |
+
chosen_feature = chosen_features[0]
|
35 |
+
if chosen_feature == "selfies":
|
36 |
+
dataframe["text"] = dataframe.apply(
|
37 |
+
lambda row: "<bom>" + row[chosen_feature] + "<eom>", axis=1
|
38 |
+
)
|
39 |
+
elif chosen_feature == "smiles":
|
40 |
+
dataframe["text"] = dataframe[chosen_feature]
|
41 |
+
elif chosen_feature == "sequence":
|
42 |
+
if "biot5" in pretrained_name:
|
43 |
+
dataframe["text"] = dataframe.apply(
|
44 |
+
lambda row: "<bop>"
|
45 |
+
+ "".join("<p>" + aa for aa in row[chosen_feature])
|
46 |
+
+ "<eop>",
|
47 |
+
axis=1,
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
dataframe["text"] = dataframe.apply(
|
51 |
+
lambda row: " ".join(row[chosen_feature]),
|
52 |
+
axis=1,
|
53 |
+
)
|
54 |
+
dataframe.drop(columns=chosen_features, inplace=True)
|
55 |
+
dataset = Dataset.from_pandas(dataframe)
|
56 |
+
return dataset
|
57 |
+
|
58 |
+
|
59 |
+
def create_and_save_datadict(train, val, test, save_path):
|
60 |
+
if val is None:
|
61 |
+
dataset_dict = DatasetDict({"train": train, "test": test})
|
62 |
+
dataset_dict.save_to_disk(save_path)
|
63 |
+
return dataset_dict
|
64 |
+
dataset_dict = DatasetDict({"train": train, "val": val, "test": test})
|
65 |
+
dataset_dict.save_to_disk(save_path)
|
66 |
+
return dataset_dict
|
67 |
+
|
68 |
+
|
69 |
+
def prepare_dataset(args):
|
70 |
+
fold_folders = glob.glob(args.data_folder + "/fold_*/")
|
71 |
+
for fold_folder in fold_folders:
|
72 |
+
train_path = os.path.join(fold_folder, "train.csv")
|
73 |
+
val_path = os.path.join(fold_folder, "val.csv")
|
74 |
+
test_path = os.path.join(fold_folder, "test.csv")
|
75 |
+
|
76 |
+
train = create_dataset_from_dataframe(
|
77 |
+
train_path, args.pretrained_name, args.chosen_features
|
78 |
+
)
|
79 |
+
val = create_dataset_from_dataframe(
|
80 |
+
val_path, args.pretrained_name, args.chosen_features
|
81 |
+
)
|
82 |
+
test = create_dataset_from_dataframe(
|
83 |
+
test_path, args.pretrained_name, args.chosen_features
|
84 |
+
)
|
85 |
+
folder_name = f"dataset_{'_'.join(args.chosen_features)}_{args.pretrained_name.split('/')[-1].replace('-', '_')}"
|
86 |
+
save_path = os.path.join(fold_folder, folder_name)
|
87 |
+
create_and_save_datadict(train, val, test, save_path)
|
88 |
+
|
89 |
+
train_path = os.path.join(args.data_folder, "train.csv")
|
90 |
+
test_path = os.path.join(args.data_folder, "test.csv")
|
91 |
+
train = create_dataset_from_dataframe(
|
92 |
+
train_path, args.pretrained_name, args.chosen_features
|
93 |
+
)
|
94 |
+
test = create_dataset_from_dataframe(
|
95 |
+
test_path, args.pretrained_name, args.chosen_features
|
96 |
+
)
|
97 |
+
save_path = os.path.join(
|
98 |
+
args.data_folder,
|
99 |
+
f"dataset_{'_'.join(args.chosen_features)}_{args.pretrained_name.split('/')[-1].replace('-', '_')}",
|
100 |
+
)
|
101 |
+
create_and_save_datadict(train, None, test, save_path)
|
src/explainer.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers_interpret import SequenceClassificationExplainer
|
2 |
+
from typing import List, Tuple, Union
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class xBitterT5_explainer(SequenceClassificationExplainer):
|
7 |
+
def _make_input_reference_pair(
|
8 |
+
self, text: Union[List, str]
|
9 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
10 |
+
if isinstance(text, list):
|
11 |
+
raise NotImplementedError("Lists of text are not currently supported.")
|
12 |
+
|
13 |
+
text_ids = self.encode(text)
|
14 |
+
input_ids = self.tokenizer.encode(text, add_special_tokens=True)
|
15 |
+
|
16 |
+
# if no special tokens were added
|
17 |
+
if len(text_ids) == len(input_ids):
|
18 |
+
ref_input_ids = [self.ref_token_id] * len(text_ids)
|
19 |
+
else:
|
20 |
+
ref_input_ids = (
|
21 |
+
[self.cls_token_id]
|
22 |
+
+ [self.ref_token_id] * len(text_ids)
|
23 |
+
+ [self.sep_token_id]
|
24 |
+
)
|
25 |
+
|
26 |
+
# Use this because pretrained BioT5 plus does not have cls_token_id
|
27 |
+
ref_input_ids = [self.ref_token_id] * len(text_ids) + [self.sep_token_id]
|
28 |
+
return (
|
29 |
+
torch.tensor([input_ids], device=self.device),
|
30 |
+
torch.tensor([ref_input_ids], device=self.device),
|
31 |
+
len(text_ids),
|
32 |
+
)
|
src/model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
2 |
+
from transformers import AutoTokenizer, T5Tokenizer, AutoConfig
|
3 |
+
from src.modeling_t5 import T5ForSequenceClassification
|
4 |
+
|
5 |
+
|
6 |
+
def prepare_tokenizer(args):
|
7 |
+
try:
|
8 |
+
try:
|
9 |
+
return AutoTokenizer.from_pretrained(args.pretrained_name)
|
10 |
+
except Exception as e:
|
11 |
+
print(f"Error: {e}")
|
12 |
+
return T5Tokenizer.from_pretrained(
|
13 |
+
args.pretrained_name,
|
14 |
+
do_lower_case=False,
|
15 |
+
)
|
16 |
+
except Exception as e:
|
17 |
+
print(f"Error: {e}")
|
18 |
+
return T5Tokenizer.from_pretrained(args.pretrained_name)
|
19 |
+
|
20 |
+
|
21 |
+
def check_unfreeze_layer(name, trainable_layers):
|
22 |
+
flag = False
|
23 |
+
for layer in trainable_layers:
|
24 |
+
if name.startswith(f"transformer.decoder.block.{layer}"):
|
25 |
+
flag = True
|
26 |
+
break
|
27 |
+
return flag
|
28 |
+
|
29 |
+
|
30 |
+
def prepare_model(args):
|
31 |
+
id2lable = {0: "non-bitter", 1: "bitter"}
|
32 |
+
label2id = {"non-bitter": 0, "bitter": 1}
|
33 |
+
config = AutoConfig.from_pretrained(
|
34 |
+
args.pretrained_name,
|
35 |
+
cache_dir=args.cache_dir,
|
36 |
+
num_labels=2,
|
37 |
+
id2label=id2lable,
|
38 |
+
label2id=label2id,
|
39 |
+
)
|
40 |
+
config.dropout_rate = args.dropout
|
41 |
+
config.classifier_dropout = args.dropout
|
42 |
+
config.problem_type = "single_label_classification"
|
43 |
+
|
44 |
+
model = T5ForSequenceClassification.from_pretrained(
|
45 |
+
args.pretrained_name,
|
46 |
+
cache_dir=args.cache_dir,
|
47 |
+
config=config,
|
48 |
+
)
|
49 |
+
model.to(args.accelerator)
|
50 |
+
for name, param in model.named_parameters():
|
51 |
+
if name.startswith("classification_head") or check_unfreeze_layer(
|
52 |
+
name, args.trainable_layers
|
53 |
+
):
|
54 |
+
param.requires_grad = True
|
55 |
+
else:
|
56 |
+
param.requires_grad = False
|
57 |
+
|
58 |
+
return model
|
src/modeling_t5.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/old_modeling_t5.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
import numpy as np
|
3 |
+
from datetime import datetime
|
4 |
+
from zoneinfo import ZoneInfo
|
5 |
+
from torch.nn.functional import softmax
|
6 |
+
from torch import tensor
|
7 |
+
from sklearn.metrics import confusion_matrix, roc_curve, auc
|
8 |
+
|
9 |
+
|
10 |
+
bitter_metrics = evaluate.combine(
|
11 |
+
["accuracy", "f1", "precision", "recall", "matthews_correlation"]
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def compute_metrics(eval_pred):
|
16 |
+
predictions, labels = eval_pred
|
17 |
+
preds = np.argmax(predictions[0], axis=1)
|
18 |
+
prediction_scores = softmax(tensor(predictions[0]), dim=-1)
|
19 |
+
prediction_scores = prediction_scores[:, 1].cpu().numpy()
|
20 |
+
|
21 |
+
metrics = bitter_metrics.compute(predictions=preds, references=labels)
|
22 |
+
tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
|
23 |
+
specificity = tn / (tn + fp)
|
24 |
+
metrics.update(
|
25 |
+
{
|
26 |
+
"eval_specificity": specificity,
|
27 |
+
"eval_tn": tn,
|
28 |
+
"eval_fp": fp,
|
29 |
+
"eval_fn": fn,
|
30 |
+
"eval_tp": tp,
|
31 |
+
}
|
32 |
+
)
|
33 |
+
|
34 |
+
fpr2, tpr2, _ = roc_curve(labels, prediction_scores, pos_label=1)
|
35 |
+
auc2 = auc(fpr2, tpr2)
|
36 |
+
metrics.update({"eval_auc": auc2})
|
37 |
+
|
38 |
+
metrics = dict(sorted(metrics.items()))
|
39 |
+
return metrics
|
40 |
+
|
41 |
+
|
42 |
+
def get_time_string():
|
43 |
+
return datetime.now(tz=ZoneInfo("Asia/Seoul")).strftime("%Y_%m_%d__%H_%M_%S")
|