Labbeti commited on
Commit
5f47c66
1 Parent(s): ae94a43

Mod: Update UI to store microphone input in microphone_conette_record.wav file, raises an error when the audio is too short or too long, update main description and show other candidates in outputs.

Browse files
Files changed (2) hide show
  1. .gitignore +1 -1
  2. app.py +63 -26
.gitignore CHANGED
@@ -1 +1 @@
1
- record.wav
 
1
+ microphone_conette_record.wav
app.py CHANGED
@@ -2,12 +2,14 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
5
- from typing import Any, Optional
6
 
7
  import streamlit as st
 
8
 
9
  from st_audiorec import st_audiorec
10
  from streamlit.runtime.uploaded_file_manager import UploadedFile
 
11
 
12
  from conette import CoNeTTEModel, conette
13
  from conette.utils.collections import dict_list_to_list_dict
@@ -17,9 +19,11 @@ ALLOW_REP_MODES = ("stopwords", "all", "none")
17
  MAX_BEAM_SIZE = 20
18
  MAX_PRED_SIZE = 30
19
  MAX_BATCH_SIZE = 32
20
- RECORD_AUDIO_FNAME = "record.wav"
21
  DEFAULT_THRESHOLD = 0.3
22
  THRESHOLD_PRECISION = 100
 
 
23
 
24
 
25
  @st.cache_resource
@@ -49,20 +53,34 @@ def get_results(
49
  model: CoNeTTEModel,
50
  audio_files: dict[str, bytes],
51
  generate_kwds: dict[str, Any],
52
- ) -> dict[str, dict[str, Any]]:
53
  # Get audio to be processed
54
- audio_to_predict: dict[str, bytes] = {}
55
  for audio_fname, audio in audio_files.items():
56
  result_hash = get_result_hash(audio_fname, generate_kwds)
57
  if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME:
58
- audio_to_predict[result_hash] = audio
59
 
60
  # Save audio to be processed
61
  tmp_files: dict[str, _TemporaryFileWrapper] = {}
62
- for result_hash, audio in audio_to_predict.items():
63
- tmp_file = NamedTemporaryFile()
64
  tmp_file.write(audio)
65
- tmp_files[result_hash] = tmp_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Generate predictions and store them in session state
68
  for start in range(0, len(tmp_files), MAX_BATCH_SIZE):
@@ -74,8 +92,6 @@ def get_results(
74
  tmp_paths_j,
75
  **generate_kwds,
76
  )
77
- for tmp_file in tmp_files_j:
78
- tmp_file.close()
79
  outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore
80
  for result_hash, output_i in zip(result_hashes_j, outputs_lst):
81
  st.session_state[result_hash] = output_i
@@ -90,46 +106,67 @@ def get_results(
90
  return outputs
91
 
92
 
93
- def show_results(outputs: dict[str, dict[str, Any]]) -> None:
94
  st.divider()
95
 
96
  for audio_fname, output in outputs.items():
97
- cand = output["cands"]
98
- lprobs = output["lprobs"]
99
- tags = output.get("tags")
 
 
 
 
 
 
 
100
 
101
  cand = format_candidate(cand)
102
- tags = format_tags(tags)
103
  prob = lprobs.exp().tolist()
 
 
 
 
 
 
 
104
 
105
  if audio_fname == RECORD_AUDIO_FNAME:
106
  header = "##### Result for microphone input:"
107
  else:
108
  header = f'##### Result for "{audio_fname}"'
109
 
110
- content = f"""
111
- {header}
112
- - **Description:** "{cand}"
113
- - **Mean confidence:** {prob*100:.0f}%
114
- - **Tags:** {tags}"""
115
- st.markdown(content)
 
 
 
 
 
 
 
 
116
  st.divider()
117
 
118
 
119
  def main() -> None:
120
- st.header("Describe audio content with CoNeTTE")
121
-
122
  model = load_conette(model_kwds=dict(device="cpu"))
123
 
124
- # st.warning(
125
- # "Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum."
126
- # )
 
127
 
128
  record_data = st_audiorec()
129
  audio_files: Optional[list[UploadedFile]] = st.file_uploader(
130
  "**Or upload audio files here:**",
131
  type=["wav", "flac", "mp3", "ogg", "avi"],
132
  accept_multiple_files=True,
 
133
  )
134
 
135
  with st.expander("Model hyperparameters"):
 
2
  # -*- coding: utf-8 -*-
3
 
4
  from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
5
+ from typing import Any, Optional, Union
6
 
7
  import streamlit as st
8
+ import torchaudio
9
 
10
  from st_audiorec import st_audiorec
11
  from streamlit.runtime.uploaded_file_manager import UploadedFile
12
+ from torch import Tensor
13
 
14
  from conette import CoNeTTEModel, conette
15
  from conette.utils.collections import dict_list_to_list_dict
 
19
  MAX_BEAM_SIZE = 20
20
  MAX_PRED_SIZE = 30
21
  MAX_BATCH_SIZE = 32
22
+ RECORD_AUDIO_FNAME = "microphone_conette_record.wav"
23
  DEFAULT_THRESHOLD = 0.3
24
  THRESHOLD_PRECISION = 100
25
+ MIN_AUDIO_DURATION_SEC = 0.3
26
+ MAX_AUDIO_DURATION_SEC = 60
27
 
28
 
29
  @st.cache_resource
 
53
  model: CoNeTTEModel,
54
  audio_files: dict[str, bytes],
55
  generate_kwds: dict[str, Any],
56
+ ) -> dict[str, Union[dict[str, Any], str]]:
57
  # Get audio to be processed
