Spaces:
Sleeping
Sleeping
import pandas as pd | |
import math | |
import plotly.subplots as sp | |
import plotly.graph_objects as go | |
import os | |
def plot_pie_chart(df): | |
ARG_prediction_counts = dict(df["ARG_prediction"].value_counts()) | |
ARG_prediction_df = pd.DataFrame.from_dict( | |
ARG_prediction_counts, orient="index", columns=["count"] | |
) | |
resistance_category_counts = dict(df["resistance_category"].value_counts()) | |
resistance_category_df = pd.DataFrame.from_dict( | |
resistance_category_counts, orient="index", columns=["count"] | |
) | |
number_of_catgeory = len(df["resistance_category"].value_counts()) | |
colors = [ | |
"#f9b4ab", | |
"#fdebd3", | |
"#264e70", | |
"#679186", | |
"#bbd4ce", | |
] | |
full_colors = [] | |
for i in range(math.ceil(number_of_catgeory / 5)): | |
full_colors += colors | |
# colors = ['gold', 'mediumturquoise', 'darkorange', 'lightgreen'] | |
fig = sp.make_subplots( | |
rows=1, | |
cols=2, | |
subplot_titles=("ARG/non-ARG", "Resistance category"), | |
specs=[[{"type": "domain"}, {"type": "domain"}]], | |
) | |
fig.add_trace( | |
go.Pie( | |
labels=ARG_prediction_df.index, | |
values=ARG_prediction_df["count"], | |
legendgroup="1", | |
title="ARG/non-ARG", | |
), | |
row=1, | |
col=1, | |
) | |
fig.add_trace( | |
go.Pie( | |
labels=resistance_category_df.index, | |
values=resistance_category_df["count"], | |
legendgroup="2", | |
title="Resistance category", | |
), | |
row=1, | |
col=2, | |
) | |
fig.update_layout(showlegend=False, margin=dict(l=200, r=200, t=100, b=100)) | |
fig.update_traces( | |
textposition="inside", | |
hoverinfo="label+percent", | |
textinfo="label", | |
marker=dict(colors=full_colors, line=dict(color="#38496e", width=1)), | |
) | |
return fig | |
def view_stat(output_name): | |
df = pd.read_csv(f"results/{output_name}", delimiter="\t") | |
# change df header | |
new_headers = ["Test ID", "ARG Prediction", "Resistance Category", "Probability"] | |
fig = plot_pie_chart(df) | |
# delete the output file | |
if os.path.exists(f"results/{output_name}"): | |
os.remove(f"results/{output_name}") | |
return [df.rename(columns=dict(zip(df.columns, new_headers))), fig] | |
def run_argnet(input, output_name, sequence_type, sequence_length_type): | |
with open("input.txt", "w") as f: | |
f.write(input) | |
# delete the output file | |
if os.path.exists(f"results/{output_name}"): | |
os.remove(f"results/{output_name}") | |
if sequence_type == "aa" and sequence_length_type == "s": | |
from . import argnet_ssaa_chunk as ssaa | |
ssaa.argnet_ssaa("input.txt", output_name) | |
elif sequence_type == "nt" and sequence_length_type == "s": | |
from . import argnet_ssnt_new_chunk as ssnt | |
ssnt.argnet_ssnt("input.txt", output_name) | |
elif sequence_type == "aa" and sequence_length_type == "l": | |
from . import argnet_lsaa_speed_sgpu as lsaa | |
lsaa.argnet_lsaa("input.txt", output_name) | |
elif sequence_type == "nt" and sequence_length_type == "l": | |
from . import argnet_lsnt as lsnt | |
lsnt.argnet_lsnt("input.txt", output_name) | |
if os.path.exists("input.txt"): | |
os.remove("input.txt") | |