FIRE / src /serve /monitor /deduplication.py
zhangbofei
feat: change to fstchat
6dc0c9c
import os
import json
import pandas as pd
import ast
import matplotlib.pyplot as plt
from matplotlib import rcParams
import argparse
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--percentile", type=float, default=0.9999)
args = parser.parse_args()
output_dir = args.output_dir
input_file = args.input_file
with open(input_file) as f:
data = json.load(f)
os.makedirs(output_dir, exist_ok=True)
# Preprocessing
all_convs_new = []
convs = []
for row in data:
conv = ""
for turns in row["conversation_a"]:
if turns["role"] == "user":
conv += f"{turns['content']}\n"
convs.append(conv[:10000])
row["post_process_conv"] = conv[:10000]
all_convs_new.append(row)
df = pd.DataFrame(all_convs_new)
print("Number of conversations: ", len(df))
prompt_counts = df["post_process_conv"].value_counts()
# Select the top 20 most frequent prompts
top_prompts = prompt_counts.head(20)
print(top_prompts)
# Determine the percentile count
percentile_cutoff = prompt_counts.quantile(args.percentile)
print(f"{args.percentile*100} percentile count: {percentile_cutoff}")
# prompts that are more common than the percentile cutoff
high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index
print(
f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}"
)
# initialize a new column dedup_tag
dedup_tags = np.array(
[{"high_freq": False, "sampled": True} for _ in range(len(df))]
)
high_freq_groups = df.groupby("post_process_conv")
for prompt in tqdm(high_frequency_prompts):
df_high_freq = high_freq_groups.get_group(prompt)
sampled_indices = df_high_freq.sample(
n=int(percentile_cutoff), random_state=42
).index
dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False}
dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True}
df["dedup_tag"] = dedup_tags
# drop intermediate columns (post_process_conv)
df = df.drop(columns=["post_process_conv"])
df.to_json(
os.path.join(output_dir, "dedup.json"),
orient="records",
indent=4,
force_ascii=False,
)