58
+ audio_to_predict: dict[str, tuple[str, bytes]] = {}
59
  for audio_fname, audio in audio_files.items():
60
  result_hash = get_result_hash(audio_fname, generate_kwds)
61
  if result_hash not in st.session_state or audio_fname == RECORD_AUDIO_FNAME:
62
+ audio_to_predict[result_hash] = (audio_fname, audio)
63
 
64
  # Save audio to be processed
65
  tmp_files: dict[str, _TemporaryFileWrapper] = {}
66
+ for result_hash, (audio_fname, audio) in audio_to_predict.items():
67
+ tmp_file = NamedTemporaryFile(delete=False)
68
  tmp_file.write(audio)
69
+ tmp_file.close()
70
+
71
+ metadata = torchaudio.info(tmp_file.name) # type: ignore
72
+ duration = metadata.num_frames / metadata.sample_rate
73
+
74
+ if MIN_AUDIO_DURATION_SEC > duration:
75
+ error_msg = f"Audio file is too short. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}])"
76
+ st.session_state[result_hash] = error_msg
77
+
78
+ elif duration > MAX_AUDIO_DURATION_SEC:
79
+ error_msg = f"Audio file is too long. (found {duration:.2f}s but the model expect audio in range [{MIN_AUDIO_DURATION_SEC}, {MAX_AUDIO_DURATION_SEC}])"
80
+ st.session_state[result_hash] = error_msg
81
+
82
+ else:
83
+ tmp_files[result_hash] = tmp_file
84
 
85
  # Generate predictions and store them in session state
86
  for start in range(0, len(tmp_files), MAX_BATCH_SIZE):
 
92
  tmp_paths_j,
93
  **generate_kwds,
94
  )
 
 
95
  outputs_lst = dict_list_to_list_dict(outputs_j) # type: ignore
96
  for result_hash, output_i in zip(result_hashes_j, outputs_lst):
97
  st.session_state[result_hash] = output_i
 
106
  return outputs
107
 
108
 
109
+ def show_results(outputs: dict[str, Union[dict[str, Any], str]]) -> None:
110
  st.divider()
111
 
112
  for audio_fname, output in outputs.items():
113
+ if isinstance(output, str):
114
+ st.error(output)
115
+ st.divider()
116
+ continue
117
+
118
+ cand: str = output["cands"]
119
+ lprobs: Tensor = output["lprobs"]
120
+ tags_lst = output.get("tags")
121
+ mult_cands: list[str] = output["mult_cands"]
122
+ mult_lprobs: Tensor = output["mult_lprobs"]
123
 
124
  cand = format_candidate(cand)
 
125
  prob = lprobs.exp().tolist()
126
+ tags = format_tags(tags_lst)
127
+ mult_cands = [format_candidate(cand_i) for cand_i in mult_cands]
128
+ mult_probs = mult_lprobs.exp()
129
+
130
+ indexes = mult_probs.argsort(descending=True)[1:]
131
+ mult_probs = mult_probs[indexes].tolist()
132
+ mult_cands = [mult_cands[idx] for idx in indexes]
133
 
134
  if audio_fname == RECORD_AUDIO_FNAME:
135
  header = "##### Result for microphone input:"
136
  else:
137
  header = f'##### Result for "{audio_fname}"'
138
 
139
+ content = [
140
+ header,
141
+ f'- **Description:** "{cand}" ({prob*100:.1f}%)',
142
+ f"- **Tags:** {tags}",
143
+ ]
144
+ if len(mult_cands) > 0:
145
+ msg = f"- **Other descriptions:**"
146
+ content.append(msg)
147
+
148
+ for cand_i, prob_i in zip(mult_cands, mult_probs):
149
+ msg = f' - "{cand_i}" ({prob_i*100:.1f}%)'
150
+ content.append(msg)
151
+
152
+ st.success("\n".join(content))
153
  st.divider()
154
 
155
 
156
  def main() -> None:
 
 
157
  model = load_conette(model_kwds=dict(device="cpu"))
158
 
159
+ st.header("Describe audio content with CoNeTTE")
160
+ st.markdown(
161
+ "This interface allows you to generate a short description of the sound events of any recording. You can try it from your microphone or upload a file below."
162
+ )
163
 
164
  record_data = st_audiorec()
165
  audio_files: Optional[list[UploadedFile]] = st.file_uploader(
166
  "**Or upload audio files here:**",
167
  type=["wav", "flac", "mp3", "ogg", "avi"],
168
  accept_multiple_files=True,
169
+ help="Recommanded audio: lasting from **1 to 30s**, sampled at **32 kHz** minimum.",
170
  )
171
 
172
  with st.expander("Model hyperparameters"):