tracywong117 commited on
Commit
819e2d9
·
1 Parent(s): d9a04ad

fix file upload handle, update readme

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +11 -2
  3. scripts/script.py +59 -20
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: ARGnet UI
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.14.0
8
  app_file: app.py
 
1
  ---
2
  title: ARGnet UI
3
+ emoji: 🧬
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.14.0
8
  app_file: app.py
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  from helper import *
8
  import scripts.script as script
9
 
 
10
  def process_data(input_type, input_text, input_file):
11
  print(input_type)
12
  if input_type == "Text":
@@ -26,6 +27,7 @@ def process_data(input_type, input_text, input_file):
26
  else:
27
  if input_file:
28
  sequence = []
 
29
  with open(input_file.name, "r") as f:
30
  for line in f:
31
  if line.startswith(">"):
@@ -70,7 +72,7 @@ with gr.Blocks() as whole_block:
70
  """
71
  )
72
  input_textbox = gr.Textbox(label="Sequence")
73
- input_textbox_2 = gr.Textbox(label="Sequence",visible=False)
74
  gr.Examples(
75
  examples=[
76
  ["Amino Acid Long Sequence (>51aa)"],
@@ -97,7 +99,14 @@ with gr.Blocks() as whole_block:
97
  ## Output
98
  """
99
  )
100
- table = gr.Dataframe(headers=["Test ID", "ARG Prediction", "Resistance Category", "Probability"])
 
 
 
 
 
 
 
101
  pie_chart = gr.Plot(container=True)
102
 
103
  text_tab.select(lambda: "Text", inputs=None, outputs=tab_selected)
 
7
  from helper import *
8
  import scripts.script as script
9
 
10
+
11
  def process_data(input_type, input_text, input_file):
12
  print(input_type)
13
  if input_type == "Text":
 
27
  else:
28
  if input_file:
29
  sequence = []
30
+ input_text = open(input_file.name, "r").read()
31
  with open(input_file.name, "r") as f:
32
  for line in f:
33
  if line.startswith(">"):
 
72
  """
73
  )
74
  input_textbox = gr.Textbox(label="Sequence")
75
+ input_textbox_2 = gr.Textbox(label="Sequence", visible=False)
76
  gr.Examples(
77
  examples=[
78
  ["Amino Acid Long Sequence (>51aa)"],
 
99
  ## Output
100
  """
101
  )
102
+ table = gr.Dataframe(
103
+ headers=[
104
+ "Test ID",
105
+ "ARG Prediction",
106
+ "Resistance Category",
107
+ "Probability",
108
+ ]
109
+ )
110
  pie_chart = gr.Plot(container=True)
111
 
112
  text_tab.select(lambda: "Text", inputs=None, outputs=tab_selected)
scripts/script.py CHANGED
@@ -5,50 +5,85 @@ import plotly.graph_objects as go
5
 
6
  import os
7
 
 
8
  def plot_pie_chart(df):
9
  ARG_prediction_counts = dict(df["ARG_prediction"].value_counts())
10
- ARG_prediction_df = pd.DataFrame.from_dict(ARG_prediction_counts, orient='index', columns=['count'])
 
 
11
  resistance_category_counts = dict(df["resistance_category"].value_counts())
12
- resistance_category_df = pd.DataFrame.from_dict(resistance_category_counts, orient='index', columns=['count'])
 
 
13
 
14
  number_of_catgeory = len(df["resistance_category"].value_counts())
15
 
16
  colors = [
17
- '#f9b4ab',
18
- '#fdebd3',
19
- '#264e70',
20
- '#679186',
21
- '#bbd4ce',
22
  ]
23
  full_colors = []
24
- for i in range(math.ceil(number_of_catgeory/5)):
25
  full_colors += colors
26
 
27
  # colors = ['gold', 'mediumturquoise', 'darkorange', 'lightgreen']
28
 
