Spaces:
Running
Running
import json | |
import os | |
import glob | |
import time | |
from fastapi import FastAPI | |
import hashlib | |
import asyncio | |
REFRESH_INTERVAL_SEC = 60 | |
LOG_DIR_LIST = [] | |
# LOG_DIR = "/home/vicuna/tmp/test_env" | |
class Monitor: | |
"""Monitor the number of calls to each model.""" | |
def __init__(self, log_dir_list: list): | |
self.log_dir_list = log_dir_list | |
self.model_call = {} | |
self.user_call = {} | |
self.model_call_limit_global = { | |
"gpt-4-1106-preview": 100, | |
"gpt-4-0125-preview": 100, | |
} | |
self.model_call_day_limit_per_user = { | |
"gpt-4-1106-preview": 5, | |
"gpt-4-0125-preview": 5, | |
} | |
async def update_stats(self, num_file=1) -> None: | |
while True: | |
# find the latest num_file log under log_dir | |
json_files = [] | |
for log_dir in self.log_dir_list: | |
json_files_per_server = glob.glob(os.path.join(log_dir, "*.json")) | |
json_files_per_server.sort(key=os.path.getctime, reverse=True) | |
json_files += json_files_per_server[:num_file] | |
model_call = {} | |
user_call = {} | |
for json_file in json_files: | |
for line in open(json_file, "r", encoding="utf-8"): | |
obj = json.loads(line) | |
if obj["type"] != "chat": | |
continue | |
if obj["model"] not in model_call: | |
model_call[obj["model"]] = [] | |
model_call[obj["model"]].append( | |
{"tstamp": obj["tstamp"], "user_id": obj["ip"]} | |
) | |
if obj["ip"] not in user_call: | |
user_call[obj["ip"]] = [] | |
user_call[obj["ip"]].append( | |
{"tstamp": obj["tstamp"], "model": obj["model"]} | |
) | |
self.model_call = model_call | |
self.model_call_stats_hour = self.get_model_call_stats(top_k=None) | |
self.model_call_stats_day = self.get_model_call_stats( | |
top_k=None, most_recent_min=24 * 60 | |
) | |
self.user_call = user_call | |
self.user_call_stats_hour = self.get_user_call_stats(top_k=None) | |
self.user_call_stats_day = self.get_user_call_stats( | |
top_k=None, most_recent_min=24 * 60 | |
) | |
await asyncio.sleep(REFRESH_INTERVAL_SEC) | |
def get_model_call_limit(self, model: str) -> int: | |
if model not in self.model_call_limit_global: | |
return -1 | |
return self.model_call_limit_global[model] | |
def update_model_call_limit(self, model: str, limit: int) -> bool: | |
if model not in self.model_call_limit_global: | |
return False | |
self.model_call_limit_global[model] = limit | |
return True | |
def is_model_limit_reached(self, model: str) -> bool: | |
if model not in self.model_call_limit_global: | |
return False | |
if model not in self.model_call_stats_hour: | |
return False | |
# check if the model call limit is reached | |
if self.model_call_stats_hour[model] >= self.model_call_limit_global[model]: | |
return True | |
return False | |
def is_user_limit_reached(self, model: str, user_id: str) -> bool: | |
if model not in self.model_call_day_limit_per_user: | |
return False | |
if user_id not in self.user_call_stats_day: | |
return False | |
if model not in self.user_call_stats_day[user_id]["call_dict"]: | |
return False | |
# check if the user call limit is reached | |
if ( | |
self.user_call_stats_day[user_id]["call_dict"][model] | |
>= self.model_call_day_limit_per_user[model] | |
): | |
return True | |
return False | |
def get_model_call_stats( | |
self, target_model=None, most_recent_min: int = 60, top_k: int = 20 | |
) -> dict: | |
model_call_stats = {} | |
for model, reqs in self.model_call.items(): | |
if target_model is not None and model != target_model: | |
continue | |
model_call = [] | |
for req in reqs: | |
if req["tstamp"] < time.time() - most_recent_min * 60: | |
continue | |
model_call.append(req["tstamp"]) | |
model_call_stats[model] = len(model_call) | |
if top_k is not None: | |
top_k_model = sorted( | |
model_call_stats, key=lambda x: model_call_stats[x], reverse=True | |
)[:top_k] | |
model_call_stats = {model: model_call_stats[model] for model in top_k_model} | |
return model_call_stats | |
def get_user_call_stats( | |
self, target_model=None, most_recent_min: int = 60, top_k: int = 20 | |
) -> dict: | |
user_call_stats = {} | |
for user_id, reqs in self.user_call.items(): | |
user_model_call = {"call_dict": {}} | |
for req in reqs: | |
if req["tstamp"] < time.time() - most_recent_min * 60: | |
continue | |
if target_model is not None and req["model"] != target_model: | |
continue | |
if req["model"] not in user_model_call["call_dict"]: | |
user_model_call["call_dict"][req["model"]] = 0 | |
user_model_call["call_dict"][req["model"]] += 1 | |
user_model_call["total_calls"] = sum(user_model_call["call_dict"].values()) | |
if user_model_call["total_calls"] > 0: | |
user_call_stats[user_id] = user_model_call | |
if top_k is not None: | |
top_k_user = sorted( | |
user_call_stats, | |
key=lambda x: user_call_stats[x]["total_calls"], | |
reverse=True, | |
)[:top_k] | |
user_call_stats = { | |
user_id: user_call_stats[user_id] for user_id in top_k_user | |
} | |
return user_call_stats | |
def get_num_users(self, most_recent_min: int = 60) -> int: | |
user_call_stats = self.get_user_call_stats( | |
most_recent_min=most_recent_min, top_k=None | |
) | |
return len(user_call_stats) | |
monitor = Monitor(log_dir_list=LOG_DIR_LIST) | |
app = FastAPI() | |
async def app_startup(): | |
asyncio.create_task(monitor.update_stats(2)) | |
async def get_model_call_limit(model: str): | |
return {"model_call_limit": {model: monitor.get_model_call_limit(model)}} | |
async def update_model_call_limit(model: str, limit: int): | |
if not monitor.update_model_call_limit(model, limit): | |
return {"success": False} | |
return {"success": True} | |
async def is_limit_reached(model: str, user_id: str): | |
if monitor.is_model_limit_reached(model): | |
return { | |
"is_limit_reached": True, | |
"reason": f"MODEL_HOURLY_LIMIT ({model}): {monitor.get_model_call_limit(model)}", | |
} | |
if monitor.is_user_limit_reached(model, user_id): | |
return { | |
"is_limit_reached": True, | |
"reason": f"USER_DAILY_LIMIT ({model}): {monitor.model_call_day_limit_per_user[model]}", | |
} | |
return {"is_limit_reached": False} | |
async def get_num_users(): | |
return {"num_users": len(monitor.user_call_stats_hour)} | |
async def get_num_users_day(): | |
return {"num_users": len(monitor.user_call_stats_day)} | |
async def get_user_call_stats( | |
model: str = None, most_recent_min: int = 60, top_k: int = None | |
): | |
return { | |
"user_call_stats": monitor.get_user_call_stats(model, most_recent_min, top_k) | |
} | |
async def get_model_call_stats( | |
model: str = None, most_recent_min: int = 60, top_k: int = None | |
): | |
return { | |
"model_call_stats": monitor.get_model_call_stats(model, most_recent_min, top_k) | |
} | |