Amir Zait commited on
Commit
e8b13db
β€’
1 Parent(s): d5d060c
Files changed (3) hide show
  1. app.py +7 -30
  2. image_generator.py +3 -5
  3. requirements.txt +0 -2
app.py CHANGED
@@ -3,7 +3,6 @@ from transformers import pipeline
3
 
4
  import soundfile as sf
5
  import gradio as gr
6
- import librosa
7
  import torch
8
  import sox
9
  import os
@@ -18,32 +17,6 @@ asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebr
18
 
19
  he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
20
 
21
- def process_audio_file(file):
22
- data, sr = librosa.load(file)
23
- if sr != 16000:
24
- data = librosa.resample(data, sr, 16000)
25
-
26
- input_values = asr_processor(data, sampling_rate=16_000, return_tensors="pt").input_values #.to(device)
27
- return input_values
28
-
29
- def transcribe(file_mic, file_upload):
30
- warn_output = ""
31
- if (file_mic is not None) and (file_upload is not None):
32
- warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
33
- file = file_mic
34
- elif (file_mic is None) and (file_upload is None):
35
- return "ERROR: You have to either use the microphone or upload an audio file"
36
- elif file_mic is not None:
37
- file = file_mic
38
- else:
39
- file = file_upload
40
-
41
- input_values = process_audio_file(file)
42
- logits = asr_model(input_values).logits
43
- predicted_ids = torch.argmax(logits, dim=-1)
44
- transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
45
- return warn_output + transcription
46
-
47
  def convert(inputfile, outfile):
48
  sox_tfm = sox.Transformer()
49
  sox_tfm.set_output_format(
@@ -52,22 +25,26 @@ def convert(inputfile, outfile):
52
  sox_tfm.build(inputfile, outfile)
53
 
54
  def parse_transcription(wav_file):
 
55
  filename = wav_file.name.split('.')[0]
56
  convert(wav_file.name, filename + "16k.wav")
57
  speech, _ = sf.read(filename + "16k.wav")
58
- print(speech.shape)
 
59
  input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
60
  logits = asr_model(input_values).logits
61
  predicted_ids = torch.argmax(logits, dim=-1)
62
  transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
63
- translated = he_en_translator(transcription)[0]['translation_text']
64
 
 
 
 
 
65
  image = generate_image(translated)
66
  return image
67
 
68
  output = gr.outputs.Image(label='')
69
  input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
70
- input_upload = gr.inputs.Audio(source="upload", type="file", optional=True)
71
 
72
  gr.Interface(parse_transcription, inputs=[input_mic], outputs=output,
73
  analytics_enabled=False,
 
3
 
4
  import soundfile as sf
5
  import gradio as gr
 
6
  import torch
7
  import sox
8
  import os
 
17
 
18
  he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def convert(inputfile, outfile):
21
  sox_tfm = sox.Transformer()
22
  sox_tfm.set_output_format(
 
25
  sox_tfm.build(inputfile, outfile)
26
 
27
  def parse_transcription(wav_file):
28
+ # Get the wav file from the microphone
29
  filename = wav_file.name.split('.')[0]
30
  convert(wav_file.name, filename + "16k.wav")
31
  speech, _ = sf.read(filename + "16k.wav")
32
+
33
+ # transcribe to hebrew
34
  input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
35
  logits = asr_model(input_values).logits
36
  predicted_ids = torch.argmax(logits, dim=-1)
37
  transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
 
38
 
39
+ # translate to english
40
+ translated = he_en_translator(transcription)[0]['translation_text']
41
+
42
+ # generate image
43
  image = generate_image(translated)
44
  return image
45
 
46
  output = gr.outputs.Image(label='')
47
  input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
 
48
 
49
  gr.Interface(parse_transcription, inputs=[input_mic], outputs=output,
50
  analytics_enabled=False,
image_generator.py CHANGED
@@ -7,13 +7,11 @@ from dalle_mini import DalleBart, DalleBartProcessor
7
  from vqgan_jax.modeling_flax_vqgan import VQModel
8
 
9
  # Model references
10
- # dalle-mega
11
- DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or πŸ€— Hub or local folder or google bucket
 
12
  DALLE_COMMIT_ID = None
13
 
14
- # if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
15
- # DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
16
-
17
  # VQGAN model
18
  VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
19
  VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
 
7
  from vqgan_jax.modeling_flax_vqgan import VQModel
8
 
9
  # Model references
10
+ # dalle-mini, mega too large
11
+ # DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or πŸ€— Hub or local folder or google bucket
12
+ DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
13
  DALLE_COMMIT_ID = None
14
 
 
 
 
15
  # VQGAN model
16
  VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
17
  VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
  gradio
2
- librosa
3
  soundfile
4
  torch
5
  transformers
6
  sox
7
- sentencepiece
8
  dalle-mini
9
  Pillow
10
  numpy
 
1
  gradio
 
2
  soundfile
3
  torch
4
  transformers
5
  sox
 
6
  dalle-mini
7
  Pillow
8
  numpy