File size: 12,682 Bytes
b462bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23a93a3
b462bee
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import os, sys, time, re, pdb
import torch, torchvision
import numpy
from PIL import Image
import hashlib
from tqdm import tqdm
import openai
from utils.direction_utils import *

p = "submodules/pix2pix-zero/src/utils"
if p not in sys.path:
    sys.path.append(p)
from diffusers import DDIMScheduler
from edit_directions import construct_direction
from edit_pipeline import EditingPipeline
from ddim_inv import DDIMInversion
from scheduler import DDIMInverseScheduler
from lavis.models import load_model_and_preprocess
from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM



def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
    with torch.no_grad():
        l_embeddings = []
        for sent in tqdm(l_sentences):
            text_inputs = tokenizer(
                    sent,
                    padding="max_length",
                    max_length=tokenizer.model_max_length,
                    truncation=True,
                    return_tensors="pt",
                )
            text_input_ids = text_inputs.input_ids
            prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
            l_embeddings.append(prompt_embeds)
    return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)



def launch_generate_sample(prompt, seed, negative_scale, num_ddim):
    os.makedirs("tmp", exist_ok=True)
    # do the editing
    edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
    edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)

    # set the random seed and sample the input noise map
    torch.cuda.manual_seed(int(seed))
    z = torch.randn((1,4,64,64), device="cuda")

    z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest()
    z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt"
    torch.save(z, z_inv_fname)

    rec_pil = edit_pipe(prompt, 
        num_inference_steps=num_ddim, x_in=z,
        only_sample=True, # this flag will only generate the sampled image, not the edited image
        guidance_scale=negative_scale,
        negative_prompt="" # use the empty string for the negative prompt
    )
    # print(rec_pil)
    del edit_pipe
    torch.cuda.empty_cache()

    return rec_pil[0], z_inv_fname



def clean_l_sentences(ls):
    s = [re.sub('\d', '', x) for x in ls]
    s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s]
    return s



def gpt3_compute_word2sentences(task_type, word, num=100):
    l_sentences = [] 
    if task_type=="object":
        template_prompt = f"Provide many captions for images containing {word}."
    elif task_type=="style":
        template_prompt = f"Provide many captions for images that are in the {word} style."
    while True:
        ret = openai.Completion.create(
            model="text-davinci-002",
            prompt=template_prompt,
            max_tokens=1000,
            temperature=1.0)
        raw_return = ret.choices[0].text
        for line in raw_return.split("\n"):
            line = line.strip()
            if len(line)>10:
                skip=False 
                for subword in word.split(" "):
                    if subword not in line: skip=True
                if not skip: l_sentences.append(line)
                else:
                    l_sentences.append(line+f", {word}")
        time.sleep(0.05)
        print(len(l_sentences))
        if len(l_sentences)>=num:
            break
    l_sentences = clean_l_sentences(l_sentences)
    return l_sentences


def flant5xl_compute_word2sentences(word, num=100):
    text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters."
    
    l_sentences = []
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
    input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to("cuda")
    input_length = input_ids.shape[1]
    while True:
        outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128)
        output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
        for line in output:
            line = line.strip()
            skip=False 
            for subword in word.split(" "):
                if subword not in line: skip=True
            if not skip: l_sentences.append(line)
            else: l_sentences.append(line+f", {word}")
        print(len(l_sentences))
        if len(l_sentences)>=num:
            break
    l_sentences = clean_l_sentences(l_sentences)

    del model
    del tokenizer
    torch.cuda.empty_cache()

    return l_sentences

def bloomz_compute_sentences(word, num=100):
    l_sentences = []
    tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
    model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16)
    input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:"
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
    input_length = input_ids.shape[1]
    t = 0.95
    eta = 1e-5
    min_length = 15

    while True:
        try:
            outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta)
            output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
        except:
            continue
        for line in output:
            line = line.strip()
            skip=False 
            for subword in word.split(" "):
                if subword not in line: skip=True
            if not skip: l_sentences.append(line)
            else: l_sentences.append(line+f", {word}")
        print(len(l_sentences))
        if len(l_sentences)>=num:
            break
    l_sentences = clean_l_sentences(l_sentences)
    del model
    del tokenizer
    torch.cuda.empty_cache()

    return l_sentences



