File size: 3,691 Bytes
f7c2e78
 
6c6d0a0
5a87575
 
 
 
 
f7c2e78
5a87575
 
 
 
be37091
f7c2e78
 
f7895ed
 
 
 
f7c2e78
077c45d
f7c2e78
5a87575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23dd537
6c6d0a0
7c6a43f
 
 
 
 
5a87575
 
2d31a01
5a87575
 
 
 
f7c2e78
 
 
 
 
 
 
 
e8b13db
f7c2e78
 
 
e8b13db
 
d8ec8f4
 
f7c2e78
d8ec8f4
be37091
23dd537
 
e8b13db
 
 
23dd537
 
e8b13db
be37091
793ea82
f7c2e78
793ea82
f7c2e78
 
8ee95b0
f7c2e78
 
 
 
d8ec8f4
f7c2e78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import soundfile as sf
import gradio as gr
import jax
import numpy as np
from PIL import Image
import random
import sox
import torch

from transformers import AutoProcessor, AutoModelForCTC
from transformers import pipeline
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
#asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-1b-hebrew")
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-1b-hebrew")

he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")

# Model references
# dalle-mini, mega too large
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
 
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

def generate_image(text):
    tokenized_prompt = processor([text])

    gen_top_k = None
    gen_top_p = None
    temperature = 0.85
    cond_scale = 3.0
 
    encoded_images = model.generate(
        **tokenized_prompt,
        prng_key=jax.random.PRNGKey(random.randint(0, 1e7)),
        params=model.params,
        top_k=gen_top_k,
        top_p=gen_top_p,
        temperature=temperature,
        condition_scale=cond_scale,
        )
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = vqgan.decode_code(encoded_images, vqgan.params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    img = decoded_images[0]
    return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))

def convert(inputfile, outfile):
    sox_tfm = sox.Transformer()
    sox_tfm.set_output_format(
        file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
    )
    sox_tfm.build(inputfile, outfile)

def parse_transcription(wav_file):
    # Get the wav file from the microphone
    filename = wav_file.name.split('.')[0]
    convert(wav_file.name, filename + "16k.wav")
    speech, _ = sf.read(filename + "16k.wav")

    # transcribe to hebrew
    input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
    logits = asr_model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
    
    print(transcription)

    # translate to english
    translated = he_en_translator(transcription)[0]['translation_text']

    print(translated) 

    # generate image
    image = generate_image(translated)
    return transcription, translated, image 

outputs = [gr.outputs.Textbox(label="transcript"), gr.outputs.Textbox(label="translated prompet"), gr.outputs.Image(label='')]
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)

gr.Interface(parse_transcription, inputs=[input_mic],  outputs=outputs,
             analytics_enabled=False,
             show_tips=False,
             theme='huggingface',
             layout='horizontal',
             title="Draw Me A Sheep in Hebrew",
             enable_queue=True).launch(inline=False)