FusionDTI / app.py
Gla-AI4BioMed-Lab's picture
Update app.py
fde18d9 verified
raw
history blame contribute delete
No virus
10.9 kB
import os
import sys
import argparse
import torch
from torch.utils.data import DataLoader
from transformers import EsmForMaskedLM, AutoModel, EsmTokenizer
from utils.drug_tokenizer import DrugTokenizer
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
from bertviz import head_view
import tempfile
from flask import Flask, request, render_template_string
os.environ["TOKENIZERS_PARALLELISM"] = "false"
sys.path.append("../")
app = Flask(__name__)
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument("--prot_encoder_path", type=str, default="westlake-repl/SaProt_650M_AF2", help="path/name of protein encoder model located")
parser.add_argument("--drug_encoder_path", type=str, default="HUBioDataLab/SELFormer", help="path/name of SMILE pre-trained language model")
parser.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
parser.add_argument("--fusion", default="CAN", type=str, help="{CAN|BAN}")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--group_size", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--test", type=int, default=0)
parser.add_argument("--use_pooled", action="store_true", default=True)
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--save_path_prefix", type=str, default="save_model_ckp/", help="save the result in which directory")
parser.add_argument("--save_name", default="fine_tune", type=str, help="the name of the saved file")
parser.add_argument("--dataset", type=str, default="Human", help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')")
return parser.parse_args()
args = parse_config()
device = args.device
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
drug_tokenizer = DrugTokenizer()
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
encoding = Pre_encoded(prot_model, drug_model, args).to(device)
def get_case_feature(model, dataloader, device):
with torch.no_grad():
for step, batch in enumerate(dataloader):
prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask, label = batch
prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask = \
prot_input_ids.to(device), prot_attention_mask.to(device), drug_input_ids.to(device), drug_attention_mask.to(device)
prot_embed, drug_embed = model.encoding(prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask)
prot_embed, drug_embed = prot_embed.cpu(), drug_embed.cpu()
prot_input_ids, drug_input_ids = prot_input_ids.cpu(), drug_input_ids.cpu()
prot_attention_mask, drug_attention_mask = prot_attention_mask.cpu(), drug_attention_mask.cpu()
label = label.cpu()
return [(prot_embed, drug_embed, prot_input_ids, drug_input_ids, prot_attention_mask, drug_attention_mask, label)]
def visualize_attention(model, case_features, device, prot_tokenizer, drug_tokenizer):
model.eval()
with torch.no_grad():
for batch in case_features:
prot, drug, prot_ids, drug_ids, prot_mask, drug_mask, label = batch
prot, drug = prot.to(device), drug.to(device)
prot_mask, drug_mask = prot_mask.to(device), drug_mask.to(device)
output, attention_weights = model(prot, drug, prot_mask, drug_mask)
prot_tokens = [prot_tokenizer.decode([pid.item()], skip_special_tokens=True) for pid in prot_ids.squeeze()]
drug_tokens = [drug_tokenizer.decode([did.item()], skip_special_tokens=True) for did in drug_ids.squeeze()]
tokens = prot_tokens + drug_tokens
attention_weights = attention_weights.unsqueeze(1)
# Generate HTML content using head_view with html_action='return'
html_head_view = head_view(attention_weights, tokens, sentence_b_start=512, html_action='return')
# Parse the HTML and modify it to replace sentence labels
html_content = html_head_view.data
html_content = html_content.replace("Sentence A -> Sentence A", "Protein -> Protein")
html_content = html_content.replace("Sentence B -> Sentence B", "Drug -> Drug")
html_content = html_content.replace("Sentence A -> Sentence B", "Protein -> Drug")
html_content = html_content.replace("Sentence B -> Sentence A", "Drug -> Protein")
# Save the modified HTML content to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
f.write(html_content.encode('utf-8'))
temp_file_path = f.name
return temp_file_path
@app.route('/', methods=['GET', 'POST'])
def index():
protein_sequence = ""
drug_sequence = ""
result = None
if request.method == 'POST':
if 'clear' in request.form:
protein_sequence = ""
drug_sequence = ""
else:
protein_sequence = request.form['protein_sequence']
drug_sequence = request.form['drug_sequence']
dataset = [(protein_sequence, drug_sequence, 1)]
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn_batch_encoding)
case_features = get_case_feature(encoding, dataloader, device)
model = FusionDTI(446, 768, args).to(device)
best_model_dir = f"{args.save_path_prefix}{args.dataset}_{args.fusion}"
checkpoint_path = os.path.join(best_model_dir, 'best_model.ckpt')
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
html_file_path = visualize_attention(model, case_features, device, prot_tokenizer, drug_tokenizer)
with open(html_file_path, 'r') as f:
result = f.read()
return render_template_string('''
<html>
<head>
<title>Drug Target Interaction Visualization</title>
<style>
body { font-family: 'Times New Roman', Times, serif; margin: 40px; }
h2 { color: #333; }
.container { display: flex; }
.left { flex: 1; padding-right: 20px; }
.right { flex: 1; }
textarea {
width: 100%;
padding: 12px 20px;
margin: 8px 0;
display: inline-block;
border: 1px solid #ccc;
border-radius: 4px;
box-sizing: border-box;
font-size: 16px;
font-family: 'Times New Roman', Times, serif;
}
.button-container {
display: flex;
justify-content: space-between;
}
input[type="submit"], .button {
width: 48%;
color: white;
padding: 14px 20px;
margin: 8px 0;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 16px;
font-family: 'Times New Roman', Times, serif;
}
.submit {
background-color: #FFA500;
}
.submit:hover {
background-color: #FF8C00;
}
.clear {
background-color: #D3D3D3;
}
.clear:hover {
background-color: #A9A9A9;
}
.result {
font-size: 18px;
}
</style>
</head>
<body>
<h2 style="text-align: center;">Drug Target Interaction Visualization</h2>
<div class="container">
<div class="left">
<form method="post">
<label for="protein_sequence">Protein Sequence:</label>
<textarea id="protein_sequence" name="protein_sequence" rows="4" placeholder="Enter protein sequence here..." required>{{ protein_sequence }}</textarea><br>
<label for="drug_sequence">Drug Sequence:</label>
<textarea id="drug_sequence" name="drug_sequence" rows="4" placeholder="Enter drug sequence here..." required>{{ drug_sequence }}</textarea><br>
<div class="button-container">
<input type="submit" name="submit" class="button submit" value="Submit">
<input type="submit" name="clear" class="button clear" value="Clear">
</div>
</form>
</div>
<div class="right" style="display: flex; justify-content: center; align-items: center;">
{% if result %}
<div class="result">
{{ result|safe }}
</div>
{% endif %}
</div>
</div>
</body>
</html>
''', protein_sequence=protein_sequence, drug_sequence=drug_sequence, result=result)
def collate_fn_batch_encoding(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer.batch_encode_plus(
list(query1),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
query_encodings2 = drug_tokenizer.batch_encode_plus(
list(query2),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
if __name__ == '__main__':
app.run(debug=True, host="0.0.0.0", port=7860)