aiola commited on
Commit
0f003aa
1 Parent(s): 4dbdaa7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -1
README.md CHANGED
@@ -9,4 +9,58 @@ tags:
9
  - Medusa
10
  - Speech
11
  - Speculative Decoding
12
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  - Medusa
10
  - Speech
11
  - Speculative Decoding
12
+ ---
13
+
14
+ # Whisper Medusa
15
+
16
+ Whisper is an advanced encoder-decoder model for speech transcription and
17
+ translation, processing audio through encoding and decoding stages. Given
18
+ its large size and slow inference speed, various optimization strategies like
19
+ Faster-Whisper and Speculative Decoding have been proposed to enhance performance.
20
+ Our Medusa model builds on Whisper by predicting multiple tokens per iteration,
21
+ which significantly improves speed with small degradation in WER. We train and
22
+ evaluate our model on the LibriSpeech dataset, demonstrating speed improvements.
23
+
24
+ ---------
25
+
26
+ ## Training Details
27
+ `aiola/whisper-medusa-v1` was trained on the LibriSpeech dataset to perform audio translation.
28
+ The Medusa heads were optimized for English, so for optimal performance and speed improvements, please use English audio only.
29
+
30
+ ---------
31
+
32
+ ## Usage
33
+ Inference can be done using the following code:
34
+ ```python
35
+ import torch
36
+ import torchaudio
37
+
38
+ from whisper_medusa import WhisperMedusaModel
39
+ from transformers import WhisperProcessor
40
+
41
+ model_name = "aiola/whisper-medusa-v1"
42
+ model = WhisperMedusaModel.from_pretrained(model_name)
43
+ processor = WhisperProcessor.from_pretrained(model_name)
44
+
45
+ path_to_audio = "path/to/audio.wav"
46
+ SAMPLING_RATE = 16000
47
+ language = "en"
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ input_speech, sr = torchaudio.load(path_to_audio)
51
+ if sr != SAMPLING_RATE:
52
+ input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
53
+
54
+ input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
55
+ input_features = input_features.to(device)
56
+
57
+ model = model.to(device)
58
+ model_output = model.generate(
59
+ input_features,
60
+ language=language,
61
+ )
62
+ predict_ids = model_output[0]
63
+ pred = processor.decode(predict_ids, skip_special_tokens=True)
64
+ print(pred)
65
+
66
+ ```