allandclive mekaneeky commited on
Commit
b2db4f9
·
0 Parent(s):

Duplicate from Sunbird/luganda2english-stt

Browse files

Co-authored-by: Ali <[email protected]>

Files changed (5) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +45 -0
  4. requirements.txt +4 -0
  5. stitched_model.py +30 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Luganda2English Speech Translation
3
+ emoji: 🦓
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.35.2
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: Sunbird/luganda2english-stt
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import json
5
+ from transformers import pipeline
6
+ from stitched_model import CombinedModel
7
+
8
+
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+
11
+ model = CombinedModel("ak3ra/wav2vec2-sunbird-speech-lug", "Sunbird/sunbird-mul-en-mbart-merged", device="cpu")
12
+
13
+
14
+
15
+ def transcribe(audio_file_mic=None, audio_file_upload=None):
16
+ if audio_file_mic:
17
+ audio_file = audio_file_mic
18
+ elif audio_file_upload:
19
+ audio_file = audio_file_upload
20
+ else:
21
+ return "Please upload an audio file or record one"
22
+
23
+ # Make sure audio is 16kHz
24
+ speech, sample_rate = librosa.load(audio_file)
25
+ if sample_rate != 16000:
26
+ speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
27
+ speech = torch.tensor([speech])
28
+
29
+ with torch.no_grad():
30
+ transcription, translation = model({"audio":speech})
31
+
32
+ return transcription, translation[0]
33
+
34
+ description = '''Luganda to English Speech Translation'''
35
+
36
+ iface = gr.Interface(fn=transcribe,
37
+ inputs=[
38
+ gr.Audio(source="microphone", type="filepath", label="Record Audio"),
39
+ gr.Audio(source="upload", type="filepath", label="Upload Audio")],
40
+ outputs=[gr.Textbox(label="Transcription"),
41
+ gr.Textbox(label="Translation")
42
+ ],
43
+ description=description
44
+ )
45
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers[torch]
2
+ librosa
3
+ sentencepiece
4
+ jiwer
stitched_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ class CombinedModel(nn.Module):
6
+ def __init__(self, stt_model_name, nmt_model_name,device = "cuda"):
7
+ super(CombinedModel, self).__init__()
8
+
9
+ self.stt_processor = Wav2Vec2Processor.from_pretrained(stt_model_name)
10
+ self.stt_model = Wav2Vec2ForCTC.from_pretrained(stt_model_name)
11
+ self.nmt_tokenizer = AutoTokenizer.from_pretrained(nmt_model_name)
12
+ self.nmt_model = AutoModelForSeq2SeqLM.from_pretrained(nmt_model_name)
13
+ self.device = device
14
+
15
+ def forward(self, batch, *args, **kwargs):
16
+ # Use stt_model to transcribe the audio to text
17
+ device = self.device
18
+ audio = torch.tensor(batch["audio"][0]).to(self.device)
19
+ input_features = self.stt_processor(audio,sampling_rate=16000, return_tensors="pt",max_length=110000, padding=True, truncation=True)
20
+ stt_output = self.stt_model(input_features.input_values.to(device), attention_mask= input_features.attention_mask.to(device) )
21
+ transcription = self.stt_processor.decode(torch.squeeze(stt_output.logits.argmax(axis=-1)).to(device))
22
+ input_nmt_tokens = self.nmt_tokenizer(transcription, return_tensors="pt", padding=True, truncation=True)
23
+ output_nmt_output = self.nmt_model.generate(input_ids = input_nmt_tokens.input_ids.to(device), attention_mask= input_nmt_tokens.attention_mask.to(device))
24
+ decoded_nmt_output = self.nmt_tokenizer.batch_decode(output_nmt_output, skip_special_tokens=True)
25
+
26
+
27
+ return transcription, decoded_nmt_output
28
+
29
+ # Usage
30
+ #model = CombinedModel("ak3ra/wav2vec2-sunbird-speech-lug", "Sunbird/sunbird-mul-en-mbart-merged", device="cpu")