ARGnet-UI / scripts /script.py
tracywong117's picture
fix file upload handle, update readme
819e2d9
import pandas as pd
import math
import plotly.subplots as sp
import plotly.graph_objects as go
import os
def plot_pie_chart(df):
ARG_prediction_counts = dict(df["ARG_prediction"].value_counts())
ARG_prediction_df = pd.DataFrame.from_dict(
ARG_prediction_counts, orient="index", columns=["count"]
)
resistance_category_counts = dict(df["resistance_category"].value_counts())
resistance_category_df = pd.DataFrame.from_dict(
resistance_category_counts, orient="index", columns=["count"]
)
number_of_catgeory = len(df["resistance_category"].value_counts())
colors = [
"#f9b4ab",
"#fdebd3",
"#264e70",
"#679186",
"#bbd4ce",
]
full_colors = []
for i in range(math.ceil(number_of_catgeory / 5)):
full_colors += colors
# colors = ['gold', 'mediumturquoise', 'darkorange', 'lightgreen']
fig = sp.make_subplots(
rows=1,
cols=2,
subplot_titles=("ARG/non-ARG", "Resistance category"),
specs=[[{"type": "domain"}, {"type": "domain"}]],
)
fig.add_trace(
go.Pie(
labels=ARG_prediction_df.index,
values=ARG_prediction_df["count"],
legendgroup="1",
title="ARG/non-ARG",
),
row=1,
col=1,
)
fig.add_trace(
go.Pie(
labels=resistance_category_df.index,
values=resistance_category_df["count"],
legendgroup="2",
title="Resistance category",
),
row=1,
col=2,
)
fig.update_layout(showlegend=False, margin=dict(l=200, r=200, t=100, b=100))
fig.update_traces(
textposition="inside",
hoverinfo="label+percent",
textinfo="label",
marker=dict(colors=full_colors, line=dict(color="#38496e", width=1)),
)
return fig
def view_stat(output_name):
df = pd.read_csv(f"results/{output_name}", delimiter="\t")
# change df header
new_headers = ["Test ID", "ARG Prediction", "Resistance Category", "Probability"]
fig = plot_pie_chart(df)
# delete the output file
if os.path.exists(f"results/{output_name}"):
os.remove(f"results/{output_name}")
return [df.rename(columns=dict(zip(df.columns, new_headers))), fig]
def run_argnet(input, output_name, sequence_type, sequence_length_type):
with open("input.txt", "w") as f:
f.write(input)
# delete the output file
if os.path.exists(f"results/{output_name}"):
os.remove(f"results/{output_name}")
if sequence_type == "aa" and sequence_length_type == "s":
from . import argnet_ssaa_chunk as ssaa
ssaa.argnet_ssaa("input.txt", output_name)
elif sequence_type == "nt" and sequence_length_type == "s":
from . import argnet_ssnt_new_chunk as ssnt
ssnt.argnet_ssnt("input.txt", output_name)
elif sequence_type == "aa" and sequence_length_type == "l":
from . import argnet_lsaa_speed_sgpu as lsaa
lsaa.argnet_lsaa("input.txt", output_name)
elif sequence_type == "nt" and sequence_length_type == "l":
from . import argnet_lsnt as lsnt
lsnt.argnet_lsnt("input.txt", output_name)
if os.path.exists("input.txt"):
os.remove("input.txt")