Spaces:
Sleeping
Sleeping
Vivien Chappelier
commited on
Commit
•
ca86cf6
1
Parent(s):
a48c785
update demo for SDXL-turbo
Browse files
app.py
CHANGED
@@ -1,34 +1,140 @@
|
|
1 |
-
import socketserver
|
2 |
-
socketserver.TCPServer.allow_reuse_address = True
|
3 |
-
|
4 |
import gradio as gr
|
5 |
|
|
|
6 |
import torch
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
# load the patched VQ-VAE
|
13 |
-
patched_decoder_ckpt = "checkpoint_000.pth"
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
#print("patching keys for first_stage_model: ", sd2.keys())
|
18 |
-
|
19 |
-
msg = pipe.vae.load_state_dict(sd2, strict=False)
|
20 |
-
print(f"loaded LDM decoder state_dict with message\n{msg}")
|
21 |
-
print("you should check that the decoder keys are correctly matched")
|
22 |
|
23 |
-
|
|
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
iface.launch(server_name="0.0.0.0")
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
import os
|
4 |
import torch
|
5 |
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
9 |
+
|
10 |
+
from diffusers import DiffusionPipeline
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
|
13 |
+
from copy import deepcopy
|
14 |
+
from collections import OrderedDict
|
15 |
+
|
16 |
+
import requests
|
17 |
+
import json
|
18 |
+
|
19 |
+
from PIL import Image, ImageEnhance
|
20 |
+
import base64
|
21 |
+
import io
|
22 |
+
|
23 |
+
class BZHStableSignatureDemo(object):
|
24 |
+
|
25 |
+
def __init__(self, *args, **kwargs):
|
26 |
+
super().__init__(*args, **kwargs)
|
27 |
+
self.pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
|
28 |
+
|
29 |
+
# load the patched VQ-VAEs
|
30 |
+
sd1 = deepcopy(self.pipe.vae.state_dict()) # save initial state dict
|
31 |
+
self.decoders = decoders = OrderedDict([("no watermark", sd1)])
|
32 |
+
for name, patched_decoder_ckpt in (
|
33 |
+
("weak", "models/stable_signature/checkpoint_000.pth.50000"),
|
34 |
+
("medium", "models/stable_signature/checkpoint_000.pth.150000"),
|
35 |
+
("strong", "models/stable_signature/checkpoint_000.pth.500000"),
|
36 |
+
("extreme", "models/stable_signature/checkpoint_000.pth.1500000")):
|
37 |
+
sd2 = torch.load(patched_decoder_ckpt)['ldm_decoder']
|
38 |
+
msg = self.pipe.vae.load_state_dict(sd2, strict=False)
|
39 |
+
print(f"loaded LDM decoder state_dict with message\n{msg}")
|
40 |
+
print("you should check that the decoder keys are correctly matched")
|
41 |
+
decoders[name] = sd2
|
42 |
+
self.decoders = decoders
|
43 |
+
|
44 |
+
def generate(self, mode, seed, prompt):
|
45 |
+
generator = torch.Generator(device=device)
|
46 |
+
if seed:
|
47 |
+
torch.manual_seed(seed)
|
48 |
+
|
49 |
+
# load the patched VAE decoder
|
50 |
+
sd = self.decoders[mode]
|
51 |
+
self.pipe.vae.load_state_dict(sd, strict=False)
|
52 |
+
|
53 |
+
output = self.pipe(prompt, num_inference_steps=4, guidance_scale=0.0, output_type="pil")
|
54 |
+
return output.images[0]
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def pad(img, padding, mode="edge"):
|
58 |
+
npimg = np.asarray(img)
|
59 |
+
nppad = ((padding[1], padding[3]), (padding[0], padding[2]), (0,0))
|
60 |
+
npimg = np.pad(npimg, nppad, mode=mode)
|
61 |
+
return Image.fromarray(npimg)
|
62 |
+
|
63 |
+
def attack_detect(self, img, jpeg_compression, downscale, saturation):
|
64 |
+
|
65 |
+
# attack
|
66 |
+
if downscale != 1:
|
67 |
+
size = img.size
|
68 |
+
size = (int(size[0] / downscale), int(size[1] / downscale))
|
69 |
+
img = img.resize(size, Image.BICUBIC)
|
70 |
+
|
71 |
+
converter = ImageEnhance.Color(img)
|
72 |
+
img = converter.enhance(saturation)
|
73 |
+
|
74 |
+
# send to detection API and apply JPEG compression attack
|
75 |
+
mf = io.BytesIO()
|
76 |
+
img.save(mf, format='JPEG', quality=jpeg_compression) # includes JPEG attack
|
77 |
+
b64 = base64.b64encode(mf.getvalue())
|
78 |
+
data = {
|
79 |
+
'image': b64.decode('utf8')
|
80 |
+
}
|
81 |
+
|
82 |
+
headers = {}
|
83 |
+
api_key = os.environ.get('BZH_API_KEY', None)
|
84 |
+
if api_key:
|
85 |
+
headers['BZH_API_KEY'] = api_key
|
86 |
+
response = requests.post('https://bzh.imatag.com/bzh/api/v1.0/detect',
|
87 |
+
json=data, headers=headers)
|
88 |
+
response.raise_for_status()
|
89 |
+
data = response.json()
|
90 |
+
pvalue = data['p-value']
|
91 |
|
92 |
+
mf.seek(0)
|
93 |
+
img0 = Image.open(mf) # reload to show JPEG attack
|
94 |
+
#result = "resolution = %dx%d p-value = %e" % (img.size[0], img.size[1], pvalue))
|
95 |
+
result = "No watermark detected."
|
96 |
+
chances = int(1 / pvalue + 1)
|
97 |
+
if pvalue < 1e-3:
|
98 |
+
result = "Weak watermark detected (< 1/%d chances of being wrong)" % chances
|
99 |
+
if pvalue < 1e-6:
|
100 |
+
result = "Strong watermark detected (< 1/%d chances of being wrong)" % chances
|
101 |
+
return (img0, result)
|
102 |
|
|
|
|
|
103 |
|
104 |
+
def interface():
|
105 |
+
prompt = "sailing ship in storm by Rembrandt"
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
backend = BZHStableSignatureDemo()
|
108 |
+
decoders = list(backend.decoders.keys())
|
109 |
|
110 |
+
with gr.Blocks() as demo:
|
111 |
+
gr.Markdown("""# Watermarked SDXL-Turbo demo
|
112 |
+
This demo presents watermarking of images generated via StableDiffusion XL Turbo.
|
113 |
+
Using the method presented in [StableSignature](https://ai.meta.com/blog/stable-signature-watermarking-generative-ai/),
|
114 |
+
the VAE decoder of StableDiffusion is fine-tuned to produce images including a specific invisible watermark. We combined
|
115 |
+
this method with our in-house decoder which operates in zero-bit mode for improved robustness.""")
|
116 |
|
117 |
+
with gr.Row():
|
118 |
+
inp = gr.Textbox(label="Prompt", value=prompt)
|
119 |
+
seed = gr.Number(label="Seed", precision=0)
|
120 |
+
mode = gr.Dropdown(choices=decoders, label="Watermark strength", value="medium")
|
121 |
+
with gr.Row():
|
122 |
+
btn1 = gr.Button("Generate")
|
123 |
+
with gr.Row():
|
124 |
+
watermarked_image = gr.Image(type="pil", tool="select").style(width=512, height=512)
|
125 |
+
with gr.Column():
|
126 |
+
downscale = gr.Slider(1, 3, value=1, step=0.1, label="Downscale ratio")
|
127 |
+
saturation = gr.Slider(0, 2, value=1, step=0.1, label="Color saturation")
|
128 |
+
jpeg_compression = gr.Slider(value=100, step=5, label="JPEG quality")
|
129 |
+
btn2 = gr.Button("Attack & Detect")
|
130 |
+
with gr.Row():
|
131 |
+
attacked_image = gr.Image(type="pil", tool="select").style(width=256)
|
132 |
+
detection_label = gr.Label(label="Detection info")
|
133 |
+
btn1.click(fn=backend.generate, inputs=[mode, seed, inp], outputs=[watermarked_image], api_name="generate")
|
134 |
+
btn2.click(fn=backend.attack_detect, inputs=[watermarked_image, jpeg_compression, downscale, saturation], outputs=[attacked_image, detection_label], api_name="detect")
|
135 |
|
136 |
+
return demo
|
|
|
137 |
|
138 |
+
if __name__ == '__main__':
|
139 |
+
demo = interface()
|
140 |
+
demo.launch()
|