barakmeiri commited on
Commit
45efc4b
·
verified ·
1 Parent(s): 39ce86d

Create editor.py

Browse files
Files changed (1) hide show
  1. src/editor.py +85 -0
src/editor.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.config import RunConfig
3
+ import PIL
4
+ from src.euler_scheduler import MyEulerAncestralDiscreteScheduler
5
+ from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
6
+ from src.sdxl_inversion_pipeline import SDXLDDIMPipeline
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+
10
+
11
+ def inversion_callback(pipe, step, timestep, callback_kwargs):
12
+ return callback_kwargs
13
+
14
+ def inference_callback(pipe, step, timestep, callback_kwargs):
15
+ return callback_kwargs
16
+
17
+ def center_crop(im):
18
+ width, height = im.size # Get dimensions
19
+ min_dim = min(width, height)
20
+ left = (width - min_dim) / 2
21
+ top = (height - min_dim) / 2
22
+ right = (width + min_dim) / 2
23
+ bottom = (height + min_dim) / 2
24
+
25
+ # Crop the center of the image
26
+ im = im.crop((left, top, right, bottom))
27
+ return im
28
+
29
+
30
+ def load_im_into_format_from_path(im_path):
31
+ return center_crop(PIL.Image.open(im_path)).resize((512, 512))
32
+
33
+
34
+ class ImageEditorDemo:
35
+ def __init__(self, pipe_inversion, pipe_inference, input_image, description_prompt, cfg):
36
+ self.pipe_inversion = pipe_inversion
37
+ self.pipe_inference = pipe_inference
38
+ self.original_image = load_im_into_format_from_path(input_image).convert("RGB")
39
+ self.load_image = True
40
+ g_cpu = torch.Generator().manual_seed(7865)
41
+ img_size = (512,512)
42
+ VQAE_SCALE = 8
43
+ latents_size = (1, 4, img_size[0] // VQAE_SCALE, img_size[1] // VQAE_SCALE)
44
+ noise = [randn_tensor(latents_size, dtype=torch.float16, device=torch.device("cuda:0"), generator=g_cpu) for i
45
+ in range(cfg.num_inversion_steps)]
46
+ pipe_inversion.scheduler.set_noise_list(noise)
47
+ pipe_inference.scheduler.set_noise_list(noise)
48
+ pipe_inversion.scheduler_inference.set_noise_list(noise)
49
+ pipe_inversion.set_progress_bar_config(disable=True)
50
+ pipe_inference.set_progress_bar_config(disable=True)
51
+ self.cfg = cfg
52
+ self.pipe_inversion.cfg = cfg
53
+ self.pipe_inference.cfg = cfg
54
+ self.inv_hp = [2, 0.1, 0.2]
55
+ self.edit_cfg = 1.2
56
+
57
+ self.pipe_inference.to("cuda")
58
+ self.pipe_inversion.to("cuda")
59
+
60
+ self.last_latent = self.invert(self.original_image, description_prompt)
61
+ self.original_latent = self.last_latent
62
+
63
+ def invert(self, init_image, base_prompt):
64
+ res = self.pipe_inversion(prompt=base_prompt,
65
+ num_inversion_steps=self.cfg.num_inversion_steps,
66
+ num_inference_steps=self.cfg.num_inference_steps,
67
+ image=init_image,
68
+ guidance_scale=self.cfg.guidance_scale,
69
+ callback_on_step_end=inversion_callback,
70
+ strength=self.cfg.inversion_max_step,
71
+ denoising_start=1.0 - self.cfg.inversion_max_step,
72
+ inv_hp=self.inv_hp)[0][0]
73
+ return res
74
+
75
+ def edit(self, target_prompt):
76
+ image = self.pipe_inference(prompt=target_prompt,
77
+ num_inference_steps=self.cfg.num_inference_steps,
78
+ negative_prompt="",
79
+ callback_on_step_end=inference_callback,
80
+ image=self.last_latent,
81
+ strength=self.cfg.inversion_max_step,
82
+ denoising_start=1.0 - self.cfg.inversion_max_step,
83
+ guidance_scale=self.edit_cfg).images[0]
84
+ return image
85
+