def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences):
    if sent_type=="fixed-template":
        l_sentences = generate_image_prompts_with_templates(description)
    elif "GPT3" in sent_type:
        import openai
        openai.organization = org_key
        openai.api_key = api_key
        _=openai.Model.retrieve("text-davinci-002")
        l_sentences = gpt3_compute_word2sentences("object", description, num=1000)
    
    elif "flan-t5-xl" in sent_type:
        l_sentences = flant5xl_compute_word2sentences(description, num=1000)
        # save the sentences to file
        with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f:
            for line in l_sentences:
                f.write(line+"\n")
    elif "BLOOMZ-7B" in sent_type:
        l_sentences = bloomz_compute_sentences(description, num=1000)
        # save the sentences to file
        with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f:
            for line in l_sentences:
                f.write(line+"\n")
    
    elif sent_type=="custom sentences":
        l_sentences = l_custom_sentences.split("\n")
        print(f"length of new sentence is {len(l_sentences)}")

    pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
    emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device="cuda")
    del pipe
    torch.cuda.empty_cache()
    return emb


def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest):
    d_name2desc = get_all_directions_names()
    d_desc2name = {v:k for k,v in d_name2desc.items()}
    os.makedirs("tmp", exist_ok=True)

    # generate custom direction first
    if src=="make your own!":
        outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt"
        if not os.path.exists(outf_name):
            src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src)
            torch.save(src_emb, outf_name)
        else:
            src_emb = torch.load(outf_name)
    else:
        src_emb = get_emb(d_desc2name[src])
    
    if dest=="make your own!":
        outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt"
        if not os.path.exists(outf_name):
            dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest)
            torch.save(dest_emb, outf_name)
        else:
            dest_emb = torch.load(outf_name)
    else:
        dest_emb = get_emb(d_desc2name[dest])
    text_dir = (dest_emb.cuda() - src_emb.cuda())*edit_mul



    if img_in_real is not None and img_in_synth is None:
        print("using real image")
        # resize the image so that the longer side is 512
        width, height = img_in_real.size
        if width > height: scale_factor = 512 / width
        else: scale_factor = 512 / height
        new_size = (int(width * scale_factor), int(height * scale_factor))
        img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS)
        hash = hashlib.sha256(img_in_real.tobytes()).hexdigest()
        # print(hash)
        inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt"
        caption_fname = f"tmp/{hash}_caption.txt"

        # make the caption if it hasn't been made before
        if not os.path.exists(caption_fname):
            # BLIP
            model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
            _image = vis_processors["eval"](img_in_real).unsqueeze(0).cuda()
            prompt_str = model_blip.generate({"image": _image})[0]
            del model_blip
            torch.cuda.empty_cache()
            with open(caption_fname, "w") as f:
                f.write(prompt_str)
        else:
            prompt_str = open(caption_fname, "r").read().strip()
        print(f"CAPTION: {prompt_str}")
        
        # do the inversion if it hasn't been done before
        if not os.path.exists(inv_fname):
            # inversion pipeline
            pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
            pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config)
            x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str, 
                    guidance_scale=1, num_inversion_steps=num_ddim,
                    img=img_in_real, torch_dtype=torch.float32 )
            x_inv = x_inv.detach()
            torch.save(x_inv, inv_fname)
            del pipe_inv
            torch.cuda.empty_cache()
        else:
            x_inv = torch.load(inv_fname)

        # do the editing
        edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
        edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config)

        _, edit_pil = edit_pipe(prompt_str,
                num_inference_steps=num_ddim,
                x_in=x_inv,
                edit_dir=text_dir,
                guidance_amount=xa_guidance,
                guidance_scale=5.0,
                negative_prompt=prompt_str # use the unedited prompt for the negative prompt
        )
        del edit_pipe
        torch.cuda.empty_cache()
        return edit_pil[0]


    elif img_in_real is None and img_in_synth is not None:
        print("using synthetic image")
        x_inv = torch.load(fpath_z_gen)
        pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to("cuda")
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        rec_pil, edit_pil = pipe(gen_prompt,
            num_inference_steps=num_ddim,
            x_in=x_inv,
            edit_dir=text_dir,
            guidance_amount=xa_guidance,
            guidance_scale=5,
            negative_prompt="" # use the empty string for the negative prompt
        )
        del pipe
        torch.cuda.empty_cache()
        return edit_pil[0]

    else:
        raise ValueError(f"Invalid image type found: {img_in_real} {img_in_synth}")



if __name__=="__main__":
    print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100))