import gradio as gr
import numpy as np
import os
import ray
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
#from transformers import pipeline as pl
from transformers import GPT2LMHeadModel , GPT2Tokenizer
from GPUtil import showUtilization as gpu_usage
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import plotly.graph_objects as go
import torch
import gc
import jax
from numba import cuda
import math
print('GPU available',torch.cuda.is_available())
#print('__CUDA Device Name:',torch.cuda.get_device_name(0))
print(os.getcwd())
if "/home/user/app/alphafold" not in sys.path:
sys.path.append("/home/user/app/alphafold")
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model
def mk_mock_template(query_sequence):
"""create blank template"""
ln = len(query_sequence)
output_templates_sequence = "-" * ln
templates_all_atom_positions = np.zeros(
(ln, templates.residue_constants.atom_type_num, 3)
)
templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))
templates_aatype = templates.residue_constants.sequence_to_onehot(
output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
)
template_features = {
"template_all_atom_positions": templates_all_atom_positions[None],
"template_all_atom_masks": templates_all_atom_masks[None],
"template_aatype": np.array(templates_aatype)[None],
"template_domain_names": [f"none".encode()],
}
return template_features
def predict_structure(prefix, feature_dict, model_runners, random_seed=0):
"""Predicts structure using AlphaFold for the given sequence."""
# Run the models.
# currently we only run model1
plddts = {}
for model_name, model_runner in model_runners.items():
processed_feature_dict = model_runner.process_features(
feature_dict, random_seed=random_seed
)
prediction_result = model_runner.predict(processed_feature_dict)
b_factors = (
prediction_result["plddt"][:, None]
* prediction_result["structure_module"]["final_atom_mask"]
)
unrelaxed_protein = protein.from_prediction(
processed_feature_dict, prediction_result, b_factors
)
unrelaxed_pdb_path = f"/home/user/app/{prefix}_unrelaxed_{model_name}.pdb"
plddts[model_name] = prediction_result["plddt"]
print(f"{model_name} {plddts[model_name].mean()}")
with open(unrelaxed_pdb_path, "w") as f:
f.write(protein.to_pdb(unrelaxed_protein))
return plddts
def compute_perplexity(model, tokenizer, sequence):
input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
return math.exp(loss)
@ray.remote(num_gpus=1, max_calls=1)
def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs):
print("running protgpt2")
print(gpu_usage())
seqs_to_sample = max_seqs*10 # get the top 10
#protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
model = GPT2LMHeadModel.from_pretrained("nferruz/ProtGPT2")
tokenizer = GPT2Tokenizer.from_pretrained("nferruz/ProtGPT2")
input_ids = tokenizer.encode(startsequence, return_tensors='pt')
sequences = model.generate(input_ids,
max_length=length,
do_sample=True,
top_k=top_k_poolsize,
repetition_penalty=repetitionPenalty,
num_return_sequences=seqs_to_sample,
eos_token_id=0)
filtered_sequences = []
for sequence in sequences:
decoded_seq = tokenizer.decode(seq)
# No newlines in first line and avoid truncation
if '\n' not in decoded_seq[0:60] and decoded_seq.count('<|endoftext|>')>=2:
clean_seq = decoded_seq.split('<|endoftext|>')[0]
ppl = compute_perplexity(model, tokenizer, clean_seq)
filtered_sequences.append((clean_seq, ppl/len(clean_seq)))
## THis needs to be fixed to show warning if not enough sequences fulfill the criteria!
selected_sequences = filtered_sequences.sort(key = lambda x: x[2])[:max_seqs]
# sequences = protgpt2(
# startsequence,
# max_length=length,
# do_sample=True,
# top_k=top_k_poolsize,
# repetition_penalty=repetitionPenalty,
# num_return_sequences=seqs_to_sample,
# eos_token_id=0,
# )
print("Cleaning up after protGPT2")
#print(gpu_usage())
#torch.cuda.empty_cache()
#device = cuda.get_current_device()
#device.reset()
#print(gpu_usage())
return selected_sequences
@ray.remote(num_gpus=1, max_calls=1)
def run_alphafold(startsequence):
print(gpu_usage())
model_runners = {}
models = ["model_1"] # ,"model_2","model_3","model_4","model_5"]
for model_name in models:
model_config = config.model_config(model_name)
model_config.data.eval.num_ensemble = 1
model_params = data.get_model_haiku_params(model_name=model_name, data_dir="/home/user/app/")
model_runner = model.RunModel(model_config, model_params)
model_runners[model_name] = model_runner
query_sequence = startsequence.replace("\n", "")
feature_dict = {
**pipeline.make_sequence_features(
sequence=query_sequence, description="none", num_res=len(query_sequence)
),
**pipeline.make_msa_features(
msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]]
),
**mk_mock_template(query_sequence),
}
plddts = predict_structure("test", feature_dict, model_runners)
print("AF2 done")
#backend = jax.lib.xla_bridge.get_backend()
#for buf in backend.live_buffers(): buf.delete()
#device = cuda.get_current_device()
#device.reset()
#print(gpu_usage())
return plddts["model_1"]
def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs):
startsequence = inp
seqlen = length
generated_seqs = ray.get(run_protgpt2.remote(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs))
gen_seqs = [x["generated_text"] for x in generated_seqs]
# Make sure sequences weren't truncated due to the length cutoff
# Select the best scoring top 10th:
print(sel_seqs)
sequencestxt = ""
for i, seq in enumerate(sel_seqs):
s = seq.replace("\n","")
seqlen = len(s)
s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
sequencestxt +=f">seq{i}, {seqlen} residues \n{s}\n\n"
return sequencestxt
def update(inp):
print("Running AF on", inp)
startsequence = inp
# run alphafold using ray
plddts = ray.get(run_alphafold.remote(startsequence))
print(plddts)
x = np.arange(10)
#plt.style.use(["seaborn-ticks", "seaborn-talk"])
#fig = plt.figure()
#ax = fig.add_subplot(111)
#ax.plot(plddts)
#ax.set_ylabel("predicted LDDT")
#ax.set_xlabel("positions")
#ax.set_title("pLDDT")
fig = go.Figure(data=go.Scatter(x=np.arange(len(plddts)), y=plddts, hovertemplate='pLDDT: %{y:.2f}
Residue index: %{x}'))
fig.update_layout(title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white")
return (
molecule(
f"test_unrelaxed_model_1.pdb",
),
fig,
f"{np.mean(plddts):.1f} ± {np.std(plddts):.1f}",
)
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb):
mol = read_mol(pdb)
x = (
"""