Spaces:
Running
Running
Kang Suhyun
commited on
[#37] Store ELO ratings in DB after calculation (#112)
Browse files* [#37] Store ELO ratings in DB after calculation
This change adds logic to store ELO ratings in the database after they are calculated.
Previously, we calculated and loaded the ratings without storing them. Now, we store them for future use.
This change doesn't affect current operations as we are not using the stored data yet. However, it sets the groundwork for future optimizations where we will use the stored ratings to avoid recalculating ELO scores for already calculated battles.
* update
* update
* review
* fix
* fix
* review
- README.md +3 -0
- app.py +2 -1
- db.py +114 -0
- leaderboard.py +43 -75
- response.py +2 -2
README.md
CHANGED
@@ -49,6 +49,9 @@ Get Involved: [Discuss and contribute on GitHub](https://github.com/yanolja/aren
|
|
49 |
|
50 |
```shell
|
51 |
CREDENTIALS_PATH=<your crednetials path> \
|
|
|
|
|
|
|
52 |
OPENAI_API_KEY=<your key> \
|
53 |
ANTHROPIC_API_KEY=<your key> \
|
54 |
MISTRAL_API_KEY=<your key> \
|
|
|
49 |
|
50 |
```shell
|
51 |
CREDENTIALS_PATH=<your crednetials path> \
|
52 |
+
RATINGS_COLLECTION=<your collection> \
|
53 |
+
SUMMARIZATIONS_COLLECTION=<your collection> \
|
54 |
+
TRANSLATIONS_COLLECTION=<your collection> \
|
55 |
OPENAI_API_KEY=<your key> \
|
56 |
ANTHROPIC_API_KEY=<your key> \
|
57 |
MISTRAL_API_KEY=<your key> \
|
app.py
CHANGED
@@ -8,8 +8,8 @@ from firebase_admin import firestore
|
|
8 |
import gradio as gr
|
9 |
import lingua
|
10 |
|
|
|
11 |
from leaderboard import build_leaderboard
|
12 |
-
from leaderboard import db
|
13 |
from leaderboard import SUPPORTED_LANGUAGES
|
14 |
from model import check_models
|
15 |
from model import supported_models
|
@@ -50,6 +50,7 @@ def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
|
|
50 |
language_a = detector.detect_language_of(response_a)
|
51 |
language_b = detector.detect_language_of(response_b)
|
52 |
|
|
|
53 |
doc_ref = db.collection("arena-summarizations").document(doc_id)
|
54 |
doc["model_a_response_language"] = language_a.name.lower()
|
55 |
doc["model_b_response_language"] = language_b.name.lower()
|
|
|
8 |
import gradio as gr
|
9 |
import lingua
|
10 |
|
11 |
+
from db import db
|
12 |
from leaderboard import build_leaderboard
|
|
|
13 |
from leaderboard import SUPPORTED_LANGUAGES
|
14 |
from model import check_models
|
15 |
from model import supported_models
|
|
|
50 |
language_a = detector.detect_language_of(response_a)
|
51 |
language_b = detector.detect_language_of(response_b)
|
52 |
|
53 |
+
# TODO(#37): Move DB operations to db.py.
|
54 |
doc_ref = db.collection("arena-summarizations").document(doc_id)
|
55 |
doc["model_a_response_language"] = language_a.name.lower()
|
56 |
doc["model_b_response_language"] = language_b.name.lower()
|
db.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module handles the management of the database.
|
3 |
+
"""
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import enum
|
6 |
+
import os
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import firebase_admin
|
10 |
+
from firebase_admin import credentials
|
11 |
+
from firebase_admin import firestore
|
12 |
+
from google.cloud.firestore_v1 import base_query
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
from credentials import get_credentials_json
|
16 |
+
|
17 |
+
|
18 |
+
def get_required_env(name: str) -> str:
|
19 |
+
value = os.getenv(name)
|
20 |
+
if value is None:
|
21 |
+
raise ValueError(f"Environment variable {name} is not set")
|
22 |
+
return value
|
23 |
+
|
24 |
+
|
25 |
+
RATINGS_COLLECTION = get_required_env("RATINGS_COLLECTION")
|
26 |
+
SUMMARIZATIONS_COLLECTION = get_required_env("SUMMARIZATIONS_COLLECTION")
|
27 |
+
TRANSLATIONS_COLLECTION = get_required_env("TRANSLATIONS_COLLECTION")
|
28 |
+
|
29 |
+
if gr.NO_RELOAD:
|
30 |
+
firebase_admin.initialize_app(credentials.Certificate(get_credentials_json()))
|
31 |
+
db = firestore.client()
|
32 |
+
|
33 |
+
|
34 |
+
class Category(enum.Enum):
|
35 |
+
SUMMARIZATION = "summarization"
|
36 |
+
TRANSLATION = "translation"
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class Rating:
|
41 |
+
model: str
|
42 |
+
rating: int
|
43 |
+
|
44 |
+
|
45 |
+
def get_ratings(category: Category, source_lang: str | None,
|
46 |
+
target_lang: str | None) -> List[Rating] | None:
|
47 |
+
doc_id = "#".join([category.value] +
|
48 |
+
[lang for lang in (source_lang, target_lang) if lang])
|
49 |
+
# TODO(#37): Make it more clear what fields are in the document.
|
50 |
+
doc_dict = db.collection(RATINGS_COLLECTION).document(doc_id).get().to_dict()
|
51 |
+
if doc_dict is None:
|
52 |
+
return None
|
53 |
+
|
54 |
+
# TODO(#37): Return the timestamp as well.
|
55 |
+
doc_dict.pop("timestamp")
|
56 |
+
|
57 |
+
return [Rating(model, rating) for model, rating in doc_dict.items()]
|
58 |
+
|
59 |
+
|
60 |
+
def set_ratings(category: Category, ratings: List[Rating], source_lang: str,
|
61 |
+
target_lang: str | None):
|
62 |
+
source_lang_lowercase = source_lang.lower()
|
63 |
+
target_lang_lowercase = target_lang.lower() if target_lang else None
|
64 |
+
|
65 |
+
doc_id = "#".join([category.value, source_lang_lowercase] +
|
66 |
+
([target_lang_lowercase] if target_lang_lowercase else []))
|
67 |
+
doc_ref = db.collection(RATINGS_COLLECTION).document(doc_id)
|
68 |
+
|
69 |
+
new_ratings = {rating.model: rating.rating for rating in ratings}
|
70 |
+
new_ratings["timestamp"] = firestore.SERVER_TIMESTAMP
|
71 |
+
doc_ref.set(new_ratings, merge=True)
|
72 |
+
|
73 |
+
|
74 |
+
@dataclass
|
75 |
+
class Battle:
|
76 |
+
model_a: str
|
77 |
+
model_b: str
|
78 |
+
winner: str
|
79 |
+
|
80 |
+
|
81 |
+
def get_battles(category: Category, source_lang: str | None,
|
82 |
+
target_lang: str | None) -> List[Battle]:
|
83 |
+
source_lang_lowercase = source_lang.lower() if source_lang else None
|
84 |
+
target_lang_lowercase = target_lang.lower() if target_lang else None
|
85 |
+
|
86 |
+
if category == Category.SUMMARIZATION:
|
87 |
+
collection = db.collection(SUMMARIZATIONS_COLLECTION).order_by("timestamp")
|
88 |
+
|
89 |
+
if source_lang_lowercase:
|
90 |
+
collection = collection.where(filter=base_query.FieldFilter(
|
91 |
+
"model_a_response_language", "==", source_lang_lowercase)).where(
|
92 |
+
filter=base_query.FieldFilter("model_b_response_language", "==",
|
93 |
+
source_lang_lowercase))
|
94 |
+
|
95 |
+
elif category == Category.TRANSLATION:
|
96 |
+
collection = db.collection(TRANSLATIONS_COLLECTION).order_by("timestamp")
|
97 |
+
|
98 |
+
if source_lang_lowercase:
|
99 |
+
collection = collection.where(filter=base_query.FieldFilter(
|
100 |
+
"source_language", "==", source_lang_lowercase))
|
101 |
+
|
102 |
+
if target_lang_lowercase:
|
103 |
+
collection = collection.where(filter=base_query.FieldFilter(
|
104 |
+
"target_language", "==", target_lang_lowercase))
|
105 |
+
|
106 |
+
else:
|
107 |
+
raise ValueError(f"Invalid category: {category}")
|
108 |
+
|
109 |
+
docs = collection.stream()
|
110 |
+
battles = []
|
111 |
+
for doc in docs:
|
112 |
+
data = doc.to_dict()
|
113 |
+
battles.append(Battle(data["model_a"], data["model_b"], data["winner"]))
|
114 |
+
return battles
|
leaderboard.py
CHANGED
@@ -5,21 +5,13 @@ It provides a leaderboard component.
|
|
5 |
from collections import defaultdict
|
6 |
import enum
|
7 |
import math
|
8 |
-
from typing import Tuple
|
9 |
|
10 |
-
import firebase_admin
|
11 |
-
from firebase_admin import credentials
|
12 |
-
from firebase_admin import firestore
|
13 |
-
from google.cloud.firestore_v1 import base_query
|
14 |
import gradio as gr
|
15 |
import lingua
|
16 |
-
import pandas as pd
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
if gr.NO_RELOAD:
|
21 |
-
firebase_admin.initialize_app(credentials.Certificate(get_credentials_json()))
|
22 |
-
db = firestore.client()
|
23 |
|
24 |
SUPPORTED_LANGUAGES = [
|
25 |
language.name.capitalize() for language in lingua.Language.all()
|
@@ -34,11 +26,16 @@ class LeaderboardTab(enum.Enum):
|
|
34 |
|
35 |
|
36 |
# Ref: https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing#scrollTo=QLGc6DwxyvQc pylint: disable=line-too-long
|
37 |
-
def compute_elo(battles
|
|
|
|
|
|
|
|
|
38 |
rating = defaultdict(lambda: initial_rating)
|
39 |
|
40 |
-
for
|
41 |
-
|
|
|
42 |
rating_a = rating[model_a]
|
43 |
rating_b = rating[model_b]
|
44 |
|
@@ -50,71 +47,41 @@ def compute_elo(battles, k=4, scale=400, base=10, initial_rating=1000):
|
|
50 |
rating[model_a] += k * (scored_point_a - expected_score_a)
|
51 |
rating[model_b] += k * (1 - scored_point_a - expected_score_b)
|
52 |
|
53 |
-
return rating
|
54 |
-
|
55 |
-
|
56 |
-
def get_docs(tab: str,
|
57 |
-
summary_lang: str = None,
|
58 |
-
source_lang: str = None,
|
59 |
-
target_lang: str = None):
|
60 |
-
if tab == LeaderboardTab.SUMMARIZATION:
|
61 |
-
collection = db.collection("arena-summarizations").order_by("timestamp")
|
62 |
-
|
63 |
-
if summary_lang and (not summary_lang == ANY_LANGUAGE):
|
64 |
-
collection = collection.where(filter=base_query.FieldFilter(
|
65 |
-
"model_a_response_language", "==", summary_lang.lower())).where(
|
66 |
-
filter=base_query.FieldFilter("model_b_response_language", "==",
|
67 |
-
summary_lang.lower()))
|
68 |
-
|
69 |
-
return collection.stream()
|
70 |
-
|
71 |
-
if tab == LeaderboardTab.TRANSLATION:
|
72 |
-
collection = db.collection("arena-translations").order_by("timestamp")
|
73 |
|
74 |
-
if source_lang and (not source_lang == ANY_LANGUAGE):
|
75 |
-
collection = collection.where(filter=base_query.FieldFilter(
|
76 |
-
"source_language", "==", source_lang.lower()))
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
"target_language", "==", target_lang.lower()))
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
def load_elo_ratings(tab,
|
86 |
-
summary_lang: str = None,
|
87 |
-
source_lang: str = None,
|
88 |
-
target_lang: str = None):
|
89 |
-
docs = get_docs(tab, summary_lang, source_lang, target_lang)
|
90 |
-
|
91 |
-
battles = []
|
92 |
-
for doc in docs:
|
93 |
-
data = doc.to_dict()
|
94 |
-
battles.append({
|
95 |
-
"model_a": data["model_a"],
|
96 |
-
"model_b": data["model_b"],
|
97 |
-
"winner": data["winner"]
|
98 |
-
})
|
99 |
|
|
|
|
|
|
|
100 |
if not battles:
|
101 |
return
|
102 |
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
sorted_ratings = sorted(
|
|
|
|
|
|
|
107 |
|
108 |
rank = 0
|
109 |
last_rating = None
|
110 |
rating_rows = []
|
111 |
for index, (model, rating) in enumerate(sorted_ratings):
|
112 |
-
|
113 |
-
if int_rating != last_rating:
|
114 |
rank = index + 1
|
115 |
|
116 |
-
rating_rows.append([rank, model,
|
117 |
-
last_rating =
|
118 |
|
119 |
return rating_rows
|
120 |
|
@@ -123,9 +90,9 @@ LEADERBOARD_UPDATE_INTERVAL = 600 # 10 minutes
|
|
123 |
LEADERBOARD_INFO = "The leaderboard is updated every 10 minutes."
|
124 |
|
125 |
|
126 |
-
def update_filtered_leaderboard(tab
|
127 |
-
target_lang: str):
|
128 |
-
new_value = load_elo_ratings(tab,
|
129 |
return gr.update(value=new_value)
|
130 |
|
131 |
|
@@ -149,14 +116,15 @@ def build_leaderboard():
|
|
149 |
headers=["Rank", "Model", "Elo rating"],
|
150 |
datatype=["number", "str", "number"],
|
151 |
value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION,
|
152 |
-
ANY_LANGUAGE),
|
153 |
elem_classes="leaderboard",
|
154 |
visible=False)
|
155 |
|
156 |
original_summarization = gr.Dataframe(
|
157 |
headers=["Rank", "Model", "Elo rating"],
|
158 |
datatype=["number", "str", "number"],
|
159 |
-
value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION
|
|
|
160 |
every=LEADERBOARD_UPDATE_INTERVAL,
|
161 |
elem_classes="leaderboard")
|
162 |
gr.Markdown(LEADERBOARD_INFO)
|
@@ -165,7 +133,6 @@ def build_leaderboard():
|
|
165 |
fn=update_filtered_leaderboard,
|
166 |
inputs=[
|
167 |
gr.State(LeaderboardTab.SUMMARIZATION), summary_language,
|
168 |
-
gr.State(None),
|
169 |
gr.State(None)
|
170 |
],
|
171 |
outputs=filtered_summarization).then(
|
@@ -197,7 +164,8 @@ def build_leaderboard():
|
|
197 |
original_translation = gr.Dataframe(
|
198 |
headers=["Rank", "Model", "Elo rating"],
|
199 |
datatype=["number", "str", "number"],
|
200 |
-
value=lambda: load_elo_ratings(LeaderboardTab.TRANSLATION
|
|
|
201 |
every=LEADERBOARD_UPDATE_INTERVAL,
|
202 |
elem_classes="leaderboard")
|
203 |
gr.Markdown(LEADERBOARD_INFO)
|
@@ -205,8 +173,8 @@ def build_leaderboard():
|
|
205 |
source_language.change(
|
206 |
fn=update_filtered_leaderboard,
|
207 |
inputs=[
|
208 |
-
gr.State(LeaderboardTab.TRANSLATION),
|
209 |
-
|
210 |
],
|
211 |
outputs=filtered_translation).then(
|
212 |
fn=toggle_leaderboard,
|
@@ -215,8 +183,8 @@ def build_leaderboard():
|
|
215 |
target_language.change(
|
216 |
fn=update_filtered_leaderboard,
|
217 |
inputs=[
|
218 |
-
gr.State(LeaderboardTab.TRANSLATION),
|
219 |
-
|
220 |
],
|
221 |
outputs=filtered_translation).then(
|
222 |
fn=toggle_leaderboard,
|
|
|
5 |
from collections import defaultdict
|
6 |
import enum
|
7 |
import math
|
8 |
+
from typing import Dict, List, Tuple
|
9 |
|
|
|
|
|
|
|
|
|
10 |
import gradio as gr
|
11 |
import lingua
|
|
|
12 |
|
13 |
+
import db
|
14 |
+
from db import get_battles
|
|
|
|
|
|
|
15 |
|
16 |
SUPPORTED_LANGUAGES = [
|
17 |
language.name.capitalize() for language in lingua.Language.all()
|
|
|
26 |
|
27 |
|
28 |
# Ref: https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing#scrollTo=QLGc6DwxyvQc pylint: disable=line-too-long
|
29 |
+
def compute_elo(battles: List[db.Battle],
|
30 |
+
k=4,
|
31 |
+
scale=400,
|
32 |
+
base=10,
|
33 |
+
initial_rating=1000) -> Dict[str, int]:
|
34 |
rating = defaultdict(lambda: initial_rating)
|
35 |
|
36 |
+
for battle in battles:
|
37 |
+
model_a, model_b, winner = battle.model_a, battle.model_b, battle.winner
|
38 |
+
|
39 |
rating_a = rating[model_a]
|
40 |
rating_b = rating[model_b]
|
41 |
|
|
|
47 |
rating[model_a] += k * (scored_point_a - expected_score_a)
|
48 |
rating[model_b] += k * (1 - scored_point_a - expected_score_b)
|
49 |
|
50 |
+
return {model: math.floor(rating + 0.5) for model, rating in rating.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
|
|
|
|
|
|
52 |
|
53 |
+
def load_elo_ratings(tab, source_lang: str, target_lang: str | None):
|
54 |
+
category = db.Category.SUMMARIZATION if tab == LeaderboardTab.SUMMARIZATION else db.Category.TRANSLATION
|
|
|
55 |
|
56 |
+
# TODO(#37): Call db.get_ratings and return the ratings if exists.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
battles = get_battles(category,
|
59 |
+
None if source_lang == ANY_LANGUAGE else source_lang,
|
60 |
+
None if target_lang == ANY_LANGUAGE else target_lang)
|
61 |
if not battles:
|
62 |
return
|
63 |
|
64 |
+
computed_ratings = compute_elo(battles)
|
65 |
+
|
66 |
+
db.set_ratings(
|
67 |
+
category,
|
68 |
+
[db.Rating(model, rating) for model, rating in computed_ratings.items()],
|
69 |
+
source_lang, target_lang)
|
70 |
|
71 |
+
sorted_ratings = sorted(
|
72 |
+
computed_ratings.items(),
|
73 |
+
key=lambda x: x[1], # rating
|
74 |
+
reverse=True)
|
75 |
|
76 |
rank = 0
|
77 |
last_rating = None
|
78 |
rating_rows = []
|
79 |
for index, (model, rating) in enumerate(sorted_ratings):
|
80 |
+
if rating != last_rating:
|
|
|
81 |
rank = index + 1
|
82 |
|
83 |
+
rating_rows.append([rank, model, rating])
|
84 |
+
last_rating = rating
|
85 |
|
86 |
return rating_rows
|
87 |
|
|
|
90 |
LEADERBOARD_INFO = "The leaderboard is updated every 10 minutes."
|
91 |
|
92 |
|
93 |
+
def update_filtered_leaderboard(tab: str, source_lang: str,
|
94 |
+
target_lang: str | None):
|
95 |
+
new_value = load_elo_ratings(tab, source_lang, target_lang)
|
96 |
return gr.update(value=new_value)
|
97 |
|
98 |
|
|
|
116 |
headers=["Rank", "Model", "Elo rating"],
|
117 |
datatype=["number", "str", "number"],
|
118 |
value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION,
|
119 |
+
ANY_LANGUAGE, None),
|
120 |
elem_classes="leaderboard",
|
121 |
visible=False)
|
122 |
|
123 |
original_summarization = gr.Dataframe(
|
124 |
headers=["Rank", "Model", "Elo rating"],
|
125 |
datatype=["number", "str", "number"],
|
126 |
+
value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION,
|
127 |
+
ANY_LANGUAGE, None),
|
128 |
every=LEADERBOARD_UPDATE_INTERVAL,
|
129 |
elem_classes="leaderboard")
|
130 |
gr.Markdown(LEADERBOARD_INFO)
|
|
|
133 |
fn=update_filtered_leaderboard,
|
134 |
inputs=[
|
135 |
gr.State(LeaderboardTab.SUMMARIZATION), summary_language,
|
|
|
136 |
gr.State(None)
|
137 |
],
|
138 |
outputs=filtered_summarization).then(
|
|
|
164 |
original_translation = gr.Dataframe(
|
165 |
headers=["Rank", "Model", "Elo rating"],
|
166 |
datatype=["number", "str", "number"],
|
167 |
+
value=lambda: load_elo_ratings(LeaderboardTab.TRANSLATION,
|
168 |
+
ANY_LANGUAGE, ANY_LANGUAGE),
|
169 |
every=LEADERBOARD_UPDATE_INTERVAL,
|
170 |
elem_classes="leaderboard")
|
171 |
gr.Markdown(LEADERBOARD_INFO)
|
|
|
173 |
source_language.change(
|
174 |
fn=update_filtered_leaderboard,
|
175 |
inputs=[
|
176 |
+
gr.State(LeaderboardTab.TRANSLATION), source_language,
|
177 |
+
target_language
|
178 |
],
|
179 |
outputs=filtered_translation).then(
|
180 |
fn=toggle_leaderboard,
|
|
|
183 |
target_language.change(
|
184 |
fn=update_filtered_leaderboard,
|
185 |
inputs=[
|
186 |
+
gr.State(LeaderboardTab.TRANSLATION), source_language,
|
187 |
+
target_language
|
188 |
],
|
189 |
outputs=filtered_translation).then(
|
190 |
fn=toggle_leaderboard,
|
response.py
CHANGED
@@ -11,7 +11,7 @@ from uuid import uuid4
|
|
11 |
from firebase_admin import firestore
|
12 |
import gradio as gr
|
13 |
|
14 |
-
from
|
15 |
from model import ContextWindowExceededError
|
16 |
from model import Model
|
17 |
from model import supported_models
|
@@ -22,7 +22,7 @@ logging.basicConfig()
|
|
22 |
logger = logging.getLogger(__name__)
|
23 |
logger.setLevel(logging.INFO)
|
24 |
|
25 |
-
|
26 |
def get_history_collection(category: str):
|
27 |
if category == Category.SUMMARIZE.value:
|
28 |
return db.collection("arena-summarization-history")
|
|
|
11 |
from firebase_admin import firestore
|
12 |
import gradio as gr
|
13 |
|
14 |
+
from db import db
|
15 |
from model import ContextWindowExceededError
|
16 |
from model import Model
|
17 |
from model import supported_models
|
|
|
22 |
logger = logging.getLogger(__name__)
|
23 |
logger.setLevel(logging.INFO)
|
24 |
|
25 |
+
# TODO(#37): Move DB operations to db.py.
|
26 |
def get_history_collection(category: str):
|
27 |
if category == Category.SUMMARIZE.value:
|
28 |
return db.collection("arena-summarization-history")
|