aiola commited on
Commit
0d8c379
1 Parent(s): 8ec7a3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -6,7 +6,9 @@ import re # Import regex library
6
 
7
  # Load model and processor
8
  processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
9
- model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1").to("cuda")
 
 
10
 
11
  def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")):
12
  """Process and standardize entity text by replacing certain symbols and normalizing spaces."""
@@ -26,6 +28,8 @@ def transcribe_and_recognize_entities(audio_file, prompt):
26
 
27
  signal = signal.cpu() # Ensure signal is on CPU for processing
28
  input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features
 
 
29
 
30
  # Split the prompt into individual NER types and process each one
31
  ner_types = prompt.split(',')
@@ -34,10 +38,10 @@ def transcribe_and_recognize_entities(audio_file, prompt):
34
 
35
  print(f"Prompt after unify_ner_text: {prompt}")
36
  prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt")
37
- prompt_ids = prompt_ids.to("cuda")
38
 
39
  predicted_ids = model.generate(
40
- input_features.to("cuda"),
41
  max_new_tokens=256,
42
  prompt_ids=prompt_ids,
43
  language='en', # Ensure transcription is translated to English
 
6
 
7
  # Load model and processor
8
  processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
9
+ model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-v1")
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model = model.to(device)
12
 
13
  def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")):
14
  """Process and standardize entity text by replacing certain symbols and normalizing spaces."""
 
28
 
29
  signal = signal.cpu() # Ensure signal is on CPU for processing
30
  input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features
31
+ input_features = input_features.to(device)
32
+
33
 
34
  # Split the prompt into individual NER types and process each one
35
  ner_types = prompt.split(',')
 
38
 
39
  print(f"Prompt after unify_ner_text: {prompt}")
40
  prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt")
41
+ prompt_ids = prompt_ids.to(device)
42
 
43
  predicted_ids = model.generate(
44
+ input_features,
45
  max_new_tokens=256,
46
  prompt_ids=prompt_ids,
47
  language='en', # Ensure transcription is translated to English