Amir Zait commited on
Commit
5a87575
β€’
1 Parent(s): 8439859

unify files

Browse files
Files changed (2) hide show
  1. app.py +46 -6
  2. image_generator.py +0 -44
app.py CHANGED
@@ -1,13 +1,16 @@
1
- from transformers import AutoProcessor, AutoModelForCTC
2
- from transformers import pipeline
3
-
4
  import soundfile as sf
5
  import gradio as gr
6
- import torch
7
- import sox
8
  import os
 
 
 
 
9
 
10
- from image_generator import generate_image
 
 
 
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
@@ -17,6 +20,43 @@ asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebr
17
 
18
  he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def convert(inputfile, outfile):
21
  sox_tfm = sox.Transformer()
22
  sox_tfm.set_output_format(
 
 
 
 
1
  import soundfile as sf
2
  import gradio as gr
3
+ import numpy as np
 
4
  import os
5
+ from PIL import Image
6
+ import random
7
+ import sox
8
+ import torch
9
 
10
+ from transformers import AutoProcessor, AutoModelForCTC
11
+ from transformers import pipeline
12
+ from dalle_mini import DalleBart, DalleBartProcessor
13
+ from vqgan_jax.modeling_flax_vqgan import VQModel
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
20
 
21
  he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
22
 
23
+ # Model references
24
+ # dalle-mini, mega too large
25
+ # DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or πŸ€— Hub or local folder or google bucket
26
+ DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
27
+ DALLE_COMMIT_ID = None
28
+
29
+ # VQGAN model
30
+ VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
31
+ VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
32
+
33
+ model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
34
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
35
+ processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
36
+
37
+ def generate_image(text):
38
+ tokenized_prompt = processor([text])
39
+
40
+ gen_top_k = None
41
+ gen_top_p = None
42
+ temperature = 0.85
43
+ cond_scale = 3.0
44
+
45
+ encoded_images = model.generate(
46
+ tokenized_prompt,
47
+ random.randint(0, 1e7),
48
+ model.params,
49
+ gen_top_k,
50
+ gen_top_p,
51
+ temperature,
52
+ cond_scale,
53
+ )
54
+ encoded_images = encoded_images.sequences[..., 1:]
55
+ decoded_images = model.decode(encoded_images, vqgan.params)
56
+ decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
57
+ img = decoded_images[0]
58
+ return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))
59
+
60
  def convert(inputfile, outfile):
61
  sox_tfm = sox.Transformer()
62
  sox_tfm.set_output_format(
image_generator.py DELETED
@@ -1,44 +0,0 @@
1
- import numpy as np
2
- import os
3
- from PIL import Image
4
- import random
5
-
6
- from dalle_mini import DalleBart, DalleBartProcessor
7
- from vqgan_jax.modeling_flax_vqgan import VQModel
8
-
9
- # Model references
10
- # dalle-mini, mega too large
11
- # DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or πŸ€— Hub or local folder or google bucket
12
- DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
13
- DALLE_COMMIT_ID = None
14
-
15
- # VQGAN model
16
- VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
17
- VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
18
-
19
- model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
20
- vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
21
- processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
22
-
23
- def generate_image(text):
24
- tokenized_prompt = processor([text])
25
-
26
- gen_top_k = None
27
- gen_top_p = None
28
- temperature = 0.85
29
- cond_scale = 3.0
30
-
31
- encoded_images = model.generate(
32
- tokenized_prompt,
33
- random.randint(0, 1e7),
34
- model.params,
35
- gen_top_k,
36
- gen_top_p,
37
- temperature,
38
- cond_scale,
39
- )
40
- encoded_images = encoded_images.sequences[..., 1:]
41
- decoded_images = model.decode(encoded_images, vqgan.params)
42
- decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
43
- img = decoded_images[0]
44
- return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))