saassa commited on
Commit
463e137
1 Parent(s): 5517704

Upload control.py

Browse files
Files changed (1) hide show
  1. control.py +122 -0
control.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from fal_serverless import isolated, cached
3
+
4
+ from pathlib import Path
5
+ import base64
6
+ import io
7
+
8
+ requirements = [
9
+ "controlnet-aux",
10
+ "diffusers",
11
+ "torch",
12
+ "mediapipe",
13
+ "transformers",
14
+ "accelerate",
15
+ "xformers"
16
+ ]
17
+
18
+
19
+ def get_image_from_url_as_bytes(url: str) -> bytes:
20
+ import requests
21
+
22
+ response = requests.get(url)
23
+ # This will raise an exception if the request returned an HTTP error code
24
+ response.raise_for_status()
25
+ return response.content
26
+
27
+ def read_image_bytes(file_path):
28
+ with open(file_path, "rb") as file:
29
+ image_bytes = file.read()
30
+ return image_bytes
31
+
32
+ @cached
33
+ def load_model():
34
+ import torch
35
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
36
+
37
+ controlnet = ControlNetModel.from_pretrained(
38
+ "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
39
+ )
40
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
41
+ "peterwilli/deliberate-2", controlnet=controlnet, torch_dtype=torch.float16
42
+ )
43
+
44
+ pipe = pipe.to("cuda:0")
45
+ pipe.unet.to(memory_format=torch.channels_last)
46
+ pipe.controlnet.to(memory_format=torch.channels_last)
47
+ return pipe
48
+
49
+
50
+ def resize_image(input_image, resolution):
51
+ import cv2
52
+ import numpy as np
53
+
54
+ H, W, C = input_image.shape
55
+ H = float(H)
56
+ W = float(W)
57
+ k = float(resolution) / min(H, W)
58
+ H *= k
59
+ W *= k
60
+ H = int(np.round(H / 64.0)) * 64
61
+ W = int(np.round(W / 64.0)) * 64
62
+ img = cv2.resize(
63
+ input_image,
64
+ (W, H),
65
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA,
66
+ )
67
+ return img
68
+
69
+ @isolated(
70
+ requirements=requirements,
71
+ machine_type="GPU",
72
+ keep_alive=30,
73
+ serve=True
74
+ )
75
+ def generate(
76
+ image_url: str, prompt: str, num_samples: int, num_steps: int, gcs=False
77
+ ) -> list[bytes] | None:
78
+
79
+ from controlnet_aux import CannyDetector
80
+ from PIL import Image
81
+ import numpy as np
82
+ import uuid
83
+ import os
84
+ from base64 import b64encode
85
+
86
+ image_bytes = get_image_from_url_as_bytes(image_url)
87
+
88
+ pipe = load_model()
89
+ image = Image.open(io.BytesIO(image_bytes))
90
+
91
+ canny = CannyDetector()
92
+ init_image = image.convert("RGB")
93
+
94
+ init_image = resize_image(np.asarray(init_image), 512)
95
+ detected_map = canny(init_image, 100, 200)
96
+ image = Image.fromarray(detected_map)
97
+
98
+ negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
99
+ results = pipe(
100
+ prompt=prompt,
101
+ image=image,
102
+ negative_prompt=negative_prompt,
103
+ num_inference_steps=num_steps,
104
+ num_images_per_prompt=num_samples
105
+ ).images
106
+
107
+ result_id = uuid.uuid4()
108
+ out_dir = Path(f"/data/cn-results/{result_id}")
109
+ out_dir.mkdir(parents=True, exist_ok=True)
110
+
111
+
112
+ for i, res in enumerate(results):
113
+ res.save(out_dir / f"res_{i}.png")
114
+
115
+ file_names = [
116
+ f for f in os.listdir(out_dir) if os.path.isfile(os.path.join(out_dir, f))
117
+ ]
118
+
119
+ list_of_bytes = [read_image_bytes(out_dir / f) for f in file_names]
120
+ raw_image = list_of_bytes[0]
121
+
122
+ return b64encode(raw_image).decode("utf-8")