from typing import Dict import torch import gradio as gr import whisper from whisper.tokenizer import get_tokenizer import classify from datasets import load_dataset model_cache = {} def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]: class_names = class_names.split(",") tokenizer = get_tokenizer(multilingual=".en" not in model_name) print("#########", model_name) if model_name not in model_cache: model = whisper.load_model(model_name) model_cache[model_name] = model else: model = model_cache[model_name] print("#### Model ####", model) internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs( model=model, class_names=class_names, tokenizer=tokenizer, ) audio_features = classify.calculate_audio_features(audio_path, model) average_logprobs = classify.calculate_average_logprobs( model=model, audio_features=audio_features, class_names=class_names, tokenizer=tokenizer, ) average_logprobs -= internal_lm_average_logprobs scores = average_logprobs.softmax(-1).tolist() return {class_name: score for class_name, score in zip(class_names, scores)} def main(): CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking],[popping],[sneezing],[sigh],[slurping],[mouth sounds],[clearing thoat]," AUDIO_PATHS = [ "./data/(dog)1-100032-A-0.wav", "./data/(helicopter)1-181071-A-40.wav", "./data/(laughing)1-1791-A-26.wav", "./data/(chirping_birds)1-34495-A-14.wav", "./data/(clock_tick)1-21934-A-38.wav", "./data/clears_throat1.wav", "./data/mouth_sounds1.wav", "./data/pop1.wav", "./data/sigh1.wav", "./data/slurp1.wav", ] EXAMPLES = [] for audio_path in AUDIO_PATHS: EXAMPLES.append([audio_path, CLASS_NAMES, "small"]) DESCRIPTION = ( '
This demo allows you to try out zero-shot audio classification using " "Whisper.
" "Github: https://github.com/jumon/zac
" "Example audio files are from the ESC-50" " dataset (CC BY-NC 3.0).