#!/usr/bin/env python # -*- coding: utf-8 -*- import os import time from tempfile import NamedTemporaryFile, _TemporaryFileWrapper from typing import Any, Optional, Union import streamlit as st import torchaudio from st_audiorec import st_audiorec from streamlit.runtime.uploaded_file_manager import UploadedFile from torch import Tensor from conette import CoNeTTEModel, conette from conette.utils.collections import dict_list_to_list_dict ALLOW_REP_MODES = ("stopwords", "all", "none") DEFAULT_TASK = "audiocaps" MAX_BEAM_SIZE = 20 MAX_PRED_SIZE = 30 MAX_BATCH_SIZE = 16 RECORD_AUDIO_FNAME = "microphone_conette_record.wav" DEFAULT_THRESHOLD = 0.3 THRESHOLD_PRECISION = 100 MIN_AUDIO_DURATION_SEC = 0.3 MAX_AUDIO_DURATION_SEC = 60 HASH_PREFIX = "hash_" TMP_FILE_PREFIX = "audio_tmp_file_" SECOND_BEFORE_CLEAR_CACHE = 10 * 60 @st.cache_resource def load_conette(*args, **kwargs) -> CoNeTTEModel: return conette(*args, **kwargs) def format_candidate(candidate: str) -> str: if len(candidate) == 0: return "" else: return f"{candidate[0].title()}{candidate[1:]}." def format_tags(tags: Optional[list[str]]) -> str: if tags is None or len(tags) == 0: return "None." else: return ", ".join(tags) def get_result_hash(audio_fname: str, generate_kwds: dict[str, Any]) -> str: return f"{HASH_PREFIX}{audio_fname}-{generate_kwds}" def get_results( model: CoNeTTEModel, audio_files: dict[str, bytes], generate_kwds: dict[str, Any], ) -> dict[str, Union[dict[str, Any], str]]: # Get audio to be processed audio_to_predict: dict[str, tuple[str, bytes]] = {} for audio_fname, audio in audio_files.items(): result_hash = get_result_hash(audio_fname, generate_kwds) if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME: audio_to_predict[result_hash] = (audio_fname, audio) # Save audio to be processed tmp_files: dict[str, _TemporaryFileWrapper] = {} for result_hash, (audio_fname, audio) in audio_to_predict.items(): tmp_file = NamedTemporaryFile(delete=False, prefix=TMP_FILE_PREFIX) tmp_file.write(audio) tmp_file.close() metadata = torchaudio.info(tmp_file.name) # type: ignore duration = metadata.num_frames / metadata.sample_rate if MIN_AUDIO_DURATION_SEC > duration: error_msg = f""" ##### Result for "{audio_fname}" Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) """ st.session_state[result_hash] = error_msg elif duration > MAX_AUDIO_DURATION_SEC: error_msg = f""" ##### Result for "{audio_fname}" Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}]) """ st.session_state[result_hash] = error_msg else: tmp_files[result_hash] = tmp_file # Generate predictions and store them in session state for start in range(0, len(tmp_files), MAX_BATCH_SIZE): end = min(start + MAX_BATCH_SIZE, len(tmp_files)) result_hashes_j = list(tmp_files.keys())[start:end] tmp_files_j = list(tmp_files.values())[start:end] tmp_paths_j = [tmp_file.name for tmp_file in tmp_files_j] outputs_j = model( tmp_paths_j, **generate_kwds, ) outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore for result_hash, output_i in zip(result_hashes_j, outputs_lst): st.session_state[result_hash] = output_i # Get outputs outputs = {} for audio_fname in audio_files.keys(): result_hash = get_result_hash(audio_fname, generate_kwds) output_i = st.session_state[result_hash] outputs[audio_fname] = output_i for tmp_file in tmp_files.values(): os.remove(tmp_file.name) return outputs def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None: keys = list(outputs.keys())[::-1] outputs = {key: outputs[key] for key in keys} st.divider() for audio_fname, output in outputs.items(): if isinstance(output, str): st.error(output) st.divider() continue cand: str = output["cands"] lprobs: Tensor = output["lprobs"] tags_lst = output.get("tags") mult_cands: list[str] = output["mult_cands"] mult_lprobs: Tensor = output["mult_lprobs"] cand = format_candidate(cand) prob = lprobs.exp().tolist() tags = format_tags(tags_lst) mult_cands = [format_candidate(cand_i) for cand_i in mult_cands] mult_probs = mult_lprobs.exp() indexes = mult_probs.argsort(descending=True)[1:] mult_probs = mult_probs[indexes].tolist() mult_cands = [mult_cands[idx] for idx in indexes] if audio_fname == RECORD_AUDIO_FNAME: header = "##### Result for microphone input:" else: header = f'##### Result for "{audio_fname}"' lines = [ header, f'

