"""
Clean chatbot arena battle log.
Usage:
python3 clean_battle_data.py --mode conv_release
"""
import argparse
import datetime
import json
import os
import sys
from pytz import timezone
import time
import PIL
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from tqdm import tqdm
from .basic_stats import get_log_files, NUM_SERVERS, LOG_ROOT_DIR
from .utils import detect_language, get_time_stamp_from_date
VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
IDENTITY_WORDS = [
"vicuna",
"lmsys",
"koala",
"uc berkeley",
"open assistant",
"laion",
"chatglm",
"chatgpt",
"gpt-4",
"openai",
"anthropic",
"claude",
"bard",
"palm",
"lamda",
"google",
"llama",
"qianwan",
"alibaba",
"mistral",
"zhipu",
"KEG lab",
"01.AI",
"AI2",
"Tülu",
"Tulu",
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
"$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
"API REQUEST ERROR. Please increase the number of max tokens.",
"**API REQUEST ERROR** Reason: The response was blocked.",
"**API REQUEST ERROR**",
]
for i in range(len(IDENTITY_WORDS)):
IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
def parse_model_name(model_name):
model_source, *rest = model_name.split("_", 1)
model_type, model_name = rest[-1], "_".join(rest[:-1])
return model_source, model_name, model_type
def remove_html(raw):
if raw.startswith("
"):
return raw[raw.find(": ") + 2 : -len("
\n")]
if raw.startswith("### Model A: ") or raw.startswith("### Model B: "):
return raw[13:]
return raw
def to_openai_format(messages):
roles = ["user", "assistant"]
ret = []
for i, x in enumerate(messages):
ret.append({"role": roles[i % 2], "content": x[1]})
return ret
def replace_model_name(old_name, tstamp):
replace_dict = {
"bard": "palm-2",
"claude-v1": "claude-1",
"claude-instant-v1": "claude-instant-1",
"oasst-sft-1-pythia-12b": "oasst-pythia-12b",
"claude-2": "claude-2.0",
"PlayGroundV2": "PlayGround V2",
"PlayGroundV2.5": "PlayGround V2.5",
}
if old_name in ["gpt-4", "gpt-3.5-turbo"]:
if tstamp > 1687849200:
return old_name + "-0613"
else:
return old_name + "-0314"
if old_name in replace_dict:
return replace_dict[old_name]
return old_name
def read_file(filename):
data = []
for retry in range(5):
try:
# lines = open(filename).readlines()
for l in open(filename):
row = json.loads(l)
if row["type"] in VOTES:
data.append(row)
break
except FileNotFoundError:
time.sleep(2)
except json.JSONDecodeError:
print(f"Error in reading {filename}")
print(row)
exit(0)
return data
def read_file_parallel(log_files, num_threads=16):
data_all = []
if num_threads == 1:
for log_file in tqdm(log_files, desc="Reading"):
data_all.extend(read_file(log_file))
return data_all
else:
from multiprocessing import Pool
with Pool(num_threads) as p:
ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
for ret in ret_all:
data_all.extend(ret)
return data_all
def load_image(image_path):
try:
return PIL.Image.open(image_path)
except:
return None
def clean_battle_data(
log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False, mode="simple", task_name="image_editing"
):
data = read_file_parallel(log_files, num_threads=1)
convert_type = {
"leftvote": "model_a",
"rightvote": "model_b",
"tievote": "tie",
"bothbad_vote": "tie (bothbad)",
}
all_models = set()
all_ips = dict()
ct_anony = 0
ct_invalid = 0
ct_leaked_identity = 0
ct_banned = 0
battles = []
for row in tqdm(data, desc="Cleaning"):
if row["models"][0] is None or row["models"][1] is None:
print(f"Invalid model names: {row['models']}")
continue
# Resolve model names
models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
if "model_name" in row["states"][0]:
models_hidden = [
row["states"][0]["model_name"],
row["states"][1]["model_name"],
]
if models_hidden[0] is None:
models_hidden = models_public
else:
models_hidden = models_public
if (models_public[0] == "" and models_public[1] != "") or (
models_public[1] == "" and models_public[0] != ""
):
ct_invalid += 1
print(f"Invalid model names: {models_public}")
continue
if models_public[0] == "" or models_public[0] == "Model A":
anony = True
models = models_hidden
ct_anony += 1
else:
anony = False
models = models_public
if not models_public == models_hidden:
print(f"Model names mismatch: {models_public} vs {models_hidden}")
ct_invalid += 1
continue
# # Detect langauge
# state = row["states"][0]
# if state["offset"] >= len(state["messages"]):
# ct_invalid += 1
# continue
# lang_code = detect_language(state["messages"][state["offset"]][1])
# # Drop conversations if the model names are leaked
# leaked_identity = False
# messages = ""
# for i in range(2):
# state = row["states"][i]
# for turn_idx, (role, msg) in enumerate(
# state["messages"][state["offset"] :]
# ):
# if msg:
# messages += msg.lower()
# for word in IDENTITY_WORDS:
# if word in messages:
# leaked_identity = True
# break
# if leaked_identity:
# ct_leaked_identity += 1
# continue
def preprocess_model_name(m):
if m == "Playground v2":
return 'playground_PlayGroundV2_generation'
if m == "Playground v2.5":
return 'playground_PlayGroundV2.5_generation'
return m
models = [preprocess_model_name(m) for m in models]
# Replace bard with palm
if task_name == "image_editing":
valid = True
for _model in models:
try:
#platform, model_name, task = _model.split("_")
platform, model_name, task = parse_model_name(_model)
except ValueError:
valid = False
break
if not (platform in ["playground", "imagenhub"] and task == "edition"):
valid = False
break
if not valid:
ct_invalid += 1
continue
for i, _model in enumerate(models):
#platform, model_name, task = _model.split("_")
platform, model_name, task = parse_model_name(_model)
models[i] = model_name
# if not all(x.startswith("imagenhub_") and x.endswith("_edition") for x in models):
# # print(f"Invalid model names: {models}")
# ct_invalid += 1
# continue
# models = [x[len("imagenhub_"):-len("_edition")] for x in models]
elif task_name == "t2i_generation":
valid = True
for _model in models:
try:
#platform, model_name, task = _model.split("_")
platform, model_name, task = parse_model_name(_model)
except ValueError:
valid = False
break
if not (platform.lower() in ["playground", "imagenhub", 'fal'] and (task == "generation" or task == "text2image")):
valid = False
break
if not valid:
ct_invalid += 1
continue
for i, _model in enumerate(models):
#platform, model_name, task = _model.split("_")
platform, model_name, task = parse_model_name(_model)
models[i] = model_name
# if not all("playground" in x.lower() or (x.startswith("imagenhub_") and x.endswith("_generation")) for x in models):
# print(f"Invalid model names: {models}")
# ct_invalid += 1
# continue
# models = [x[len("imagenhub_"):-len("_generation")] for x in models]
# for i, model_name in enumerate(models):
# mode
# if model_name.startswith("imagenhub_"):
# models[i] = model_name[len("imagenhub_"):-len("_generation")]
elif task_name == "video_generation":
valid = True
for _model in models:
try:
#platform, model_name, task = _model.split("_")
platform, model_name, task = parse_model_name(_model)
except ValueError:
valid = False
break
if not (platform in ["videogenhub", "fal"] and task == "generation" or task == "text2video"):
valid = False
break
if not valid:
ct_invalid += 1
continue
for i, _model in enumerate(models):
#platform, model_name, task = _model.split("_")
platform, model_name, task = parse_model_name(_model)
models[i] = model_name
else:
raise ValueError(f"Invalid task_name: {task_name}")
models = [replace_model_name(m, row["tstamp"]) for m in models]
# Exclude certain models
if exclude_model_names and any(x in exclude_model_names for x in models):
ct_invalid += 1
continue
# if models[0] not in model_infos or models[1] not in model_infos:
# continue
# # Exclude votes before the starting date
# if model_infos and (model_infos[models[0]]["starting_from"] > row["tstamp"] or model_infos[models[1]]["starting_from"] > row["tstamp"]):
# print(f"Invalid vote before the valid starting date for {models[0]} and {models[1]}")
# ct_invalid += 1
# continue
if mode == "conv_release":
# assert the two images are the same
date = datetime.datetime.fromtimestamp(row["tstamp"], tz=timezone("US/Pacific")).strftime("%Y-%m-%d") # 2024-02-29
image_path_format = f"{LOG_ROOT_DIR}/{date}-convinput_images/input_image_"
image_path_0 = image_path_format + str(row["states"][0]["conv_id"]) + ".png"
image_path_1 = image_path_format + str(row["states"][1]["conv_id"]) + ".png"
if not os.path.exists(image_path_0) or not os.path.exists(image_path_1):
print(f"Image not found for {image_path_0} or {image_path_1}")
ct_invalid += 1
continue
image_0 = load_image(image_path_0)
image_1 = load_image(image_path_1)
if image_0 is None or image_1 is None:
print(f"Image not found for {image_path_0} or {image_path_1}")
ct_invalid += 1
continue
if image_0.tobytes() != image_1.tobytes():
print(f"Image not the same for {image_path_0} and {image_path_1}")
ct_invalid += 1
continue
question_id = row["states"][0]["conv_id"]
# conversation_a = to_openai_format(
# row["states"][0]["messages"][row["states"][0]["offset"] :]
# )
# conversation_b = to_openai_format(
# row["states"][1]["messages"][row["states"][1]["offset"] :]
# )
ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)}
all_ips[ip]["count"] += 1
if sanitize_ip:
user_id = f"arena_user_{all_ips[ip]['sanitized_id']}"
else:
user_id = f"{all_ips[ip]['ip']}"
if ban_ip_list is not None and ip in ban_ip_list:
ct_banned += 1
print(f"User {user_id} is banned")
continue
# Save the results
battles.append(
dict(
question_id=question_id,
model_a=models[0],
model_b=models[1],
winner=convert_type[row["type"]],
judge=f"arena_user_{user_id}",
# conversation_a=conversation_a,
# conversation_b=conversation_b,
# turn=len(conversation_a) // 2,
anony=anony,
# language=lang_code,
tstamp=row["tstamp"],
)
)
all_models.update(models_hidden)
battles.sort(key=lambda x: x["tstamp"])
last_updated_tstamp = battles[-1]["tstamp"]
last_updated_datetime = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y-%m-%d %H:%M:%S %Z")
print(
f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
f"#leaked_identity: {ct_leaked_identity} "
f"#banned: {ct_banned} "
)
print(f"#battles: {len(battles)}, #anony: {ct_anony}")
print(f"#models: {len(all_models)}, {all_models}")
print(f"last-updated: {last_updated_datetime}")
if ban_ip_list is not None:
for ban_ip in ban_ip_list:
if ban_ip in all_ips:
del all_ips[ban_ip]
print("Top 30 IPs:")
print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30])
return battles
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-num-files", type=int)
parser.add_argument(
"--mode", type=str, choices=["simple", "conv_release"], default="simple"
)
parser.add_argument("--task_name", type=str, default="image_editing", choices=["image_editing", "t2i_generation", "video_generation"])
parser.add_argument("--exclude-model-names", type=str, nargs="+")
parser.add_argument("--ban-ip-file", type=str)
parser.add_argument("--sanitize-ip", action="store_true", default=False)
args = parser.parse_args()
log_files = get_log_files(args.max_num_files)
ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None
battles = clean_battle_data(
log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip, args.mode, args.task_name
)
last_updated_tstamp = battles[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
).strftime("%Y%m%d")
if args.mode == "simple":
for x in battles:
for key in [
"conversation_a",
"conversation_b",
"question_id",
]:
if key in x:
del x[key]
print("Samples:")
for i in range(min(4, len(battles))):
print(battles[i])
output = f"clean_battle_{args.task_name}_{cutoff_date}.json"
elif args.mode == "conv_release":
# new_battles = []
# for x in battles:
# if not x["anony"]:
# continue
# for key in []:
# del x[key]
# new_battles.append(x)
# battles = new_battles
output = f"clean_battle_{args.task_name}_conv_{cutoff_date}.json"
with open(output, "w") as fout:
json.dump(battles, fout, indent=2, ensure_ascii=False)
print(f"Write cleaned data to {output}")
with open("cut_off_date.txt", "w") as fout:
fout.write(cutoff_date)