|
import json |
|
import os |
|
import shutil |
|
import sys |
|
from collections import defaultdict |
|
from statistics import mean |
|
|
|
import pandas as pd |
|
import requests |
|
|
|
from text_normalizer import text_normalizer |
|
from utils import compute_average_wer, download_dataset |
|
|
|
|
|
def fetch_evaluation_data(url): |
|
""" |
|
Fetches evaluation data from the given URL. |
|
:param url: The URL to fetch the evaluation data from. |
|
:returns: The evaluation data as a dictionary. |
|
:rauses: sys.exit if the request fails |
|
""" |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
return json.loads(response.text) |
|
else: |
|
sys.exit(f"Failed to fetch WhisperKit evals: {response.text}") |
|
|
|
|
|
def get_device_name(device): |
|
""" |
|
Gets the device name from the device map if it exists. |
|
:param device: String representing the device name. |
|
:returns: The device name from the device map if it exists, otherwise the input device name. |
|
""" |
|
with open("dashboard_data/device_map.json", "r") as f: |
|
device_map = json.load(f) |
|
return device_map.get(device, device).replace(" ", "_") |
|
|
|
|
|
def process_quality_file(file_path, dataset_dfs, quality_results): |
|
""" |
|
Processes a single quality file and updates the quality_results dictionary. |
|
|
|
:param file_path: Path to the quality JSON file. |
|
:param dataset_dfs: Dictionary of DataFrames containing dataset information. |
|
:param quality_results: Dictionary to store the processed quality results. |
|
|
|
This function reads a quality JSON file, extracts relevant information, |
|
and updates the quality_results dictionary with various metrics including WER |
|
and Quality of Inference (QoI) for different datasets. |
|
""" |
|
with open(file_path, "r") as file: |
|
test_results = json.load(file) |
|
|
|
if len(test_results) == 0: |
|
return |
|
|
|
model = file_path.split("/")[-3].replace("_", "/") |
|
device = "Linux" |
|
timestamp = file_path.split("/")[-1].split(".")[0] |
|
key = model |
|
dataset_name = file_path.split("/")[-2] |
|
|
|
for test_result in test_results: |
|
audio_file_name = test_result["testInfo"]["audioFile"] |
|
|
|
dataset_key = "Earnings-22" if "earnings22" in dataset_name else "LibriSpeech" |
|
dataset_df = dataset_dfs[dataset_key] |
|
|
|
wer_entry = { |
|
"prediction": text_normalizer(test_result["testInfo"]["prediction"]), |
|
"reference": text_normalizer(test_result["testInfo"]["reference"]), |
|
} |
|
quality_results[key]["timestamp"] = timestamp |
|
quality_results[key]["dataset_wer"][dataset_name].append(wer_entry) |
|
|
|
audio = audio_file_name.split("-")[0] |
|
dataset_row = dataset_df.loc[dataset_df["file"].str.contains(audio)].iloc[0] |
|
reference_wer = dataset_row["wer"] |
|
prediction_wer = test_result["testInfo"]["wer"] |
|
|
|
quality_results[key]["qoi"].append(1 if prediction_wer <= reference_wer * 110 else 0) |
|
|
|
|
|
def calculate_and_save_quality_results(quality_results, quality_output_path): |
|
""" |
|
Calculates final quality metrics and saves them to a JSON file. |
|
|
|
:param quality_results: Dictionary containing raw quality data. |
|
:param quality_output_path: Path to save the processed quality results. |
|
|
|
This function processes the raw quality data, calculates average metrics, |
|
and writes the final results to a JSON file, with each entry representing |
|
a unique model's quality metrics across different datasets, including |
|
Word Error Rate (WER) and Quality of Inference (QoI). |
|
""" |
|
with open(quality_output_path, "w") as quality_file: |
|
for key, data in quality_results.items(): |
|
model = key |
|
|
|
dataset_wers = { |
|
dataset: compute_average_wer(wer) |
|
for dataset, wer in data["dataset_wer"].items() |
|
} |
|
average_wer = ( |
|
sum(dataset_wers.values()) / len(dataset_wers) |
|
if len(dataset_wers) != 0 |
|
else 0 |
|
) |
|
|
|
quality_entry = { |
|
"model": model.replace("_", "/"), |
|
"timestamp": data["timestamp"], |
|
"average_wer": round(average_wer, 2), |
|
"dataset_wer": dataset_wers, |
|
"qoi": round(mean(data["qoi"]), 2), |
|
} |
|
|
|
json.dump(quality_entry, quality_file) |
|
quality_file.write("\n") |
|
|
|
|
|
def main(): |
|
""" |
|
Main function to orchestrate the quality data generation process. |
|
|
|
This function performs the following steps: |
|
1. Downloads quality data if requested. |
|
2. Fetches evaluation data for various datasets. |
|
3. Processes quality files for specific datasets. |
|
4. Calculates and saves quality results, including WER and QoI metrics. |
|
""" |
|
if len(sys.argv) > 1 and sys.argv[1] == "download": |
|
try: |
|
shutil.rmtree("english") |
|
except: |
|
print("Nothing to remove.") |
|
download_dataset("argmaxinc/whisperkit-evals", "english", "WhisperKit") |
|
|
|
datasets = { |
|
"Earnings-22": "https://huggingface.co./datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", |
|
"LibriSpeech": "https://huggingface.co./datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", |
|
"earnings22-10mins": "https://huggingface.co./datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", |
|
"librispeech-10mins": "https://huggingface.co./datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", |
|
"earnings22-12hours": "https://huggingface.co./datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", |
|
"librispeech": "https://huggingface.co./datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", |
|
} |
|
|
|
dataset_dfs = {} |
|
for dataset_name, url in datasets.items(): |
|
evals = fetch_evaluation_data(url) |
|
dataset_dfs[dataset_name] = pd.json_normalize(evals["results"]) |
|
|
|
source_quality_directory = "argmaxinc/english/WhisperKit/" |
|
|
|
quality_results = defaultdict( |
|
lambda: { |
|
"average_wer": [], |
|
"dataset_wer": defaultdict(list), |
|
"qoi": [], |
|
"timestamp": None, |
|
} |
|
) |
|
|
|
for subdir, _, files in os.walk(source_quality_directory): |
|
dataset = subdir.split("/")[-1] |
|
if dataset not in ["earnings22-10mins", "librispeech-10mins"]: |
|
continue |
|
|
|
for filename in files: |
|
if not filename.endswith(".json"): |
|
continue |
|
|
|
file_path = os.path.join(subdir, filename) |
|
process_quality_file(file_path, dataset_dfs, quality_results) |
|
|
|
calculate_and_save_quality_results( |
|
quality_results, "dashboard_data/quality_data.json" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|