Kang Suhyun commited on
Commit
5352a13
·
unverified ·
1 Parent(s): d5ea1a7

[#37] Store ELO ratings in DB after calculation (#112)

Browse files

* [#37] Store ELO ratings in DB after calculation

This change adds logic to store ELO ratings in the database after they are calculated.

Previously, we calculated and loaded the ratings without storing them. Now, we store them for future use.

This change doesn't affect current operations as we are not using the stored data yet. However, it sets the groundwork for future optimizations where we will use the stored ratings to avoid recalculating ELO scores for already calculated battles.

* update

* update

* review

* fix

* fix

* review

Files changed (5) hide show
  1. README.md +3 -0
  2. app.py +2 -1
  3. db.py +114 -0
  4. leaderboard.py +43 -75
  5. response.py +2 -2
README.md CHANGED
@@ -49,6 +49,9 @@ Get Involved: [Discuss and contribute on GitHub](https://github.com/yanolja/aren
49
 
50
  ```shell
51
  CREDENTIALS_PATH=<your crednetials path> \
 
 
 
52
  OPENAI_API_KEY=<your key> \
53
  ANTHROPIC_API_KEY=<your key> \
54
  MISTRAL_API_KEY=<your key> \
 
49
 
50
  ```shell
51
  CREDENTIALS_PATH=<your crednetials path> \
52
+ RATINGS_COLLECTION=<your collection> \
53
+ SUMMARIZATIONS_COLLECTION=<your collection> \
54
+ TRANSLATIONS_COLLECTION=<your collection> \
55
  OPENAI_API_KEY=<your key> \
56
  ANTHROPIC_API_KEY=<your key> \
57
  MISTRAL_API_KEY=<your key> \
app.py CHANGED
@@ -8,8 +8,8 @@ from firebase_admin import firestore
8
  import gradio as gr
9
  import lingua
10
 
 
11
  from leaderboard import build_leaderboard
12
- from leaderboard import db
13
  from leaderboard import SUPPORTED_LANGUAGES
14
  from model import check_models
15
  from model import supported_models
@@ -50,6 +50,7 @@ def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
50
  language_a = detector.detect_language_of(response_a)
51
  language_b = detector.detect_language_of(response_b)
52
 
 
53
  doc_ref = db.collection("arena-summarizations").document(doc_id)
54
  doc["model_a_response_language"] = language_a.name.lower()
55
  doc["model_b_response_language"] = language_b.name.lower()
 
8
  import gradio as gr
9
  import lingua
10
 
11
+ from db import db
12
  from leaderboard import build_leaderboard
 
13
  from leaderboard import SUPPORTED_LANGUAGES
14
  from model import check_models
15
  from model import supported_models
 
50
  language_a = detector.detect_language_of(response_a)
51
  language_b = detector.detect_language_of(response_b)
52
 
53
+ # TODO(#37): Move DB operations to db.py.
54
  doc_ref = db.collection("arena-summarizations").document(doc_id)
55
  doc["model_a_response_language"] = language_a.name.lower()
56
  doc["model_b_response_language"] = language_b.name.lower()
db.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module handles the management of the database.
3
+ """
4
+ from dataclasses import dataclass
5
+ import enum
6
+ import os
7
+ from typing import List
8
+
9
+ import firebase_admin
10
+ from firebase_admin import credentials
11
+ from firebase_admin import firestore
12
+ from google.cloud.firestore_v1 import base_query
13
+ import gradio as gr
14
+
15
+ from credentials import get_credentials_json
16
+
17
+
18
+ def get_required_env(name: str) -> str:
19
+ value = os.getenv(name)
20
+ if value is None:
21
+ raise ValueError(f"Environment variable {name} is not set")
22
+ return value
23
+
24
+
25
+ RATINGS_COLLECTION = get_required_env("RATINGS_COLLECTION")
26
+ SUMMARIZATIONS_COLLECTION = get_required_env("SUMMARIZATIONS_COLLECTION")
27
+ TRANSLATIONS_COLLECTION = get_required_env("TRANSLATIONS_COLLECTION")
28
+
29
+ if gr.NO_RELOAD:
30
+ firebase_admin.initialize_app(credentials.Certificate(get_credentials_json()))
31
+ db = firestore.client()
32
+
33
+
34
+ class Category(enum.Enum):
35
+ SUMMARIZATION = "summarization"
36
+ TRANSLATION = "translation"
37
+
38
+
39
+ @dataclass
40
+ class Rating:
41
+ model: str
42
+ rating: int
43
+
44
+
45
+ def get_ratings(category: Category, source_lang: str | None,
46
+ target_lang: str | None) -> List[Rating] | None:
47
+ doc_id = "#".join([category.value] +
48
+ [lang for lang in (source_lang, target_lang) if lang])
49
+ # TODO(#37): Make it more clear what fields are in the document.
50
+ doc_dict = db.collection(RATINGS_COLLECTION).document(doc_id).get().to_dict()
51
+ if doc_dict is None:
52
+ return None
53
+
54
+ # TODO(#37): Return the timestamp as well.
55
+ doc_dict.pop("timestamp")
56
+
57
+ return [Rating(model, rating) for model, rating in doc_dict.items()]
58
+
59
+
60
+ def set_ratings(category: Category, ratings: List[Rating], source_lang: str,
61
+ target_lang: str | None):
62
+ source_lang_lowercase = source_lang.lower()
63
+ target_lang_lowercase = target_lang.lower() if target_lang else None
64
+
65
+ doc_id = "#".join([category.value, source_lang_lowercase] +
66
+ ([target_lang_lowercase] if target_lang_lowercase else []))
67
+ doc_ref = db.collection(RATINGS_COLLECTION).document(doc_id)
68
+
69
+ new_ratings = {rating.model: rating.rating for rating in ratings}
70
+ new_ratings["timestamp"] = firestore.SERVER_TIMESTAMP
71
+ doc_ref.set(new_ratings, merge=True)
72
+
73
+
74
+ @dataclass
75
+ class Battle:
76
+ model_a: str
77
+ model_b: str
78
+ winner: str
79
+
80
+
81
+ def get_battles(category: Category, source_lang: str | None,
82
+ target_lang: str | None) -> List[Battle]:
83
+ source_lang_lowercase = source_lang.lower() if source_lang else None
84
+ target_lang_lowercase = target_lang.lower() if target_lang else None
85
+
86
+ if category == Category.SUMMARIZATION:
87
+ collection = db.collection(SUMMARIZATIONS_COLLECTION).order_by("timestamp")
88
+
89
+ if source_lang_lowercase:
90
+ collection = collection.where(filter=base_query.FieldFilter(
91
+ "model_a_response_language", "==", source_lang_lowercase)).where(
92
+ filter=base_query.FieldFilter("model_b_response_language", "==",
93
+ source_lang_lowercase))
94
+
95
+ elif category == Category.TRANSLATION:
96
+ collection = db.collection(TRANSLATIONS_COLLECTION).order_by("timestamp")
97
+
98
+ if source_lang_lowercase:
99
+ collection = collection.where(filter=base_query.FieldFilter(
100
+ "source_language", "==", source_lang_lowercase))
101
+
102
+ if target_lang_lowercase:
103
+ collection = collection.where(filter=base_query.FieldFilter(
104
+ "target_language", "==", target_lang_lowercase))
105
+
106
+ else:
107
+ raise ValueError(f"Invalid category: {category}")
108
+
109
+ docs = collection.stream()
110
+ battles = []
111
+ for doc in docs:
112
+ data = doc.to_dict()
113
+ battles.append(Battle(data["model_a"], data["model_b"], data["winner"]))
114
+ return battles
leaderboard.py CHANGED
@@ -5,21 +5,13 @@ It provides a leaderboard component.
5
  from collections import defaultdict
6
  import enum
7
  import math
8
- from typing import Tuple
9
 
10
- import firebase_admin
11
- from firebase_admin import credentials
12
- from firebase_admin import firestore
13
- from google.cloud.firestore_v1 import base_query
14
  import gradio as gr
15
  import lingua
16
- import pandas as pd
17
 
18
- from credentials import get_credentials_json
19
-
20
- if gr.NO_RELOAD:
21
- firebase_admin.initialize_app(credentials.Certificate(get_credentials_json()))
22
- db = firestore.client()
23
 
24
  SUPPORTED_LANGUAGES = [
25
  language.name.capitalize() for language in lingua.Language.all()
@@ -34,11 +26,16 @@ class LeaderboardTab(enum.Enum):
34
 
35
 
36
  # Ref: https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing#scrollTo=QLGc6DwxyvQc pylint: disable=line-too-long
37
- def compute_elo(battles, k=4, scale=400, base=10, initial_rating=1000):
 
 
 
 
38
  rating = defaultdict(lambda: initial_rating)
39
 
40
- for model_a, model_b, winner in battles[["model_a", "model_b",
41
- "winner"]].itertuples(index=False):
 
42
  rating_a = rating[model_a]
43
  rating_b = rating[model_b]
44
 
@@ -50,71 +47,41 @@ def compute_elo(battles, k=4, scale=400, base=10, initial_rating=1000):
50
  rating[model_a] += k * (scored_point_a - expected_score_a)
51
  rating[model_b] += k * (1 - scored_point_a - expected_score_b)
52
 
53
- return rating
54
-
55
-
56
- def get_docs(tab: str,
57
- summary_lang: str = None,
58
- source_lang: str = None,
59
- target_lang: str = None):
60
- if tab == LeaderboardTab.SUMMARIZATION:
61
- collection = db.collection("arena-summarizations").order_by("timestamp")
62
-
63
- if summary_lang and (not summary_lang == ANY_LANGUAGE):
64
- collection = collection.where(filter=base_query.FieldFilter(
65
- "model_a_response_language", "==", summary_lang.lower())).where(
66
- filter=base_query.FieldFilter("model_b_response_language", "==",
67
- summary_lang.lower()))
68
-
69
- return collection.stream()
70
-
71
- if tab == LeaderboardTab.TRANSLATION:
72
- collection = db.collection("arena-translations").order_by("timestamp")
73
 
74
- if source_lang and (not source_lang == ANY_LANGUAGE):
75
- collection = collection.where(filter=base_query.FieldFilter(
76
- "source_language", "==", source_lang.lower()))
77
 
78
- if target_lang and (not target_lang == ANY_LANGUAGE):
79
- collection = collection.where(filter=base_query.FieldFilter(
80
- "target_language", "==", target_lang.lower()))
81
 
82
- return collection.stream()
83
-
84
-
85
- def load_elo_ratings(tab,
86
- summary_lang: str = None,
87
- source_lang: str = None,
88
- target_lang: str = None):
89
- docs = get_docs(tab, summary_lang, source_lang, target_lang)
90
-
91
- battles = []
92
- for doc in docs:
93
- data = doc.to_dict()
94
- battles.append({
95
- "model_a": data["model_a"],
96
- "model_b": data["model_b"],
97
- "winner": data["winner"]
98
- })
99
 
 
 
 
100
  if not battles:
101
  return
102
 
103
- battles = pd.DataFrame(battles)
104
- ratings = compute_elo(battles)
 
 
 
 
105
 
106
- sorted_ratings = sorted(ratings.items(), key=lambda x: x[1], reverse=True)
 
 
 
107
 
108
  rank = 0
109
  last_rating = None
110
  rating_rows = []
111
  for index, (model, rating) in enumerate(sorted_ratings):
112
- int_rating = math.floor(rating + 0.5)
113
- if int_rating != last_rating:
114
  rank = index + 1
115
 
116
- rating_rows.append([rank, model, int_rating])
117
- last_rating = int_rating
118
 
119
  return rating_rows
120
 
@@ -123,9 +90,9 @@ LEADERBOARD_UPDATE_INTERVAL = 600 # 10 minutes
123
  LEADERBOARD_INFO = "The leaderboard is updated every 10 minutes."
124
 
125
 
126
- def update_filtered_leaderboard(tab, summary_lang: str, source_lang: str,
127
- target_lang: str):
128
- new_value = load_elo_ratings(tab, summary_lang, source_lang, target_lang)
129
  return gr.update(value=new_value)
130
 
131
 
@@ -149,14 +116,15 @@ def build_leaderboard():
149
  headers=["Rank", "Model", "Elo rating"],
150
  datatype=["number", "str", "number"],
151
  value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION,
152
- ANY_LANGUAGE),
153
  elem_classes="leaderboard",
154
  visible=False)
155
 
156
  original_summarization = gr.Dataframe(
157
  headers=["Rank", "Model", "Elo rating"],
158
  datatype=["number", "str", "number"],
159
- value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION),
 
160
  every=LEADERBOARD_UPDATE_INTERVAL,
161
  elem_classes="leaderboard")
162
  gr.Markdown(LEADERBOARD_INFO)
@@ -165,7 +133,6 @@ def build_leaderboard():
165
  fn=update_filtered_leaderboard,
166
  inputs=[
167
  gr.State(LeaderboardTab.SUMMARIZATION), summary_language,
168
- gr.State(None),
169
  gr.State(None)
170
  ],
171
  outputs=filtered_summarization).then(
@@ -197,7 +164,8 @@ def build_leaderboard():
197
  original_translation = gr.Dataframe(
198
  headers=["Rank", "Model", "Elo rating"],
199
  datatype=["number", "str", "number"],
200
- value=lambda: load_elo_ratings(LeaderboardTab.TRANSLATION),
 
201
  every=LEADERBOARD_UPDATE_INTERVAL,
202
  elem_classes="leaderboard")
203
  gr.Markdown(LEADERBOARD_INFO)
@@ -205,8 +173,8 @@ def build_leaderboard():
205
  source_language.change(
206
  fn=update_filtered_leaderboard,
207
  inputs=[
208
- gr.State(LeaderboardTab.TRANSLATION),
209
- gr.State(None), source_language, target_language
210
  ],
211
  outputs=filtered_translation).then(
212
  fn=toggle_leaderboard,
@@ -215,8 +183,8 @@ def build_leaderboard():
215
  target_language.change(
216
  fn=update_filtered_leaderboard,
217
  inputs=[
218
- gr.State(LeaderboardTab.TRANSLATION),
219
- gr.State(None), source_language, target_language
220
  ],
221
  outputs=filtered_translation).then(
222
  fn=toggle_leaderboard,
 
5
  from collections import defaultdict
6
  import enum
7
  import math
8
+ from typing import Dict, List, Tuple
9
 
 
 
 
 
10
  import gradio as gr
11
  import lingua
 
12
 
13
+ import db
14
+ from db import get_battles
 
 
 
15
 
16
  SUPPORTED_LANGUAGES = [
17
  language.name.capitalize() for language in lingua.Language.all()
 
26
 
27
 
28
  # Ref: https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing#scrollTo=QLGc6DwxyvQc pylint: disable=line-too-long
29
+ def compute_elo(battles: List[db.Battle],
30
+ k=4,
31
+ scale=400,
32
+ base=10,
33
+ initial_rating=1000) -> Dict[str, int]:
34
  rating = defaultdict(lambda: initial_rating)
35
 
36
+ for battle in battles:
37
+ model_a, model_b, winner = battle.model_a, battle.model_b, battle.winner
38
+
39
  rating_a = rating[model_a]
40
  rating_b = rating[model_b]
41
 
 
47
  rating[model_a] += k * (scored_point_a - expected_score_a)
48
  rating[model_b] += k * (1 - scored_point_a - expected_score_b)
49
 
50
+ return {model: math.floor(rating + 0.5) for model, rating in rating.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
52
 
53
+ def load_elo_ratings(tab, source_lang: str, target_lang: str | None):
54
+ category = db.Category.SUMMARIZATION if tab == LeaderboardTab.SUMMARIZATION else db.Category.TRANSLATION
 
55
 
56
+ # TODO(#37): Call db.get_ratings and return the ratings if exists.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ battles = get_battles(category,
59
+ None if source_lang == ANY_LANGUAGE else source_lang,
60
+ None if target_lang == ANY_LANGUAGE else target_lang)
61
  if not battles:
62
  return
63
 
64
+ computed_ratings = compute_elo(battles)
65
+
66
+ db.set_ratings(
67
+ category,
68
+ [db.Rating(model, rating) for model, rating in computed_ratings.items()],
69
+ source_lang, target_lang)
70
 
71
+ sorted_ratings = sorted(
72
+ computed_ratings.items(),
73
+ key=lambda x: x[1], # rating
74
+ reverse=True)
75
 
76
  rank = 0
77
  last_rating = None
78
  rating_rows = []
79
  for index, (model, rating) in enumerate(sorted_ratings):
80
+ if rating != last_rating:
 
81
  rank = index + 1
82
 
83
+ rating_rows.append([rank, model, rating])
84
+ last_rating = rating
85
 
86
  return rating_rows
87
 
 
90
  LEADERBOARD_INFO = "The leaderboard is updated every 10 minutes."
91
 
92
 
93
+ def update_filtered_leaderboard(tab: str, source_lang: str,
94
+ target_lang: str | None):
95
+ new_value = load_elo_ratings(tab, source_lang, target_lang)
96
  return gr.update(value=new_value)
97
 
98
 
 
116
  headers=["Rank", "Model", "Elo rating"],
117
  datatype=["number", "str", "number"],
118
  value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION,
119
+ ANY_LANGUAGE, None),
120
  elem_classes="leaderboard",
121
  visible=False)
122
 
123
  original_summarization = gr.Dataframe(
124
  headers=["Rank", "Model", "Elo rating"],
125
  datatype=["number", "str", "number"],
126
+ value=lambda: load_elo_ratings(LeaderboardTab.SUMMARIZATION,
127
+ ANY_LANGUAGE, None),
128
  every=LEADERBOARD_UPDATE_INTERVAL,
129
  elem_classes="leaderboard")
130
  gr.Markdown(LEADERBOARD_INFO)
 
133
  fn=update_filtered_leaderboard,
134
  inputs=[
135
  gr.State(LeaderboardTab.SUMMARIZATION), summary_language,
 
136
  gr.State(None)
137
  ],
138
  outputs=filtered_summarization).then(
 
164
  original_translation = gr.Dataframe(
165
  headers=["Rank", "Model", "Elo rating"],
166
  datatype=["number", "str", "number"],
167
+ value=lambda: load_elo_ratings(LeaderboardTab.TRANSLATION,
168
+ ANY_LANGUAGE, ANY_LANGUAGE),
169
  every=LEADERBOARD_UPDATE_INTERVAL,
170
  elem_classes="leaderboard")
171
  gr.Markdown(LEADERBOARD_INFO)
 
173
  source_language.change(
174
  fn=update_filtered_leaderboard,
175
  inputs=[
176
+ gr.State(LeaderboardTab.TRANSLATION), source_language,
177
+ target_language
178
  ],
179
  outputs=filtered_translation).then(
180
  fn=toggle_leaderboard,
 
183
  target_language.change(
184
  fn=update_filtered_leaderboard,
185
  inputs=[
186
+ gr.State(LeaderboardTab.TRANSLATION), source_language,
187
+ target_language
188
  ],
189
  outputs=filtered_translation).then(
190
  fn=toggle_leaderboard,
response.py CHANGED
@@ -11,7 +11,7 @@ from uuid import uuid4
11
  from firebase_admin import firestore
12
  import gradio as gr
13
 
14
- from leaderboard import db
15
  from model import ContextWindowExceededError
16
  from model import Model
17
  from model import supported_models
@@ -22,7 +22,7 @@ logging.basicConfig()
22
  logger = logging.getLogger(__name__)
23
  logger.setLevel(logging.INFO)
24
 
25
-
26
  def get_history_collection(category: str):
27
  if category == Category.SUMMARIZE.value:
28
  return db.collection("arena-summarization-history")
 
11
  from firebase_admin import firestore
12
  import gradio as gr
13
 
14
+ from db import db
15
  from model import ContextWindowExceededError
16
  from model import Model
17
  from model import supported_models
 
22
  logger = logging.getLogger(__name__)
23
  logger.setLevel(logging.INFO)
24
 
25
+ # TODO(#37): Move DB operations to db.py.
26
  def get_history_collection(category: str):
27
  if category == Category.SUMMARIZE.value:
28
  return db.collection("arena-summarization-history")