crystina-z commited on
Commit
c777cc8
·
1 Parent(s): 048d322
Files changed (6) hide show
  1. app.py +148 -0
  2. dl19-gpt-3.5.pt +3 -0
  3. dl19-gpt-4.pt +3 -0
  4. dl20-gpt-3.5.pt +3 -0
  5. dl20-gpt-4.pt +3 -0
  6. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import numpy as np
4
+
5
+ import streamlit as st
6
+ from pathlib import Path
7
+ from collections import defaultdict
8
+
9
+ import sys
10
+ path_root = Path("./")
11
+ sys.path.append(str(path_root))
12
+
13
+
14
+ st.set_page_config(page_title="PSC Runtime",
15
+ page_icon='🌸', layout="centered")
16
+
17
+
18
+
19
+ name = st.selectbox(
20
+ "Choose a dataset",
21
+ ["dl19", "dl20"],
22
+ index=None,
23
+ placeholder="Choose a dataset..."
24
+ )
25
+
26
+ model_name = st.selectbox(
27
+ "Choose a model",
28
+ ["gpt-3.5", "gpt-4"],
29
+ index=None,
30
+ placeholder="Choose a model..."
31
+ )
32
+
33
+
34
+ if name and model_name:
35
+ import torch
36
+ # fn = f"dl19-gpt-3.5.pt"
37
+ fn = f"{name}-{model_name}.pt"
38
+ object = torch.load(fn)
39
+
40
+ outputs = object[2]
41
+ query2outputs = {}
42
+ for output in outputs:
43
+ all_queries = {x['query'] for x in output}
44
+ assert len(all_queries) == 1
45
+ query = list(all_queries)[0]
46
+ query2outputs[query] = [x['hits'] for x in output]
47
+
48
+ search_query = st.selectbox(
49
+ "Choose a query from the list",
50
+ sorted(query2outputs),
51
+ # index=None,
52
+ # placeholder="Choose a query from the list..."
53
+ )
54
+
55
+ def preferences_from_hits(list_of_hits):
56
+ docid2id = {}
57
+ id2doc = {}
58
+ preferences = []
59
+
60
+ for result in list_of_hits:
61
+ for doc in result:
62
+ if doc["docid"] not in docid2id:
63
+ id = len(docid2id)
64
+ docid2id[doc["docid"]] = id
65
+ id2doc[id] = doc
66
+ print([doc["docid"] for doc in result])
67
+ print([docid2id[doc["docid"]] for doc in result])
68
+ preferences.append([docid2id[doc["docid"]] for doc in result])
69
+
70
+ # = {v: k for k, v in docid2id.items()}
71
+ return np.array(preferences), id2doc
72
+
73
+
74
+ def load_qrels(name):
75
+ import ir_datasets
76
+ if name == "dl19":
77
+ ds_name = "msmarco-passage/trec-dl-2019/judged"
78
+ elif name == "dl20":
79
+ ds_name = "msmarco-passage/trec-dl-2020/judged"
80
+ else:
81
+ raise ValueError(name)
82
+
83
+ dataset = ir_datasets.load(ds_name)
84
+ qrels = defaultdict(dict)
85
+ for qrel in dataset.qrels_iter():
86
+ qrels[qrel.query_id][qrel.doc_id] = qrel.relevance
87
+ return qrels
88
+
89
+
90
+ def aggregate(list_of_hits):
91
+ import numpy as np
92
+ from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences
93
+ from permsc import BordaRankAggregator
94
+
95
+ preferences, id2doc = preferences_from_hits(list_of_hits)
96
+ y_optimal = KemenyOptimalAggregator().aggregate(preferences)
97
+ # y_optimal = BordaRankAggregator().aggregate(preferences)
98
+
99
+ return [id2doc[id] for id in y_optimal]
100
+
101
+
102
+ def write_ranking(search_results, text):
103
+ st.write(f'<p align=\"right\" style=\"color:grey;\"> {text} ms</p>', unsafe_allow_html=True)
104
+
105
+ qid = {result["qid"] for result in search_results}
106
+ assert len(qid) == 1
107
+ qid = list(qid)[0]
108
+
109
+ for i, result in enumerate(search_results):
110
+ result_id = result["docid"]
111
+ contents = result["content"]
112
+
113
+ label = qrels[str(qid)].get(str(result_id), 0)
114
+ if label == 3:
115
+ style = "style=\"color:rgb(231, 95, 43);\""
116
+ elif label == 2:
117
+ style = "style=\"color:rgb(238, 147, 49);\""
118
+ elif label == 1:
119
+ style = "style=\"color:rgb(241, 177, 118);\""
120
+ else:
121
+ style = "style=\"color:grey;\""
122
+
123
+ print(qid, result_id, label, style)
124
+ # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
125
+ output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'
126
+
127
+ try:
128
+ st.write(output, unsafe_allow_html=True)
129
+ st.write(
130
+ f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True)
131
+
132
+ except:
133
+ pass
134
+ st.write('---')
135
+
136
+
137
+ aggregated_ranking = aggregate(query2outputs[search_query])
138
+ qrels = load_qrels(name)
139
+ col1, col2 = st.columns([5, 5])
140
+
141
+ if search_query:
142
+ with col1:
143
+ if search_query or button_clicked:
144
+ write_ranking(search_results=query2outputs[search_query][0], "w/o PSC")
145
+
146
+ with col2:
147
+ if search_query or button_clicked:
148
+ write_ranking(search_results=aggregated_ranking, "w/ PSC")
dl19-gpt-3.5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40a500ac421cca105758a5beb649e71ce2fd9c0cac5577d327a8680ffab9f710
3
+ size 2121849
dl19-gpt-4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0b5862926a1b302f07a22f2c9071f5914ebc0679e27579d99e253d22ee99605
3
+ size 2121845
dl20-gpt-3.5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95b630e2ad744825409548e2557ff08a109370a4fcbdc0d4ddf58a128094530e
3
+ size 2645817
dl20-gpt-4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeac0101b4e846fd7b3a25e96722882669373f8b35369b57cab2212f1b4770cb
3
+ size 2645813
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ tqdm
3
+ fastrank
4
+ permsc
5
+ ir_datasets