File size: 2,569 Bytes
7dd6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import gradio as gr

from modules import scripts
import ldm_patched.ldm.modules.attention as attention


def sdp(q, k, v, transformer_options):
    return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None)


class StyleAlignForForge(scripts.Script):
    sorting_priority = 17

    def title(self):
        return "StyleAlign Integrated"

    def show(self, is_img2img):
        # make this extension visible in both txt2img and img2img tab.
        return scripts.AlwaysVisible

    def ui(self, *args, **kwargs):
        with gr.Accordion(open=False, label=self.title()):
            shared_attention = gr.Checkbox(label='Share attention in batch', value=False)

        return [shared_attention]

    def process_before_every_sampling(self, p, *script_args, **kwargs):
        # This will be called before every sampling.
        # If you use highres fix, this will be called twice.

        shared_attention = script_args[0]

        if not shared_attention:
            return

        unet = p.sd_model.forge_objects.unet.clone()

        def join(x):
            b, f, c = x.shape
            return x.reshape(1, b * f, c)

        def aligned_attention(q, k, v, transformer_options):
            b, f, c = q.shape
            o = sdp(join(q), join(k), join(v), transformer_options)
            b2, f2, c2 = o.shape
            o = o.reshape(b, b2 * f2 // b, c2)
            return o

        def attn1_proc(q, k, v, transformer_options):
            cond_indices = transformer_options['cond_indices']
            uncond_indices = transformer_options['uncond_indices']
            cond_or_uncond = transformer_options['cond_or_uncond']
            results = []

            for cx in cond_or_uncond:
                if cx == 0:
                    indices = cond_indices
                else:
                    indices = uncond_indices

                if len(indices) > 0:
                    bq, bk, bv = q[indices], k[indices], v[indices]
                    bo = aligned_attention(bq, bk, bv, transformer_options)
                    results.append(bo)

            results = torch.cat(results, dim=0)
            return results

        unet.set_model_replace_all(attn1_proc, 'attn1')

        p.sd_model.forge_objects.unet = unet

        # Below codes will add some logs to the texts below the image outputs on UI.
        # The extra_generation_params does not influence results.
        p.extra_generation_params.update(dict(
            stylealign_enabled=shared_attention,
        ))

        return