crystina-z commited on
Commit
37046f4
·
1 Parent(s): 4726074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -17,29 +17,26 @@ st.set_page_config(page_title="PSC Runtime",
17
 
18
 
19
  name = st.selectbox(
20
- "",
21
  ["dl19", "dl20"],
22
  index=None,
23
  placeholder="Choose a dataset..."
24
  )
25
 
26
  model_name = st.selectbox(
27
- "",
28
  ["gpt-3.5", "gpt-4"],
29
  index=None,
30
  placeholder="Choose a model..."
31
  )
32
 
33
- # "dl19"
34
 
35
  if name and model_name:
36
-
37
  import torch
38
  # fn = f"dl19-gpt-3.5.pt"
39
  fn = f"{name}-{model_name}.pt"
40
  object = torch.load(fn)
41
-
42
-
43
  outputs = object[2]
44
  query2outputs = {}
45
  for output in outputs:
@@ -47,20 +44,19 @@ if name and model_name:
47
  assert len(all_queries) == 1
48
  query = list(all_queries)[0]
49
  query2outputs[query] = [x['hits'] for x in output]
50
-
51
-
52
  search_query = st.selectbox(
53
  "",
54
  sorted(query2outputs),
55
- index=None,
56
- placeholder="Choose a query from the list..."
57
  )
58
 
59
  def preferences_from_hits(list_of_hits):
60
  docid2id = {}
61
  id2doc = {}
62
  preferences = []
63
-
64
  for result in list_of_hits:
65
  for doc in result:
66
  if doc["docid"] not in docid2id:
@@ -73,8 +69,8 @@ if name and model_name:
73
 
74
  # = {v: k for k, v in docid2id.items()}
75
  return np.array(preferences), id2doc
76
-
77
-
78
  def load_qrels(name):
79
  import ir_datasets
80
  if name == "dl19":
@@ -89,8 +85,8 @@ if name and model_name:
89
  for qrel in dataset.qrels_iter():
90
  qrels[qrel.query_id][qrel.doc_id] = qrel.relevance
91
  return qrels
92
-
93
-
94
  def aggregate(list_of_hits):
95
  import numpy as np
96
  from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
@@ -101,12 +97,12 @@ if name and model_name:
101
  # y_optimal = BordaRankAggregator().aggregate(preferences)
102
 
103
  return [id2doc[id] for id in y_optimal]
104
-
105
 
106
  def write_ranking(search_results):
107
  # st.write(
108
  # f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)
109
-
110
  qid = {result["qid"] for result in search_results}
111
  assert len(qid) == 1
112
  qid = list(qid)[0]
@@ -114,17 +110,17 @@ if name and model_name:
114
  for i, result in enumerate(search_results):
115
  result_id = result["docid"]
116
  contents = result["content"]
117
-
118
  label = qrels[str(qid)].get(str(result_id), 0)
119
  if label == 3:
120
- style = "style=\"color:blue;\""
121
  elif label == 2:
122
- style = "style=\"color:green;\""
123
  elif label == 1:
124
- style = "style=\"color:red;\""
125
  else:
126
  style = "style=\"color:grey;\""
127
-
128
  print(qid, result_id, label, style)
129
  # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
130
  output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'
 
17
 
18
 
19
  name = st.selectbox(
20
+ "Choose a dataset",
21
  ["dl19", "dl20"],
22
  index=None,
23
  placeholder="Choose a dataset..."
24
  )
25
 
26
  model_name = st.selectbox(
27
+ "Choose a model",
28
  ["gpt-3.5", "gpt-4"],
29
  index=None,
30
  placeholder="Choose a model..."
31
  )
32
 
 
33
 
34
  if name and model_name:
 
35
  import torch
36
  # fn = f"dl19-gpt-3.5.pt"
37
  fn = f"{name}-{model_name}.pt"
38
  object = torch.load(fn)
39
+
 
40
  outputs = object[2]
41
  query2outputs = {}
42
  for output in outputs:
 
44
  assert len(all_queries) == 1
45
  query = list(all_queries)[0]
46
  query2outputs[query] = [x['hits'] for x in output]
47
+
 
48
  search_query = st.selectbox(
49
  "",
50
  sorted(query2outputs),
51
+ # index=None,
52
+ # placeholder="Choose a query from the list..."
53
  )
54
 
55
  def preferences_from_hits(list_of_hits):
56
  docid2id = {}
57
  id2doc = {}
58
  preferences = []
59
+
60
  for result in list_of_hits:
61
  for doc in result:
62
  if doc["docid"] not in docid2id:
 
69
 
70
  # = {v: k for k, v in docid2id.items()}
71
  return np.array(preferences), id2doc
72
+
73
+
74
  def load_qrels(name):
75
  import ir_datasets
76
  if name == "dl19":
 
85
  for qrel in dataset.qrels_iter():
86
  qrels[qrel.query_id][qrel.doc_id] = qrel.relevance
87
  return qrels
88
+
89
+
90
  def aggregate(list_of_hits):
91
  import numpy as np
92
  from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
 
97
  # y_optimal = BordaRankAggregator().aggregate(preferences)
98
 
99
  return [id2doc[id] for id in y_optimal]
100
+
101
 
102
  def write_ranking(search_results):
103
  # st.write(
104
  # f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)
105
+
106
  qid = {result["qid"] for result in search_results}
107
  assert len(qid) == 1
108
  qid = list(qid)[0]
 
110
  for i, result in enumerate(search_results):
111
  result_id = result["docid"]
112
  contents = result["content"]
113
+
114
  label = qrels[str(qid)].get(str(result_id), 0)
115
  if label == 3:
116
+ style = "style=\"color:rgb(231, 95, 43);\""
117
  elif label == 2:
118
+ style = "style=\"color:rgb(238, 147, 49);\""
119
  elif label == 1:
120
+ style = "style=\"color:rgb(241, 177, 118);\""
121
  else:
122
  style = "style=\"color:grey;\""
123
+
124
  print(qid, result_id, label, style)
125
  # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
126
  output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'