Vivien Chappelier commited on
Commit
ca86cf6
1 Parent(s): a48c785

update demo for SDXL-turbo

Browse files
Files changed (1) hide show
  1. app.py +128 -22
app.py CHANGED
@@ -1,34 +1,140 @@
1
- import socketserver
2
- socketserver.TCPServer.allow_reuse_address = True
3
-
4
  import gradio as gr
5
 
 
6
  import torch
7
 
8
- from diffusers import StableDiffusionPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
11
 
12
- # load the patched VQ-VAE
13
- patched_decoder_ckpt = "checkpoint_000.pth"
14
 
15
- if patched_decoder_ckpt is not None:
16
- sd2 = torch.load(patched_decoder_ckpt)['ldm_decoder']
17
- #print("patching keys for first_stage_model: ", sd2.keys())
18
-
19
- msg = pipe.vae.load_state_dict(sd2, strict=False)
20
- print(f"loaded LDM decoder state_dict with message\n{msg}")
21
- print("you should check that the decoder keys are correctly matched")
22
 
23
- pipe = pipe.to("cuda")
 
24
 
25
- prompt = "sailing ship in storm by Rembrandt"
 
 
 
 
 
26
 
27
- def generate(prompt):
28
- output = pipe(prompt, num_inference_steps=50, output_type="pil")
29
- output.images[0].save("result.png")
30
- return output.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- iface = gr.Interface(fn=generate, inputs=[gr.Textbox(label="Prompt", value=prompt)], outputs=[gr.Image(type="pil")])
33
- iface.launch(server_name="0.0.0.0")
34
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ import os
4
  import torch
5
 
6
+ import numpy as np
7
+
8
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
+
10
+ from diffusers import DiffusionPipeline
11
+ import torchvision.transforms as transforms
12
+
13
+ from copy import deepcopy
14
+ from collections import OrderedDict
15
+
16
+ import requests
17
+ import json
18
+
19
+ from PIL import Image, ImageEnhance
20
+ import base64
21
+ import io
22
+
23
+ class BZHStableSignatureDemo(object):
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
28
+
29
+ # load the patched VQ-VAEs
30
+ sd1 = deepcopy(self.pipe.vae.state_dict()) # save initial state dict
31
+ self.decoders = decoders = OrderedDict([("no watermark", sd1)])
32
+ for name, patched_decoder_ckpt in (
33
+ ("weak", "models/stable_signature/checkpoint_000.pth.50000"),
34
+ ("medium", "models/stable_signature/checkpoint_000.pth.150000"),
35
+ ("strong", "models/stable_signature/checkpoint_000.pth.500000"),
36
+ ("extreme", "models/stable_signature/checkpoint_000.pth.1500000")):
37
+ sd2 = torch.load(patched_decoder_ckpt)['ldm_decoder']
38
+ msg = self.pipe.vae.load_state_dict(sd2, strict=False)
39
+ print(f"loaded LDM decoder state_dict with message\n{msg}")
40
+ print("you should check that the decoder keys are correctly matched")
41
+ decoders[name] = sd2
42
+ self.decoders = decoders
43
+
44
+ def generate(self, mode, seed, prompt):
45
+ generator = torch.Generator(device=device)
46
+ if seed:
47
+ torch.manual_seed(seed)
48
+
49
+ # load the patched VAE decoder
50
+ sd = self.decoders[mode]
51
+ self.pipe.vae.load_state_dict(sd, strict=False)
52
+
53
+ output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
54
+ return output.images[0]
55
+
56
+ @staticmethod
57
+ def pad(img, padding, mode="edge"):
58
+ npimg = np.asarray(img)
59
+ nppad = ((padding[1], padding[3]), (padding[0], padding[2]), (0,0))
60
+ npimg = np.pad(npimg, nppad, mode=mode)
61
+ return Image.fromarray(npimg)
62
+
63
+ def attack_detect(self, img, jpeg_compression, downscale, saturation):
64
+
65
+ # attack
66
+ if downscale != 1:
67
+ size = img.size
68
+ size = (int(size[0] / downscale), int(size[1] / downscale))
69
+ img = img.resize(size, Image.BICUBIC)
70
+
71
+ converter = ImageEnhance.Color(img)
72
+ img = converter.enhance(saturation)
73
+
74
+ # send to detection API and apply JPEG compression attack
75
+ mf = io.BytesIO()
76
+ img.save(mf, format='JPEG', quality=jpeg_compression) # includes JPEG attack
77
+ b64 = base64.b64encode(mf.getvalue())
78
+ data = {
79
+ 'image': b64.decode('utf8')
80
+ }
81
+
82
+ headers = {}
83
+ api_key = os.environ.get('BZH_API_KEY', None)
84
+ if api_key:
85
+ headers['BZH_API_KEY'] = api_key
86
+ response = requests.post('https://bzh.imatag.com/bzh/api/v1.0/detect',
87
+ json=data, headers=headers)
88
+ response.raise_for_status()
89
+ data = response.json()
90
+ pvalue = data['p-value']
91
 
