aiola commited on
Commit
8baa9e5
1 Parent(s): bbca453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -11
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
4
  import torchaudio
5
  import spaces
 
6
 
7
  # Initialize devices
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -12,13 +13,78 @@ processor = WhisperProcessor.from_pretrained("aiola/whisper-ner-v1")
12
  model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1")
13
  model = model.to(device)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")):
16
  """Process and standardize entity text by replacing certain symbols and normalizing spaces."""
17
- text = " ".join(text.split())
18
  for symbol in symbols_to_replace:
19
- text = text.replace(symbol, "-")
20
  return text.lower()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @spaces.GPU # This decorator ensures your function can use GPU on Hugging Face Spaces
23
  def transcribe_and_recognize_entities(audio_file, prompt):
24
  target_sample_rate = 16000
@@ -48,14 +114,56 @@ def transcribe_and_recognize_entities(audio_file, prompt):
48
  )
49
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
50
 
51
- return transcription
 
 
52
 
53
- iface = gr.Interface(
54
- fn=transcribe_and_recognize_entities,
55
- inputs=[gr.Audio(label="Upload Audio", type="filepath"), gr.Textbox(label="Entity Recognition Prompt")],
56
- outputs=gr.Textbox(label="Transcription and Entities"),
57
- title="Whisper-NER Demo",
58
- description="Upload an audio file and enter entities to identify. The model will transcribe the audio and recognize entities."
59
- )
60
 
61
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
  import torchaudio
5
  import spaces
6
+ import re
7
 
8
  # Initialize devices
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
13
  model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1")
14
  model = model.to(device)
15
 
16
+
17
+ examples = [
18
+ [
19
+ "audio/672-122797-0026.wav",
20
+ "monetary-value, biological-classification, desire, demographic-group, object-category, relationship-role, reflexive-pronoun, furniture-type"
21
+ ],
22
+ [
23
+ "audio/672-122797-0024.wav",
24
+ "health-warning, importance-indicator, event, sentiment"
25
+ ],
26
+ [
27
+ "audio/672-122797-0027.wav",
28
+ "action, emotional-resilience, comparative-path-characteristic, social-role"
29
+ ],
30
+ [
31
+ "audio/672-122797-0048.wav",
32
+ "weapon, emotional-state, household-chore, atmosphere-quality"
33
+ ],
34
+ [
35
+ "audio/7021-85628-0025.wav",
36
+ "action-goal, person's-title, emotional-connection, personal-qualities, pronoun-target, assignmentaction, physical-action, family-role"
37
+ ]
38
+ ]
39
+
40
+
41
  def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")):
42
  """Process and standardize entity text by replacing certain symbols and normalizing spaces."""
43
+ text = " ".join(text.split())
44
  for symbol in symbols_to_replace:
45
+ text = text.replace(symbol, "-")
46
  return text.lower()
47
 
48
+
49
+ def extract_entities_and_clean_text_fixed(text):
50
+ entity_pattern = r"<(.*?)>(.*?)<\1>>"
51
+ entities = []
52
+ clean_text = []
53
+ current_pos = 0
54
+
55
+ # Iterate through the matches for entity tags
56
+ for match in re.finditer(entity_pattern, text):
57
+ # Add text before the entity to the clean text
58
+ clean_text.append(text[current_pos:match.start()])
59
+
60
+ entity_type = match.group(1)
61
+ entity_text = match.group(2)
62
+ start_pos = len("".join(clean_text)) # Start position in the clean text
63
+ end_pos = start_pos + len(entity_text)
64
+
65
+ # Append the entity text to the clean text
66
+ clean_text.append(entity_text)
67
+
68
+ # Add the entity details to the list
69
+ entities.append({
70
+ "entity": entity_type,
71
+ "text": entity_text,
72
+ "start": start_pos,
73
+ "end": end_pos
74
+ })
75
+
76
+ # Update the current position to the end of the match
77
+ current_pos = match.end()
78
+
79
+ # Append the remaining part of the text after the last entity
80
+ clean_text.append(text[current_pos:])
81
+
82
+ # Join all parts of the clean text
83
+ clean_text_str = "".join(clean_text)
84
+
85
+ return clean_text_str, entities
86
+
87
+
88
  @spaces.GPU # This decorator ensures your function can use GPU on Hugging Face Spaces
89
  def transcribe_and_recognize_entities(audio_file, prompt):
90
  target_sample_rate = 16000
 
114
  )
115
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
116
 
117
+ clean_text_fixed, extracted_entities_fixed = extract_entities_and_clean_text_fixed(transcription)
118
+
119
+ return transcription, {"text": clean_text_fixed, "entities": extracted_entities_fixed}
120
 
 
 
 
 
 
 
 
121
 
122
+ with gr.Blocks(title="WhisperNER v1") as demo:
123
+
124
+ gr.Markdown(
125
+ """
126
+ # Whisper-NER: ASR with zero-shot NER
127
+
128
+ WhisperNER is a unified model for automatic speech recognition (ASR) and named entity recognition (NER), with zero-shot capabilities.
129
+
130
+ ## Links
131
+
132
+ * Paper: Paper: [WhisperNER: Unified Open Named Entity and Speech Recognition](https://arxiv.org/abs/2409.08107).
133
+ * Model: https://huggingface.co/aiola/whisper-ner-v1
134
+ * Code: https://github.com/aiola-lab/whisper-ner
135
+ """
136
+ )
137
+
138
+ with gr.Row() as row1:
139
+ with gr.Column() as col1:
140
+ audio_input = gr.Audio(label="Audio Example", type="filepath")
141
+ with gr.Column() as col2:
142
+ label_input = gr.Textbox(label="Entity Labels")
143
+
144
+ gr.Markdown("## Output")
145
+
146
+ with gr.Row() as row3:
147
+ transcript_output = gr.Textbox(label="Transcription and Entities")
148
+
149
+ with gr.Row() as row4:
150
+ highlighted_text_output = gr.HighlightedText(label="Predicted Highlighted Entities")
151
+
152
+ submit_btn = gr.Button("Submit")
153
+ examples = gr.Examples(
154
+ examples,
155
+ fn=transcribe_and_recognize_entities,
156
+ inputs=[audio_input, label_input],
157
+ outputs=[transcript_output, highlighted_text_output],
158
+ cache_examples=True,
159
+ run_on_click=True,
160
+ )
161
+
162
+ # Submitting
163
+ label_input.submit(
164
+ fn=transcribe_and_recognize_entities,
165
+ inputs=[audio_input, label_input],
166
+ outputs=[transcript_output, highlighted_text_output],
167
+ )
168
+
169
+ demo.launch()