shariqfarooq commited on
Commit
13f1a87
·
1 Parent(s): 5a85f92
Files changed (3) hide show
  1. app.py +125 -0
  2. cross_frame_attention.py +120 -0
  3. loosecontrol.py +135 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dataclasses import dataclass
3
+ import PIL
4
+ import PIL.Image
5
+
6
+ import torch
7
+ import numpy as np
8
+ from gradio_editor3d import Editor3D as g3deditor
9
+ import copy
10
+ from loosecontrol import LooseControlNet
11
+
12
+
13
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
14
+ cn = LooseControlNet()
15
+ cn.pipe = cn.pipe.to(torch_device=device, torch_dtype=torch.float16)
16
+
17
+ # Need to figure out a better way how to do this per user, making 'cf attention' act like a state per user.
18
+ # For now, we just copy the model.
19
+ cn_with_cf = copy.deepcopy(cn)
20
+ cn_with_cf.set_cf_attention()
21
+
22
+
23
+ @dataclass
24
+ class FixedInputs:
25
+ prompt: str
26
+ seed: int
27
+ depth: PIL.Image.Image
28
+
29
+
30
+ negative_prompt = "blurry, text, caption, lowquality, lowresolution, low res, grainy, ugly"
31
+ def depth2image(prompt, seed, depth):
32
+ seed = int(seed)
33
+ gen = cn(prompt, control_image=depth, controlnet_conditioning_scale=1.0, generator=torch.Generator().manual_seed(seed), num_inference_steps=20, negative_prompt=negative_prompt)
34
+ return gen
35
+
36
+ def edit_previous(prompt, seed, depth, fixed_inputs):
37
+ seed = int(seed)
38
+ control_image = [fixed_inputs.depth, depth]
39
+ prompt = [fixed_inputs.prompt, prompt]
40
+ neg_prompt = [negative_prompt, negative_prompt]
41
+ generator = [torch.Generator().manual_seed(fixed_inputs.seed), torch.Generator().manual_seed(seed)]
42
+ gen = cn_with_cf(prompt, control_image=control_image, controlnet_conditioning_scale=1.0, generator=generator, num_inference_steps=20, negative_prompt=neg_prompt)[-1]
43
+ return gen
44
+
45
+ def run(prompt, seed, depth, should_edit, fixed_inputs):
46
+ depth = depth.convert("RGB")
47
+ # all values below [3,3,3] in depth should actually be set to [255,255,255]
48
+ # This is to due the nature of training data and is experimental right now.
49
+ # Not in use for now.
50
+ # depth = np.array(depth)
51
+ # depth[depth < 3] = 255
52
+ # depth = PIL.Image.fromarray(depth)
53
+
54
+ fixed_inputs = fixed_inputs[0]
55
+ if should_edit and fixed_inputs is not None:
56
+ return edit_previous(prompt, seed, depth, fixed_inputs)
57
+ else:
58
+ return depth2image(prompt, seed, depth)
59
+
60
+ def handle_edit_change(edit, prompt, seed, image_input, fixed_inputs):
61
+ if edit:
62
+ fixed_inputs[0] = FixedInputs(prompt, int(seed), image_input)
63
+ else:
64
+ fixed_inputs[0] = None
65
+ return fixed_inputs
66
+
67
+
68
+ css = """
69
+
70
+ #image_output {
71
+ width: 512px;
72
+ height: 512px;
73
+ """
74
+
75
+
76
+ main_description = """
77
+ # LooseControl
78
+
79
+ This is the official demo for the paper [LooseControl: Lifting ControlNet for Generalized Depth Conditioning](https://shariqfarooq123.github.io/loose-control/).
80
+ Our 3D Box Editing allows users to interactively edit the 3D boxes representing objects in the scene. Users can change the position, size, and orientation of 3D boxes, allowing to quickly create and edit the scenes to their liking in a 3D-aware manner.
81
+ Best viewed on desktop.
82
+ """
83
+
84
+ instructions_editor3d = """
85
+ ## Instructions for Editor3D UI
86
+ - Use 'WASD' keys to move the camera.
87
+ - Click on an object to select it.
88
+ - Use the sliders to change the position, size, and orientation of the selected object. Sliders support click and drag for faster editing.
89
+ - Use the 'Add Box', 'Delete', and 'Duplicate' buttons to add, delete, and duplicate objects.
90
+ - Delete and Duplicate buttons work on the selected object. Duplicate creates a copy and selects it.
91
+ - Use the 'Toggle Mode' to switch between "normal" and "depth" mode. Final image sent to the model should be in "depth" mode.
92
+ - Use the 'Render' button to render the scene and send it to the model for generation.
93
+
94
+ ### Lock style checkbox - Fixes the style of the latest generated image.
95
+ This allows users to edit the 3D boxes without changing the style of the generated image. This is useful when the user is satisfied with the style/content of the generated image and wants to edit the 3D boxes without changing the overall essence of the scene.
96
+ It can be used to create stop motion videos like those shown [here](https://shariqfarooq123.github.io/loose-control/).
97
+
98
+ """
99
+
100
+
101
+
102
+ with gr.Blocks(css=css) as demo:
103
+ gr.Markdown(main_description)
104
+
105
+ fixed_inputs = gr.State([None])
106
+ with gr.Row():
107
+ prompt = gr.Textbox(label="Prompt", placeholder="Write your prompt", elem_id="input")
108
+ seed = gr.Textbox(value=42, label="Seed", elem_id="seed")
109
+ should_edit = gr.Checkbox(label="Lock style", elem_id="edit")
110
+
111
+ with gr.Row():
112
+ image_input = g3deditor(elem_id="image_input")
113
+
114
+ with gr.Row():
115
+ image_output = gr.Image(elem_id="image_output", type='pil')
116
+
117
+ should_edit.change(fn=handle_edit_change, inputs=[should_edit, prompt, seed, image_input, fixed_inputs], outputs=[fixed_inputs])
118
+ image_input.change(fn=run, inputs=[prompt, seed, image_input, should_edit, fixed_inputs], outputs=[image_output])
119
+ with gr.Accordion("Instructions"):
120
+ gr.Markdown(instructions_editor3d)
121
+
122
+ demo.queue().launch()
123
+
124
+
125
+
cross_frame_attention.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+ class CrossFrameAttnProcessor:
5
+ def __init__(self, unet_chunk_size=2):
6
+ self.unet_chunk_size = unet_chunk_size
7
+
8
+ def __call__(
9
+ self,
10
+ attn,
11
+ hidden_states,
12
+ encoder_hidden_states=None,
13
+ attention_mask=None, **kwargs):
14
+ batch_size, sequence_length, _ = hidden_states.shape
15
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
16
+ query = attn.to_q(hidden_states)
17
+
18
+ is_cross_attention = encoder_hidden_states is not None
19
+ if encoder_hidden_states is None:
20
+ encoder_hidden_states = hidden_states
21
+ elif attn.norm_cross:
22
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
23
+ key = attn.to_k(encoder_hidden_states)
24
+ value = attn.to_v(encoder_hidden_states)
25
+ # Sparse Attention
26
+ if not is_cross_attention:
27
+ video_length = key.size()[0] // self.unet_chunk_size
28
+ # print("Video length is", video_length)
29
+ # former_frame_index = torch.arange(video_length) - 1
30
+ # former_frame_index[0] = 0
31
+ former_frame_index = [0] * video_length
32
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
33
+ key = key[:, former_frame_index]
34
+ key = rearrange(key, "b f d c -> (b f) d c")
35
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
36
+ value = value[:, former_frame_index]
37
+ value = rearrange(value, "b f d c -> (b f) d c")
38
+
39
+ query = attn.head_to_batch_dim(query)
40
+ key = attn.head_to_batch_dim(key)
41
+ value = attn.head_to_batch_dim(value)
42
+
43
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
44
+ hidden_states = torch.bmm(attention_probs, value)
45
+ hidden_states = attn.batch_to_head_dim(hidden_states)
46
+
47
+ # linear proj
48
+ hidden_states = attn.to_out[0](hidden_states)
49
+ # dropout
50
+ hidden_states = attn.to_out[1](hidden_states)
51
+
52
+ return hidden_states
53
+
54
+
55
+
56
+ class AttnProcessorX:
57
+ r"""
58
+ Default processor for performing attention-related computations.
59
+ """
60
+
61
+ def __call__(
62
+ self,
63
+ attn,
64
+ hidden_states,
65
+ encoder_hidden_states=None,
66
+ attention_mask=None,
67
+ temb=None,
68
+ scale=1.0,
69
+ ):
70
+ residual = hidden_states
71
+
72
+ if attn.spatial_norm is not None:
73
+ hidden_states = attn.spatial_norm(hidden_states, temb)
74
+
75
+ input_ndim = hidden_states.ndim
76
+
77
+ if input_ndim == 4:
78
+ batch_size, channel, height, width = hidden_states.shape
79
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
80
+
81
+ batch_size, sequence_length, _ = (
82
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
83
+ )
84
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
85
+
86
+ if attn.group_norm is not None:
87
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
88
+
89
+ query = attn.to_q(hidden_states, scale=scale)
90
+
91
+ if encoder_hidden_states is None:
92
+ encoder_hidden_states = hidden_states
93
+ elif attn.norm_cross:
94
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
95
+
96
+ key = attn.to_k(encoder_hidden_states, scale=scale)
97
+ value = attn.to_v(encoder_hidden_states, scale=scale)
98
+
99
+ query = attn.head_to_batch_dim(query)
100
+ key = attn.head_to_batch_dim(key)
101
+ value = attn.head_to_batch_dim(value)
102
+
103
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
104
+ hidden_states = torch.bmm(attention_probs, value)
105
+ hidden_states = attn.batch_to_head_dim(hidden_states)
106
+
107
+ # linear proj
108
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
109
+ # dropout
110
+ hidden_states = attn.to_out[1](hidden_states)
111
+
112
+ if input_ndim == 4:
113
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
114
+
115
+ if attn.residual_connection:
116
+ hidden_states = hidden_states + residual
117
+
118
+ hidden_states = hidden_states / attn.rescale_output_factor
119
+
120
+ return hidden_states
loosecontrol.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ ControlNetModel,
3
+ StableDiffusionControlNetPipeline,
4
+ UniPCMultistepScheduler,
5
+ )
6
+ import torch
7
+ import PIL
8
+ import PIL.Image
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from typing import Dict
11
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
12
+ import functools
13
+ from cross_frame_attention import CrossFrameAttnProcessor
14
+
15
+ TEXT_ENCODER_NAME = "text_encoder"
16
+ UNET_NAME = "unet"
17
+ NEGATIVE_PROMPT = "blurry, text, caption, lowquality, lowresolution, low res, grainy, ugly"
18
+
19
+ def attach_loaders_mixin(model):
20
+ # hacky way to make ControlNet work with LoRA. This may not be required in future versions of diffusers.
21
+ model.text_encoder_name = TEXT_ENCODER_NAME
22
+ model.unet_name = UNET_NAME
23
+ r"""
24
+ Attach the [`UNet2DConditionLoadersMixin`] to a model. This will add the
25
+ all the methods from the mixin 'UNet2DConditionLoadersMixin' to the model.
26
+ """
27
+ # mixin_instance = UNet2DConditionLoadersMixin()
28
+ for attr_name, attr_value in vars(UNet2DConditionLoadersMixin).items():
29
+ # print(attr_name)
30
+ if callable(attr_value):
31
+ # setattr(model, attr_name, functools.partialmethod(attr_value, model).__get__(model, model.__class__))
32
+ setattr(model, attr_name, functools.partial(attr_value, model))
33
+ return model
34
+
35
+ def set_attn_processor(module, processor, _remove_lora=False):
36
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
37
+ if hasattr(module, "set_processor"):
38
+ if not isinstance(processor, dict):
39
+ module.set_processor(processor, _remove_lora=_remove_lora)
40
+ else:
41
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
42
+
43
+ for sub_name, child in module.named_children():
44
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
45
+
46
+ for name, module in module.named_children():
47
+ fn_recursive_attn_processor(name, module, processor)
48
+
49
+
50
+
51
+ class ControlNetX(ControlNetModel, UNet2DConditionLoadersMixin):
52
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
53
+ # This may not be required in future versions of diffusers.
54
+ @property
55
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
56
+ r"""
57
+ Returns:
58
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
59
+ indexed by its weight name.
60
+ """
61
+ # set recursively
62
+ processors = {}
63
+
64
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
65
+ if hasattr(module, "get_processor"):
66
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
67
+
68
+ for sub_name, child in module.named_children():
69
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
70
+
71
+ return processors
72
+
73
+ for name, module in self.named_children():
74
+ fn_recursive_add_processors(name, module, processors)
75
+
76
+ return processors
77
+
78
+ class ControlNetPipeline:
79
+ def __init__(self, checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
80
+ controlnet = ControlNetX.from_pretrained(checkpoint)
81
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
82
+ sd_checkpoint, controlnet=controlnet, requires_safety_checker=False, safety_checker=None,
83
+ torch_dtype=torch.float16)
84
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
85
+
86
+ @torch.no_grad()
87
+ def __call__(self,
88
+ prompt: str="",
89
+ height=512,
90
+ width=512,
91
+ control_image=None,
92
+ controlnet_conditioning_scale=1.0,
93
+ num_inference_steps: int=20,
94
+ **kwargs) -> PIL.Image.Image:
95
+
96
+ out = self.pipe(prompt, control_image,
97
+ height=height, width=width,
98
+ num_inference_steps=num_inference_steps,
99
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
100
+ **kwargs).images
101
+
102
+ return out[0] if len(out) == 1 else out
103
+
104
+ def to(self, *args, **kwargs):
105
+ self.pipe.to(*args, **kwargs)
106
+ return self
107
+
108
+
109
+ class LooseControlNet(ControlNetPipeline):
110
+ def __init__(self, loose_control_weights="shariqfarooq/loose-control-3dbox", cn_checkpoint="lllyasviel/control_v11f1p_sd15_depth", sd_checkpoint="runwayml/stable-diffusion-v1-5") -> None:
111
+ super().__init__(cn_checkpoint, sd_checkpoint)
112
+ self.pipe.controlnet = attach_loaders_mixin(self.pipe.controlnet)
113
+ self.pipe.controlnet.load_attn_procs(loose_control_weights)
114
+
115
+ def set_normal_attention(self):
116
+ self.pipe.unet.set_attn_processor(AttnProcessor())
117
+
118
+ def set_cf_attention(self, _remove_lora=False):
119
+ for upblocks in self.pipe.unet.up_blocks[-2:]:
120
+ set_attn_processor(upblocks, CrossFrameAttnProcessor(), _remove_lora=_remove_lora)
121
+
122
+ def edit(self, depth, depth_edit, prompt, prompt_edit=None, seed=42, seed_edit=None, negative_prompt=NEGATIVE_PROMPT, controlnet_conditioning_scale=1.0, num_inference_steps=20, **kwargs):
123
+ if prompt_edit is None:
124
+ prompt_edit = prompt
125
+
126
+ if seed_edit is None:
127
+ seed_edit = seed
128
+
129
+ seed = int(seed)
130
+ seed_edit = int(seed_edit)
131
+ control_image = [depth, depth_edit]
132
+ prompt = [prompt, prompt_edit]
133
+ generator = [torch.Generator().manual_seed(seed), torch.Generator().manual_seed(seed_edit)]
134
+ gen = self.pipe(prompt, control_image=control_image, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator, num_inference_steps=num_inference_steps, negative_prompt=negative_prompt, **kwargs)[-1]
135
+ return gen