92
+ mf.seek(0)
93
+ img0 = Image.open(mf) # reload to show JPEG attack
94
+ #result = "resolution = %dx%d p-value = %e" % (img.size[0], img.size[1], pvalue))
95
+ result = "No watermark detected."
96
+ chances = int(1 / pvalue + 1)
97
+ if pvalue < 1e-3:
98
+ result = "Weak watermark detected (< 1/%d chances of being wrong)" % chances
99
+ if pvalue < 1e-6:
100
+ result = "Strong watermark detected (< 1/%d chances of being wrong)" % chances
101
+ return (img0, result)
102
 
 
 
103
 
104
+ def interface():
105
+ prompt = "sailing ship in storm by Rembrandt"
 
 
 
 
 
106
 
107
+ backend = BZHStableSignatureDemo()
108
+ decoders = list(backend.decoders.keys())
109
 
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown("""# Watermarked SDXL-Turbo demo
112
+ This demo presents watermarking of images generated via StableDiffusion XL Turbo.
113
+ Using the method presented in [StableSignature](https://ai.meta.com/blog/stable-signature-watermarking-generative-ai/),
114
+ the VAE decoder of StableDiffusion is fine-tuned to produce images including a specific invisible watermark. We combined
115
+ this method with our in-house decoder which operates in zero-bit mode for improved robustness.""")
116
 
117
+ with gr.Row():
118
+ inp = gr.Textbox(label="Prompt", value=prompt)
119
+ seed = gr.Number(label="Seed", precision=0)
120
+ mode = gr.Dropdown(choices=decoders, label="Watermark strength", value="medium")
121
+ with gr.Row():
122
+ btn1 = gr.Button("Generate")
123
+ with gr.Row():
124
+ watermarked_image = gr.Image(type="pil", tool="select").style(width=512, height=512)
125
+ with gr.Column():
126
+ downscale = gr.Slider(1, 3, value=1, step=0.1, label="Downscale ratio")
127
+ saturation = gr.Slider(0, 2, value=1, step=0.1, label="Color saturation")
128
+ jpeg_compression = gr.Slider(value=100, step=5, label="JPEG quality")
129
+ btn2 = gr.Button("Attack & Detect")
130
+ with gr.Row():
131
+ attacked_image = gr.Image(type="pil", tool="select").style(width=256)
132
+ detection_label = gr.Label(label="Detection info")
133
+ btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
134
+ btn2.click(fn=backend.attack_detect, inputs=[watermarked_image, jpeg_compression, downscale, saturation], outputs=[attacked_image, detection_label], api_name="detect")
135
 
136
+ return demo
 
137
 
138
+ if __name__ == '__main__':
139
+ demo = interface()
140
+ demo.launch()