saassa commited on
Commit
76d27e6
1 Parent(s): 463e137

Upload controlc.py

Browse files
Files changed (1) hide show
  1. controlc.py +117 -0
controlc.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ from pathlib import Path
5
+ import base64
6
+ import io
7
+
8
+ requirements = [
9
+ "controlnet-aux",
10
+ "diffusers",
11
+ "torch",
12
+ "mediapipe",
13
+ "transformers",
14
+ "accelerate",
15
+ "xformers"
16
+ ]
17
+
18
+
19
+ def get_image_from_url_as_bytes(url: str) -> bytes:
20
+ import requests
21
+
22
+ response = requests.get(url)
23
+ # This will raise an exception if the request returned an HTTP error code
24
+ response.raise_for_status()
25
+ return response.content
26
+
27
+ def read_image_bytes(file_path):
28
+ with open(file_path, "rb") as file:
29
+ image_bytes = file.read()
30
+ return image_bytes
31
+
32
+
33
+ def load_model():
34
+ import torch
35
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
36
+
37
+ controlnet = ControlNetModel.from_pretrained(
38
+ "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
39
+ )
40
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
41
+ "peterwilli/deliberate-2", controlnet=controlnet, torch_dtype=torch.float16
42
+ )
43
+
44
+ pipe = pipe.to("cuda:0")
45
+ pipe.unet.to(memory_format=torch.channels_last)
46
+ pipe.controlnet.to(memory_format=torch.channels_last)
47
+ return pipe
48
+
49
+
50
+ def resize_image(input_image, resolution):
51
+ import cv2
52
+ import numpy as np
53
+
54
+ H, W, C = input_image.shape
55
+ H = float(H)
56
+ W = float(W)
57
+ k = float(resolution) / min(H, W)
58
+ H *= k
59
+ W *= k
60
+ H = int(np.round(H / 64.0)) * 64
61
+ W = int(np.round(W / 64.0)) * 64
62
+ img = cv2.resize(
63
+ input_image,
64
+ (W, H),
65
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
66
+ )
67
+ return img
68
+
69
+
70
+ def generate(
71
+ image_url: str, prompt: str, num_samples: int, num_steps: int, gcs=False
72
+ ) -> list[bytes] | None:
73
+
74
+ from controlnet_aux import CannyDetector
75
+ from PIL import Image
76
+ import numpy as np
77
+ import uuid
78
+ import os
79
+ from base64 import b64encode
80
+
81
+ image_bytes = get_image_from_url_as_bytes(image_url)
82
+
83
+ pipe = load_model()
84
+ image = Image.open(io.BytesIO(image_bytes))
85
+
86
+ canny = CannyDetector()
87
+ init_image = image.convert("RGB")
88
+
89
+ init_image = resize_image(np.asarray(init_image), 512)
90
+ detected_map = canny(init_image, 100, 200)
91
+ image = Image.fromarray(detected_map)
92
+
93
+ negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
94
+ results = pipe(
95
+ prompt=prompt,
96
+ image=image,
97
+ negative_prompt=negative_prompt,
98
+ num_inference_steps=num_steps,
99
+ num_images_per_prompt=num_samples
100
+ ).images
101
+
102
+ result_id = uuid.uuid4()
103
+ out_dir = Path(f"/data/cn-results/{result_id}")
104
+ out_dir.mkdir(parents=True, exist_ok=True)
105
+
106
+
107
+ for i, res in enumerate(results):
108
+ res.save(out_dir / f"res_{i}.png")
109
+
110
+ file_names = [
111
+ f for f in os.listdir(out_dir) if os.path.isfile(os.path.join(out_dir, f))
112
+ ]
113
+
114
+ list_of_bytes = [read_image_bytes(out_dir / f) for f in file_names]
115
+ raw_image = list_of_bytes[0]
116
+
117
+ return b64encode(raw_image).decode("utf-8")