File size: 6,196 Bytes
19912cb
 
 
 
94d76bf
8baa9e5
94d76bf
 
 
19912cb
 
6eff469
0d8c379
 
19912cb
8baa9e5
 
f69bb9b
 
 
 
 
 
 
 
8baa9e5
 
5e00e82
 
 
 
 
8baa9e5
 
 
 
 
 
 
 
 
 
 
 
f69bb9b
8baa9e5
 
 
19912cb
 
8baa9e5
19912cb
8baa9e5
19912cb
 
8baa9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94d76bf
19912cb
 
 
 
 
 
 
 
 
0d8c379
 
19912cb
 
 
 
 
 
0d8c379
19912cb
 
0d8c379
19912cb
 
94d76bf
19912cb
 
 
 
8baa9e5
 
 
19912cb
 
8baa9e5
 
 
 
 
 
 
5e00e82
8baa9e5
 
 
86e7832
8baa9e5
 
 
 
 
 
 
 
 
 
 
475db72
 
8baa9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475db72
 
 
 
 
8baa9e5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import torchaudio
import spaces
import re

# Initialize devices
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and processor
processor = WhisperProcessor.from_pretrained("aiola/whisper-ner-v1")
model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1")
model = model.to(device)


examples = [
    [
        "audio/sports.wav",
        "football-club, football-player, action"
    ],
    [
        "audio/entertainment.wav",
        "movie, date, actor, tv-show, musician"
    ],
    [
        "audio/672-122797-0026.wav",
        "biological-classification, desire, demographic-group, object-category, relationship-role, reflexive-pronoun, furniture-type"
    ],
    [
        "audio/7021-85628-0025.wav",
        "action-goal, person's-title, emotional-connection, personal-qualities, pronoun-target, assignmentaction, physical-action, family-role"
    ],
    [
        "audio/672-122797-0024.wav",
        "health-warning, importance-indicator, event, sentiment"
    ],
    [
        "audio/672-122797-0027.wav",
        "action, emotional-resilience, comparative-path-characteristic, social-role"
    ],
    [
        "audio/672-122797-0048.wav",
        "weapon, emotional-state, household-chore, atmosphere-quality"
    ],
]


def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")):
    """Process and standardize entity text by replacing certain symbols and normalizing spaces."""
    text = " ".join(text.split())
    for symbol in symbols_to_replace:
        text = text.replace(symbol, "-")
    return text.lower()


def extract_entities_and_clean_text_fixed(text):
    entity_pattern = r"<(.*?)>(.*?)<\1>>"
    entities = []
    clean_text = []
    current_pos = 0

    # Iterate through the matches for entity tags
    for match in re.finditer(entity_pattern, text):
        # Add text before the entity to the clean text
        clean_text.append(text[current_pos:match.start()])

        entity_type = match.group(1)
        entity_text = match.group(2)
        start_pos = len("".join(clean_text))  # Start position in the clean text
        end_pos = start_pos + len(entity_text)

        # Append the entity text to the clean text
        clean_text.append(entity_text)

        # Add the entity details to the list
        entities.append({
            "entity": entity_type,
            "text": entity_text,
            "start": start_pos,
            "end": end_pos
        })

        # Update the current position to the end of the match
        current_pos = match.end()

    # Append the remaining part of the text after the last entity
    clean_text.append(text[current_pos:])

    # Join all parts of the clean text
    clean_text_str = "".join(clean_text)

    return clean_text_str, entities


@spaces.GPU  # This decorator ensures your function can use GPU on Hugging Face Spaces
def transcribe_and_recognize_entities(audio_file, prompt):
    target_sample_rate = 16000
    signal, sampling_rate = torchaudio.load(audio_file)
    resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=target_sample_rate)
    signal = resampler(signal)
    if signal.ndim == 2:
        signal = torch.mean(signal, dim=0)

    input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features
    input_features = input_features.to(device)

    ner_types = prompt.split(',')
    processed_ner_types = [unify_ner_text(ner_type.strip()) for ner_type in ner_types]
    prompt = ", ".join(processed_ner_types)

    print(f"Prompt after unify_ner_text: {prompt}")
    prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt")
    prompt_ids = prompt_ids.to(device)

    predicted_ids = model.generate(
        input_features,
        max_new_tokens=256,
        prompt_ids=prompt_ids,
        language='en',
        generation_config=model.generation_config,
    )
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    clean_text_fixed, extracted_entities_fixed = extract_entities_and_clean_text_fixed(transcription)

    return transcription, {"text": clean_text_fixed, "entities": extracted_entities_fixed}


with gr.Blocks(title="WhisperNER v1") as demo:

    gr.Markdown(
        """
        # Whisper-NER: ASR with zero-shot NER

        WhisperNER is a unified model for automatic speech recognition (ASR) and named entity recognition (NER), with zero-shot capabilities.
        The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance.

        ## Links

        * Paper: [WhisperNER: Unified Open Named Entity and Speech Recognition](https://arxiv.org/abs/2409.08107).
        * Model: https://huggingface.co./aiola/whisper-ner-v1
        * Code: https://github.com/aiola-lab/whisper-ner
        """
    )

    with gr.Row() as row1:
        with gr.Column() as col1:
            audio_input = gr.Audio(label="Audio Example", type="filepath")
        with gr.Column() as col2:
            label_input = gr.Textbox(label="Entity Labels")

    submit_btn = gr.Button("Submit")
    
    gr.Markdown("## Output")

    with gr.Row() as row3:
        transcript_output = gr.Textbox(label="Transcription and Entities")

    with gr.Row() as row4:
        highlighted_text_output = gr.HighlightedText(label="Predicted Highlighted Entities")

    examples = gr.Examples(
        examples,
        fn=transcribe_and_recognize_entities,
        inputs=[audio_input, label_input],
        outputs=[transcript_output, highlighted_text_output],
        cache_examples=True,
        run_on_click=True,
    )

    # Submitting
    label_input.submit(
        fn=transcribe_and_recognize_entities,
        inputs=[audio_input, label_input],
        outputs=[transcript_output, highlighted_text_output],
    )
    submit_btn.click(
        fn=transcribe_and_recognize_entities,
        inputs=[audio_input, label_input],
        outputs=[transcript_output, highlighted_text_output],
    )

    demo.launch()