ysharma HF staff commited on
Commit
3b8e6ee
1 Parent(s): 2cc96ed
Files changed (1) hide show
  1. app.py +193 -4
app.py CHANGED
@@ -1,3 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
  fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
@@ -12,13 +200,14 @@ def engine(text_input):
12
  entities = [tupl for tupl in entities if None not in tupl]
13
  entities_num = len(entities)
14
 
 
15
  #img_intfc = gr.Interface.load("spaces/multimodalart/latentdiffusion")
16
- img_intfc = gr.Interface.load("spaces/multimodalart/latentdiffusion", inputs=[gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text")],
17
- outputs=[gr.outputs.Image(type="pil", label="output image"),gr.outputs.Carousel(label="Individual images",components=["image"]),gr.outputs.Textbox(label="Error")], )
18
  #title="Convert text to image")
19
  #img = img_intfc[0]
20
- img = img_intfc(('George','50','256','256','1','10'))
21
- img = img[0]
22
  #inputs=['George',50,256,256,1,10]
23
  #run(prompt, steps, width, height, images, scale)
24
 
 
1
+ ## **** below codelines are borrowed from multimodalart space
2
+ from pydoc import describe
3
+ import gradio as gr
4
+ import torch
5
+ from omegaconf import OmegaConf
6
+ import sys
7
+ sys.path.append(".")
8
+ sys.path.append('./taming-transformers')
9
+ sys.path.append('./latent-diffusion')
10
+ from taming.models import vqgan
11
+ from ldm.util import instantiate_from_config
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ model_path_e = hf_hub_download(repo_id="multimodalart/compvis-latent-diffusion-text2img-large", filename="txt2img-f8-large.ckpt")
15
+
16
+ #@title Import stuff
17
+ import argparse, os, sys, glob
18
+ import numpy as np
19
+ from PIL import Image
20
+ from einops import rearrange
21
+ from torchvision.utils import make_grid
22
+ import transformers
23
+ import gc
24
+ from ldm.util import instantiate_from_config
25
+ from ldm.models.diffusion.ddim import DDIMSampler
26
+ from ldm.models.diffusion.plms import PLMSSampler
27
+ from open_clip import tokenizer
28
+ import open_clip
29
+
30
+ def load_model_from_config(config, ckpt, verbose=False):
31
+ print(f"Loading model from {ckpt}")
32
+ pl_sd = torch.load(ckpt, map_location="cuda")
33
+ sd = pl_sd["state_dict"]
34
+ model = instantiate_from_config(config.model)
35
+ m, u = model.load_state_dict(sd, strict=False)
36
+ if len(m) > 0 and verbose:
37
+ print("missing keys:")
38
+ print(m)
39
+ if len(u) > 0 and verbose:
40
+ print("unexpected keys:")
41
+ print(u)
42
+
43
+ model = model.half().cuda()
44
+ model.eval()
45
+ return model
46
+
47
+ def load_safety_model(clip_model):
48
+ """load the safety model"""
49
+ import autokeras as ak # pylint: disable=import-outside-toplevel
50
+ from tensorflow.keras.models import load_model # pylint: disable=import-outside-toplevel
51
+ from os.path import expanduser # pylint: disable=import-outside-toplevel
52
+
53
+ home = expanduser("~")
54
+
55
+ cache_folder = home + "/.cache/clip_retrieval/" + clip_model.replace("/", "_")
56
+ if clip_model == "ViT-L/14":
57
+ model_dir = cache_folder + "/clip_autokeras_binary_nsfw"
58
+ dim = 768
59
+ elif clip_model == "ViT-B/32":
60
+ model_dir = cache_folder + "/clip_autokeras_nsfw_b32"
61
+ dim = 512
62
+ else:
63
+ raise ValueError("Unknown clip model")
64
+ if not os.path.exists(model_dir):
65
+ os.makedirs(cache_folder, exist_ok=True)
66
+
67
+ from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel
68
+
69
+ path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip"
70
+ if clip_model == "ViT-L/14":
71
+ url_model = "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
72
+ elif clip_model == "ViT-B/32":
73
+ url_model = (
74
+ "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_nsfw_b32.zip"
75
+ )
76
+ else:
77
+ raise ValueError("Unknown model {}".format(clip_model))
78
+ urlretrieve(url_model, path_to_zip_file)
79
+ import zipfile # pylint: disable=import-outside-toplevel
80
+
81
+ with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
82
+ zip_ref.extractall(cache_folder)
83
+
84
+ loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)
85
+ loaded_model.predict(np.random.rand(10 ** 3, dim).astype("float32"), batch_size=10 ** 3)
86
+
87
+ return loaded_model
88
+
89
+ def is_unsafe(safety_model, embeddings, threshold=0.5):
90
+ """find unsafe embeddings"""
91
+ nsfw_values = safety_model.predict(embeddings, batch_size=embeddings.shape[0])
92
+ x = np.array([e[0] for e in nsfw_values])
93
+ return True if x > threshold else False
94
+
95
+ config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
96
+ model = load_model_from_config(config,model_path_e)
97
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
98
+ model = model.to(device)
99
+
100
+ #NSFW CLIP Filter
101
+ safety_model = load_safety_model("ViT-B/32")
102
+ clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
103
+
104
+
105
+ def run(prompt, steps, width, height, images, scale):
106
+ opt = argparse.Namespace(
107
+ prompt = prompt,
108
+ outdir='latent-diffusion/outputs',
109
+ ddim_steps = int(steps),
110
+ ddim_eta = 0,
111
+ n_iter = 1,
112
+ W=int(width),
113
+ H=int(height),
114
+ n_samples=int(images),
115
+ scale=scale,
116
+ plms=True
117
+ )
118
+
119
+ if opt.plms:
120
+ opt.ddim_eta = 0
121
+ sampler = PLMSSampler(model)
122
+ else:
123
+ sampler = DDIMSampler(model)
124
+
125
+ os.makedirs(opt.outdir, exist_ok=True)
126
+ outpath = opt.outdir
127
+
128
+ prompt = opt.prompt
129
+
130
+
131
+ sample_path = os.path.join(outpath, "samples")
132
+ os.makedirs(sample_path, exist_ok=True)
133
+ base_count = len(os.listdir(sample_path))
134
+
135
+ all_samples=list()
136
+ all_samples_images=list()
137
+ with torch.no_grad():
138
+ with torch.cuda.amp.autocast():
139
+ with model.ema_scope():
140
+ uc = None
141
+ if opt.scale > 0:
142
+ uc = model.get_learned_conditioning(opt.n_samples * [""])
143
+ for n in range(opt.n_iter):
144
+ c = model.get_learned_conditioning(opt.n_samples * [prompt])
145
+ shape = [4, opt.H//8, opt.W//8]
146
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
147
+ conditioning=c,
148
+ batch_size=opt.n_samples,
149
+ shape=shape,
150
+ verbose=False,
151
+ unconditional_guidance_scale=opt.scale,
152
+ unconditional_conditioning=uc,
153
+ eta=opt.ddim_eta)
154
+
155
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
156
+ x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
157
+
158
+ for x_sample in x_samples_ddim:
159
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
160
+ image_vector = Image.fromarray(x_sample.astype(np.uint8))
161
+ image_preprocess = preprocess(image_vector).unsqueeze(0)
162
+ with torch.no_grad():
163
+ image_features = clip_model.encode_image(image_preprocess)
164
+ image_features /= image_features.norm(dim=-1, keepdim=True)
165
+ query = image_features.cpu().detach().numpy().astype("float32")
166
+ unsafe = is_unsafe(safety_model,query,0.5)
167
+ if(not unsafe):
168
+ all_samples_images.append(image_vector)
169
+ else:
170
+ return(None,None,"Sorry, potential NSFW content was detected on your outputs by our NSFW detection model. Try again with different prompts. If you feel your prompt was not supposed to give NSFW outputs, this may be due to a bias in the model. Read more about biases in the Biases Acknowledgment section below.")
171
+ #Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
172
+ base_count += 1
173
+ all_samples.append(x_samples_ddim)
174
+
175
+
176
+ # additionally, save as grid
177
+ grid = torch.stack(all_samples, 0)
178
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
179
+ grid = make_grid(grid, nrow=2)
180
+ # to image
181
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
182
+
183
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))
184
+ #return(Image.fromarray(grid.astype(np.uint8)),all_samples_images,None)
185
+ return Image.fromarray(grid.astype(np.uint8))
186
+
187
+ ## **** above codelines are borrowed from multimodalart space
188
+
189
  import gradio as gr
190
 
191
  fastspeech = gr.Interface.load("huggingface/facebook/fastspeech2-en-ljspeech")
 
200
  entities = [tupl for tupl in entities if None not in tupl]
201
  entities_num = len(entities)
202
 
203
+ img = run(entities[0],'50','256','256','1','10')
204
  #img_intfc = gr.Interface.load("spaces/multimodalart/latentdiffusion")
205
+ #img_intfc = gr.Interface.load("spaces/multimodalart/latentdiffusion", inputs=[gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text"), gr.inputs.Textbox(lines=1, label="Input Text")],
206
+ #outputs=[gr.outputs.Image(type="pil", label="output image"),gr.outputs.Carousel(label="Individual images",components=["image"]),gr.outputs.Textbox(label="Error")], )
207
  #title="Convert text to image")
208
  #img = img_intfc[0]
209
+ #img = img_intfc('George','50','256','256','1','10')
210
+ #img = img[0]
211
  #inputs=['George',50,256,256,1,10]
212
  #run(prompt, steps, width, height, images, scale)
213