KurtDu commited on
Commit
8efa040
·
verified ·
1 Parent(s): e18ae7e

Upload elo_rank.py

Browse files
Files changed (1) hide show
  1. elo_rank.py +133 -0
elo_rank.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import json
3
+
4
+ class EloRank:
5
+ def __init__(self, initial_rating=1000, k_factor=32):
6
+ """
7
+ Initialize the EloRank class.
8
+ :param initial_rating: Initial ELO rating for each model.
9
+ :param k_factor: The K-factor that determines the sensitivity of rating changes.
10
+ """
11
+ self.ratings = {}
12
+ self.initial_rating = initial_rating
13
+ self.k_factor = k_factor
14
+ self.wins = {}
15
+
16
+ def add_model(self, model_id):
17
+ """
18
+ Add a new model with the initial rating.
19
+ :param model_id: Unique identifier for the model.
20
+ """
21
+ self.ratings[model_id] = self.initial_rating
22
+ self.wins[model_id] = 0
23
+
24
+ def record_match(self, winner, loser):
25
+ """
26
+ Update the ratings based on a match result.
27
+ :param winner: Model ID of the winner.
28
+ :param loser: Model ID of the loser.
29
+ """
30
+ rating_winner = self.ratings[winner]
31
+ rating_loser = self.ratings[loser]
32
+
33
+ expected_winner = self.expected_score(rating_winner, rating_loser)
34
+ expected_loser = self.expected_score(rating_loser, rating_winner)
35
+
36
+ self.ratings[winner] += self.k_factor * (1 - expected_winner)
37
+ self.ratings[loser] += self.k_factor * (0 - expected_loser)
38
+
39
+ # Update win count
40
+ self.wins[winner] += 1
41
+
42
+ def expected_score(self, rating_a, rating_b):
43
+ """
44
+ Calculate the expected score for a model.
45
+ :param rating_a: Rating of model A.
46
+ :param rating_b: Rating of model B.
47
+ :return: Expected score.
48
+ """
49
+ return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
50
+
51
+ def get_rating(self, model_id):
52
+ """
53
+ Get the current rating of a model.
54
+ :param model_id: Unique identifier for the model.
55
+ :return: Current rating of the model.
56
+ """
57
+ return self.ratings.get(model_id, None)
58
+
59
+ def get_wins(self, model_id):
60
+ """
61
+ Get the number of wins of a model.
62
+ :param model_id: Unique identifier for the model.
63
+ :return: Number of wins of the model.
64
+ """
65
+ return self.wins.get(model_id, 0)
66
+
67
+ def get_top_models(self, n=2):
68
+ """
69
+ Get the top N models by rating.
70
+ :param n: Number of top models to retrieve.
71
+ :return: List of model IDs of the top models.
72
+ """
73
+ return sorted(self.ratings, key=self.ratings.get, reverse=True)[:n]
74
+
75
+ def sample_next_match(self):
76
+ """
77
+ Sample the next match based on the probability proportional to the current rating.
78
+ This approach helps accelerate the convergence of ranking.
79
+ :return: Tuple of two model IDs for the next match.
80
+ """
81
+ model_ids = list(self.ratings.keys())
82
+ probabilities = [self.ratings[model_id] for model_id in model_ids]
83
+ total_rating = sum(probabilities)
84
+ probabilities = [rating / total_rating for rating in probabilities]
85
+
86
+ # Sample two different models for the next match
87
+ next_match = random.choices(model_ids, probabilities, k=2)
88
+ while next_match[0] == next_match[1]:
89
+ next_match = random.choices(model_ids, probabilities, k=2)
90
+
91
+ return tuple(next_match)
92
+
93
+ def process_match_records(self, file_path):
94
+ """
95
+ Process match records from a JSON file and update ratings and win counts accordingly.
96
+ :param file_path: Path to the JSON file containing match records.
97
+ """
98
+ with open(file_path, 'r') as file:
99
+ match_records = json.load(file)
100
+
101
+ for record in match_records:
102
+ winner = record['winner']
103
+ model_1 = record['model_1']
104
+ model_2 = record['model_2']
105
+
106
+ # Add models if they are not already added
107
+ if model_1 not in self.ratings:
108
+ self.add_model(model_1)
109
+ if model_2 not in self.ratings:
110
+ self.add_model(model_2)
111
+
112
+ # Record the match result
113
+ if winner == model_1:
114
+ self.record_match(model_1, model_2)
115
+ elif winner == model_2:
116
+ self.record_match(model_2, model_1)
117
+
118
+ # # Example Usage
119
+ # e = EloRank()
120
+ # e.add_model('model_A')
121
+ # e.add_model('model_B')
122
+ # e.add_model('model_C')
123
+
124
+ # e.record_match('model_A', 'model_B')
125
+ # print(e.get_rating('model_A')) # Should be greater than the initial rating
126
+ # print(e.get_rating('model_B')) # Should be less than the initial rating
127
+
128
+ # print(e.get_top_models(2)) # Get the top 2 models
129
+ # print(e.sample_next_match()) # Sample the next match based on ratings
130
+
131
+ # # Process match records from a JSON file
132
+ # e.process_match_records('match_records.json')
133
+ # print(e.get_wins('model_A')) # Get the number of wins for model_A