Manyue-DataScientist commited on
Commit
b3635dd
·
verified ·
1 Parent(s): 2a6784d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -29
app.py CHANGED
@@ -1,41 +1,98 @@
1
  import streamlit as st
2
  from pyannote.audio import Pipeline
3
- from transformers import pipeline
4
  import whisper
 
 
 
 
5
 
6
- # Title
7
- st.title("Multi-Speaker Audio Analyzer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Upload Audio File
10
- uploaded_file = st.file_uploader("Upload an audio file (MP3/WAV)", type=["mp3", "wav"])
 
 
 
11
 
12
- # Process Button
13
- if uploaded_file:
14
- st.audio(uploaded_file, format='audio/wav')
 
15
 
16
- # Load pre-trained models
17
- diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization")
18
- transcription_model = whisper.load_model("base")
19
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
 
 
 
 
20
 
21
- # Perform Speaker Diarization
22
- st.write("Processing Speaker Diarization...")
23
- diarized_output = diarization_pipeline(uploaded_file)
 
 
 
 
 
 
 
 
 
24
 
25
- # Perform Speech-to-Text Transcription
26
- st.write("Transcribing Audio...")
27
- transcription = transcription_model.transcribe(uploaded_file)
28
 
29
- # Generate Summary
30
- st.write("Generating Summary...")
31
- summary = summarizer(transcription["text"])
32
 
33
- # Display Outputs
34
- st.write("Speaker-Diarized Transcript:")
35
- st.text(diarized_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- st.write("Full Transcription:")
38
- st.text(transcription["text"])
39
-
40
- st.write("Summary:")
41
- st.text(summary[0]['summary_text'])
 
1
  import streamlit as st
2
  from pyannote.audio import Pipeline
 
3
  import whisper
4
+ import tempfile
5
+ import os
6
+ import torch
7
+ from transformers import pipeline as tf_pipeline
8
 
9
+ # Cache the model loading using streamlit
10
+ @st.cache_resource
11
+ def load_models():
12
+ try:
13
+ # Load diarization model efficiently
14
+ diarization = Pipeline.from_pretrained(
15
+ "pyannote/speaker-diarization",
16
+ use_auth_token=st.secrets["hf_token"]
17
+ )
18
+
19
+ # Load smaller whisper model for faster processing
20
+ transcriber = whisper.load_model("base")
21
+
22
+ # Load efficient summarizer
23
+ summarizer = tf_pipeline(
24
+ "summarization",
25
+ model="facebook/bart-large-cnn",
26
+ device=0 if torch.cuda.is_available() else -1
27
+ )
28
+
29
+ return diarization, transcriber, summarizer
30
+ except Exception as e:
31
+ st.error(f"Error loading models: {str(e)}")
32
+ return None, None, None
33
 
34
+ def process_audio(audio_file, max_duration=300): # limit to 5 minutes initially
35
+ try:
36
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
37
+ tmp.write(audio_file.getvalue())
38
+ tmp_path = tmp.name
39
 
40
+ # Get cached models
41
+ diarization, transcriber, summarizer = load_models()
42
+ if not all([diarization, transcriber, summarizer]):
43
+ return "Model loading failed"
44
 
45
+ # Process with progress bar
46
+ with st.spinner("Identifying speakers..."):
47
+ diarization_result = diarization(tmp_path)
48
+
49
+ with st.spinner("Transcribing audio..."):
50
+ transcription = transcriber.transcribe(tmp_path)
51
+
52
+ with st.spinner("Generating summary..."):
53
+ summary = summarizer(transcription["text"], max_length=130, min_length=30)
54
 
55
+ # Cleanup
56
+ os.unlink(tmp_path)
57
+
58
+ return {
59
+ "diarization": diarization_result,
60
+ "transcription": transcription["text"],
61
+ "summary": summary[0]["summary_text"]
62
+ }
63
+
64
+ except Exception as e:
65
+ st.error(f"Error processing audio: {str(e)}")
66
+ return None
67
 
68
+ def main():
69
+ st.title("Multi-Speaker Audio Analyzer")
70
+ st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
71
 
72
+ uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
 
 
73
 
74
+ if uploaded_file:
75
+ st.audio(uploaded_file, format='audio/wav')
76
+
77
+ if st.button("Analyze Audio"):
78
+ results = process_audio(uploaded_file)
79
+
80
+ if results:
81
+ # Display results in tabs
82
+ tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
83
+
84
+ with tab1:
85
+ st.write("Speaker Segments:")
86
+ for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
87
+ st.write(f"{speaker}: {turn.start:.1f}s → {turn.end:.1f}s")
88
+
89
+ with tab2:
90
+ st.write("Transcription:")
91
+ st.write(results["transcription"])
92
+
93
+ with tab3:
94
+ st.write("Summary:")
95
+ st.write(results["summary"])
96
 
97
+ if __name__ == "__main__":
98
+ main()