Amir Zait commited on
Commit
be37091
1 Parent(s): 0d9345a

added dalle

Browse files
Files changed (3) hide show
  1. app.py +10 -3
  2. image_generator.py +46 -0
  3. requirements.txt +4 -0
app.py CHANGED
@@ -8,6 +8,8 @@ import torch
8
  import sox
9
  import os
10
 
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  api_token = os.getenv("API_TOKEN")
@@ -49,6 +51,9 @@ def convert(inputfile, outfile):
49
  )
50
  sox_tfm.build(inputfile, outfile)
51
 
 
 
 
52
  def parse_transcription(wav_file):
53
  filename = wav_file.name.split('.')[0]
54
  convert(wav_file.name, filename + "16k.wav")
@@ -58,10 +63,12 @@ def parse_transcription(wav_file):
58
  logits = asr_model(input_values).logits
59
  predicted_ids = torch.argmax(logits, dim=-1)
60
  transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
61
- translated = he_en_translator(transcription)
62
- return translated
 
 
63
 
64
- output = gr.outputs.Textbox(label="TEXT")
65
  input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
66
  input_upload = gr.inputs.Audio(source="upload", type="file", optional=True)
67
 
 
8
  import sox
9
  import os
10
 
11
+ from image_generator import generate_image
12
+
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  api_token = os.getenv("API_TOKEN")
 
51
  )
52
  sox_tfm.build(inputfile, outfile)
53
 
54
+ def generate_image(text):
55
+ pass
56
+
57
  def parse_transcription(wav_file):
58
  filename = wav_file.name.split('.')[0]
59
  convert(wav_file.name, filename + "16k.wav")
 
63
  logits = asr_model(input_values).logits
64
  predicted_ids = torch.argmax(logits, dim=-1)
65
  transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
66
+ translated = he_en_translator(transcription)[0]['translation_text']
67
+
68
+ image = generate_image(translated)
69
+ return image
70
 
71
+ output = gr.outputs.Image(label='')
72
  input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
73
  input_upload = gr.inputs.Audio(source="upload", type="file", optional=True)
74
 
image_generator.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ from dalle_mini import DalleBart, DalleBartProcessor
6
+ from vqgan_jax.modeling_flax_vqgan import VQModel
7
+
8
+ # Model references
9
+
10
+ # dalle-mega
11
+ DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or 🤗 Hub or local folder or google bucket
12
+ DALLE_COMMIT_ID = None
13
+
14
+ # if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
15
+ # DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
16
+
17
+ # VQGAN model
18
+ VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
19
+ VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
20
+
21
+ model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
22
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
23
+ processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
24
+
25
+ def get_image(text):
26
+ tokenized_prompt = processor([text])
27
+
28
+ gen_top_k = None
29
+ gen_top_p = None
30
+ temperature = 0.85
31
+ cond_scale = 3.0
32
+
33
+ encoded_images = model.generate(
34
+ tokenized_prompt,
35
+ random.randint(0, 1e7),
36
+ model.params,
37
+ gen_top_k,
38
+ gen_top_p,
39
+ temperature,
40
+ cond_scale,
41
+ )
42
+ encoded_images = encoded_images.sequences[..., 1:]
43
+ decoded_images = model.decode(encoded_images, vqgan.params)
44
+ decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
45
+ img = decoded_images[0]
46
+ return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))
requirements.txt CHANGED
@@ -5,3 +5,7 @@ torch
5
  transformers
6
  sox
7
  sentencepiece
 
 
 
 
 
5
  transformers
6
  sox
7
  sentencepiece
8
+ vqgan-jax
9
+ dalle-mini
10
+ PIL
11
+ numpy