SwastikN's picture
Update app.py
5210135 verified
raw
history blame
5.94 kB
import gradio as gr
import os
import torch
import openai
import requests
import gradio as gr
import transformers
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel
from gradio_molgallery3d import MolGallery3D
auth_username = os.environ["AUTH_USERNAME"]
auth_password = os.environ["AUTH_PASSWORD"]
gated_access_token = os.environ['CAMBRIDGELTL_ACCESS_TOKEN']
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
def listToString(s):
str1 = "\n\n"
# return string
return (str1.join(s))
def correct_smiles(smiles):
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
else:
return Chem.MolToSmiles(mol, canonical=True)
except Exception as e:
pass
molecule_tokenizer = AutoTokenizer.from_pretrained('SwastikN/sxc_chem_llm', token=gated_access_token)
molecule_model = GPT2LMHeadModel.from_pretrained("SwastikN/sxc_chem_llm", token=gated_access_token)
device = "cuda" if torch.cuda.is_available() else "cpu"
molecule_model.eval()
device = torch.device(device)
molecule_model.to(device)
def generate_drug_seq(inp_protein_seq, batch_size, top_p, top_k):
prompt = "<|startoftext|><P>" + inp_protein_seq
generated = torch.tensor(molecule_tokenizer.encode(prompt)).unsqueeze(0)
generated = generated.to(device)
directly_gen_protein_list = []
directly_gen_ligand_list = []
sample_outputs = molecule_model.generate(
generated,
# bos_token_id=random.randint(1,30000),
do_sample=True,
top_k=top_k,
max_length=1024,
top_p=top_p,
num_return_sequences=batch_size
)
for sample_output in sample_outputs:
generate_ligand = molecule_tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1]
directly_gen_ligand_list.append(generate_ligand)
directly_gen_protein_list.append(molecule_tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[0])
for i in range(len(directly_gen_ligand_list)):
directly_gen_ligand_list[i] = correct_smiles(directly_gen_ligand_list[i])
filtered_ligand_list = [ligand for ligand in directly_gen_ligand_list if ligand is not None]
rdkit_objects = [Chem.MolFromSmiles(smiles) for smiles in filtered_ligand_list]
return listToString(filtered_ligand_list), rdkit_objects
theme = gr.themes.Monochrome(
primary_hue="slate",
secondary_hue="red",
neutral_hue="stone",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
with gr.Blocks(theme=theme) as interface:
with gr.Column():
gr.Markdown(
"""<h1><center>SXCLigandLM : Drug Discovery</center></h1>
<p>
This is a demonstration of SXC Medical LLM for realisation of MCMS4451 - MSc. Project.
</p>
<p>The System is capable to generate synthetically valid ligand molecules for a given protein sequence. This protein sequence can belong to the RNA of any virus or pathogenic genes.</p>
<p> Developed by: Swastik N. under guidance of Prof. Dr. Anal Acharya at St. Xavier's College (Autonomous), Calcutta. </p>
"""
)
# with gr.Row():
# with gr.Column(scale=2):
# input_image = gr.Image(label="Input Image", type="pil", interactive=True)
# instruction = gr.Textbox(placeholder="Enter your instruction/question...", label="Question/Instruction")
# llm = gr.Dropdown(["sxc-medical-llm"], label="Select Model")
# submit = gr.Button("Submit", variant="primary")
# with gr.Column(scale=2):
# with gr.Accordion("Show Interpretation table", open=False):
# output_table = gr.Textbox(lines=8, label="Interpretation Table")
# output_text = gr.Textbox(lines=8, label="Output")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("<h2><center> Drug Generation System: </center></h2>")
with gr.Row():
textboxInput = gr.Textbox(label="Input Protein Amino Acid Sequence", interactive=True,
placeholder="MASV.....")
gr.Markdown(value="Parameters")
with gr.Row():
# numberAmount = gr.Number(label="Minimum Generation Amount", interactive=True, value=2)
numberBatchSize = gr.Number(label="Generate Maximum Candidates", interactive=True, value=10)
sliderTopK = gr.Slider(minimum=1, maximum=32, step=1, label="top_k", value=9, interactive=True, visible=False)
sliderTopP = gr.Slider(minimum=0, maximum=1, step=0.05, label="top_p", value=0.9, interactive=True, visible=False)
with gr.Column(scale=2):
with gr.Row():
textboxLog = gr.Textbox(lines = 10, label="Generated Candidates")
with gr.Row():
gallery3d = MolGallery3D(label="3D Interactive Structures", automatic_rotation=False)
with gr.Row():
generate_molecule = gr.Button('Generate Drug Like Ligands', variant='primary')
# openai.api_key = ""
# submit.click(process_document, inputs=[input_image, instruction, llm], outputs=[output_table, output_text])
# instruction.submit(
# process_document, inputs=[input_image, instruction, llm], outputs=[output_table, output_text]
# )
generate_molecule.click(generate_drug_seq, inputs=[textboxInput, numberBatchSize, sliderTopP, sliderTopK], outputs=[textboxLog, gallery3d])
interface.launch(auth=(auth_username, auth_password), share=True)