"{cand}"

', ] st.markdown( """ """, unsafe_allow_html=True, ) content = "
".join(lines) st.markdown(content, unsafe_allow_html=True) lines = [ f"- **Probability**: {prob*100:.1f}%", ] if len(mult_cands) > 0: msg = f"- **Other descriptions:**" lines.append(msg) for cand_i, prob_i in zip(mult_cands, mult_probs): msg = f' - "{cand_i}" ({prob_i*100:.1f}%)' lines.append(msg) msg = f"- **Tags:** {tags}" lines.append(msg) content = "\n".join(lines) st.markdown(content, unsafe_allow_html=False) st.divider() def main() -> None: model = load_conette(model_kwds=dict(device="cpu")) st.header("Describe audio content with CoNeTTE") st.markdown( "This interface allows you to generate a short description of the sound events of any recording using an Audio Captioning system. You can try it from your microphone or upload a file below." ) st.markdown( "Use '**Start Recording**' and '**Stop**' to record an audio from your microphone." ) record_data = st_audiorec() with st.expander("Or upload audio files here:"): audio_files: Optional[list[UploadedFile]] = st.file_uploader( f"Audio files are automatically resampled to 32 kHz.\nTheir duration must be in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}] seconds.", type=["wav", "flac", "mp3", "ogg", "avi"], accept_multiple_files=True, help="Supports wav, flac, mp3, ogg and avi files.", ) with st.expander("Model options"): if DEFAULT_TASK in model.tasks: default_task_idx = list(model.tasks).index(DEFAULT_TASK) else: default_task_idx = 0 task = st.selectbox("Task embedding input", model.tasks, default_task_idx) allow_rep_mode = st.selectbox("Allow repetition of words", ALLOW_REP_MODES, 0) beam_size: int = st.select_slider( # type: ignore "Beam size", list(range(1, MAX_BEAM_SIZE + 1)), model.config.beam_size, ) min_pred_size, max_pred_size = st.slider( "Minimal and maximal number of words", 1, MAX_PRED_SIZE, (model.config.min_pred_size, model.config.max_pred_size), ) threshold = st.select_slider( "Tags threshold", [(i / THRESHOLD_PRECISION) for i in range(THRESHOLD_PRECISION + 1)], DEFAULT_THRESHOLD, ) if allow_rep_mode == "all": forbid_rep_mode = "none" elif allow_rep_mode == "none": forbid_rep_mode = "all" elif allow_rep_mode == "stopwords": forbid_rep_mode = "content_words" else: msg = ( f"Unknown option {allow_rep_mode=}. (expected one of {ALLOW_REP_MODES})" ) raise ValueError(msg) del allow_rep_mode generate_kwds: dict[str, Any] = dict( task=task, beam_size=beam_size, min_pred_size=min_pred_size, max_pred_size=max_pred_size, forbid_rep_mode=forbid_rep_mode, threshold=threshold, ) audios: dict[str, bytes] = {} if audio_files is not None: audios |= {audio.name: audio.getvalue() for audio in audio_files} if record_data is not None: audios |= {RECORD_AUDIO_FNAME: record_data} if len(audios) > 0: with st.spinner("Generating descriptions..."): outputs = get_results(model, audios, generate_kwds) st.header("Results:") show_results(outputs) current = time.perf_counter() last_generation = st.session_state.get("last_generation", current) if current > last_generation + SECOND_BEFORE_CLEAR_CACHE: print(f"Removing result cache...") for key in st.session_state.keys(): if isinstance(key, str) and key.startswith(HASH_PREFIX): del st.session_state[key] st.session_state["last_generation"] = current if __name__ == "__main__": main()