29
- fig = sp.make_subplots(rows=1, cols=2, subplot_titles=("ARG/non-ARG", "Resistance category"), specs=[[{'type': 'domain'}, {'type': 'domain'}]])
30
-
31
- fig.add_trace(go.Pie(labels=ARG_prediction_df.index, values=ARG_prediction_df['count'], legendgroup = '1', title="ARG/non-ARG"), row=1, col=1)
32
- fig.add_trace(go.Pie(labels=resistance_category_df.index, values=resistance_category_df['count'], legendgroup = '2', title="Resistance category"), row=1, col=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  fig.update_layout(showlegend=False, margin=dict(l=200, r=200, t=100, b=100))
35
- fig.update_traces(textposition='inside',hoverinfo='label+percent', textinfo='label', marker=dict(colors=colors, line=dict(color='#38496e', width=1)))
36
-
 
 
 
 
 
37
  return fig
38
 
 
39
  def view_stat(output_name):
40
- df = pd.read_csv(f"results/{output_name}", delimiter='\t')
41
  # change df header
42
  new_headers = ["Test ID", "ARG Prediction", "Resistance Category", "Probability"]
43
  fig = plot_pie_chart(df)
44
-
45
  # delete the output file
46
  if os.path.exists(f"results/{output_name}"):
47
  os.remove(f"results/{output_name}")
48
 
49
  return [df.rename(columns=dict(zip(df.columns, new_headers))), fig]
50
 
51
- def run_argnet(input, output_name,sequence_type,sequence_length_type):
 
52
  with open("input.txt", "w") as f:
53
  f.write(input)
54
 
@@ -58,16 +93,20 @@ def run_argnet(input, output_name,sequence_type,sequence_length_type):
58
 
59
  if sequence_type == "aa" and sequence_length_type == "s":
60
  from . import argnet_ssaa_chunk as ssaa
 
61
  ssaa.argnet_ssaa("input.txt", output_name)
62
  elif sequence_type == "nt" and sequence_length_type == "s":
63
- from . import argnet_ssnt_new_chunk as ssnt
 
64
  ssnt.argnet_ssnt("input.txt", output_name)
65
  elif sequence_type == "aa" and sequence_length_type == "l":
66
  from . import argnet_lsaa_speed_sgpu as lsaa
 
67
  lsaa.argnet_lsaa("input.txt", output_name)
68
  elif sequence_type == "nt" and sequence_length_type == "l":
69
  from . import argnet_lsnt as lsnt
70
- lsnt.argnet_lsnt("input.txt", output_name)
71
-
72
 
 
73
 
 
 
 
5
 
6
  import os
7
 
8
+
9
  def plot_pie_chart(df):
10
  ARG_prediction_counts = dict(df["ARG_prediction"].value_counts())
11
+ ARG_prediction_df = pd.DataFrame.from_dict(
12
+ ARG_prediction_counts, orient="index", columns=["count"]
13
+ )
14
  resistance_category_counts = dict(df["resistance_category"].value_counts())
15
+ resistance_category_df = pd.DataFrame.from_dict(
16
+ resistance_category_counts, orient="index", columns=["count"]
17
+ )
18
 
19
  number_of_catgeory = len(df["resistance_category"].value_counts())
20
 
21
  colors = [
22
+ "#f9b4ab",
23
+ "#fdebd3",
24
+ "#264e70",
25
+ "#679186",
26
+ "#bbd4ce",
27
  ]
28
  full_colors = []
29
+ for i in range(math.ceil(number_of_catgeory / 5)):
30
  full_colors += colors
31
 
32
  # colors = ['gold', 'mediumturquoise', 'darkorange', 'lightgreen']
33
 
34
+ fig = sp.make_subplots(
35
+ rows=1,
36
+ cols=2,
37
+ subplot_titles=("ARG/non-ARG", "Resistance category"),
38
+ specs=[[{"type": "domain"}, {"type": "domain"}]],
39
+ )
40
+
41
+ fig.add_trace(
42
+ go.Pie(
43
+ labels=ARG_prediction_df.index,
44
+ values=ARG_prediction_df["count"],
45
+ legendgroup="1",
46
+ title="ARG/non-ARG",
47
+ ),
48
+ row=1,
49
+ col=1,
50
+ )
51
+ fig.add_trace(
52
+ go.Pie(
53
+ labels=resistance_category_df.index,
54
+ values=resistance_category_df["count"],
55
+ legendgroup="2",
56
+ title="Resistance category",
57
+ ),
58
+ row=1,
59
+ col=2,
60
+ )
61
 
62
  fig.update_layout(showlegend=False, margin=dict(l=200, r=200, t=100, b=100))
63
+ fig.update_traces(
64
+ textposition="inside",
65
+ hoverinfo="label+percent",
66
+ textinfo="label",
67
+ marker=dict(colors=full_colors, line=dict(color="#38496e", width=1)),
68
+ )
69
+
70
  return fig
71
 
72
+
73
  def view_stat(output_name):
74
+ df = pd.read_csv(f"results/{output_name}", delimiter="\t")
75
  # change df header
76
  new_headers = ["Test ID", "ARG Prediction", "Resistance Category", "Probability"]
77
  fig = plot_pie_chart(df)
78
+
79
  # delete the output file
80
  if os.path.exists(f"results/{output_name}"):
81
  os.remove(f"results/{output_name}")
82
 
83
  return [df.rename(columns=dict(zip(df.columns, new_headers))), fig]
84
 
85
+
86
+ def run_argnet(input, output_name, sequence_type, sequence_length_type):
87
  with open("input.txt", "w") as f:
88
  f.write(input)
89
 
 
93
 
94
  if sequence_type == "aa" and sequence_length_type == "s":
95
  from . import argnet_ssaa_chunk as ssaa
96
+
97
  ssaa.argnet_ssaa("input.txt", output_name)
98
  elif sequence_type == "nt" and sequence_length_type == "s":
99
+ from . import argnet_ssnt_new_chunk as ssnt
100
+
101
  ssnt.argnet_ssnt("input.txt", output_name)
102
  elif sequence_type == "aa" and sequence_length_type == "l":
103
  from . import argnet_lsaa_speed_sgpu as lsaa
104
+
105
  lsaa.argnet_lsaa("input.txt", output_name)
106
  elif sequence_type == "nt" and sequence_length_type == "l":
107
  from . import argnet_lsnt as lsnt
 
 
108
 
109
+ lsnt.argnet_lsnt("input.txt", output_name)
110
 
111
+ if os.path.exists("input.txt"):
112
+ os.remove("input.txt")