Upload 3 files
Browse files- server/config.py +4 -0
- server/main.py +1 -0
- server/wrapper.py +21 -1
server/config.py
CHANGED
@@ -2,6 +2,9 @@ from dataclasses import dataclass, field
|
|
2 |
from typing import List
|
3 |
|
4 |
import torch
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
@dataclass
|
@@ -46,3 +49,4 @@ class Config:
|
|
46 |
t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
|
47 |
# Number of warmup steps
|
48 |
warmup: int = 10
|
|
|
|
2 |
from typing import List
|
3 |
|
4 |
import torch
|
5 |
+
import os
|
6 |
+
|
7 |
+
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "False") == "True"
|
8 |
|
9 |
|
10 |
@dataclass
|
|
|
49 |
t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
|
50 |
# Number of warmup steps
|
51 |
warmup: int = 10
|
52 |
+
safety_checker: bool = SAFETY_CHECKER
|
server/main.py
CHANGED
@@ -62,6 +62,7 @@ class Api:
|
|
62 |
dtype=config.dtype,
|
63 |
t_index_list=config.t_index_list,
|
64 |
warmup=config.warmup,
|
|
|
65 |
)
|
66 |
self.app = FastAPI()
|
67 |
self.app.add_api_route(
|
|
|
62 |
dtype=config.dtype,
|
63 |
t_index_list=config.t_index_list,
|
64 |
warmup=config.warmup,
|
65 |
+
safety_checker=config.safety_checker,
|
66 |
)
|
67 |
self.app = FastAPI()
|
68 |
self.app.add_api_route(
|
server/wrapper.py
CHANGED
@@ -28,6 +28,7 @@ class StreamDiffusionWrapper:
|
|
28 |
dtype: str,
|
29 |
t_index_list: List[int],
|
30 |
warmup: int,
|
|
|
31 |
):
|
32 |
self.device = device
|
33 |
self.dtype = dtype
|
@@ -40,6 +41,16 @@ class StreamDiffusionWrapper:
|
|
40 |
t_index_list=t_index_list,
|
41 |
warmup=warmup,
|
42 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def _load_model(
|
45 |
self,
|
@@ -104,7 +115,16 @@ class StreamDiffusionWrapper:
|
|
104 |
|
105 |
x_output = self.stream.txt2img()
|
106 |
if i >= 3:
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
end.record()
|
109 |
|
110 |
torch.cuda.synchronize()
|
|
|
28 |
dtype: str,
|
29 |
t_index_list: List[int],
|
30 |
warmup: int,
|
31 |
+
safety_checker: bool,
|
32 |
):
|
33 |
self.device = device
|
34 |
self.dtype = dtype
|
|
|
41 |
t_index_list=t_index_list,
|
42 |
warmup=warmup,
|
43 |
)
|
44 |
+
self.safety_checker = None
|
45 |
+
if safety_checker:
|
46 |
+
from transformers import CLIPFeatureExtractor
|
47 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
48 |
+
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
49 |
+
"CompVis/stable-diffusion-safety-checker").to(self.device)
|
50 |
+
self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
51 |
+
"openai/clip-vit-base-patch32")
|
52 |
+
self.nsfw_fallback_img = PIL.Image.new(
|
53 |
+
"RGB", (512, 512), (0, 0, 0))
|
54 |
|
55 |
def _load_model(
|
56 |
self,
|
|
|
115 |
|
116 |
x_output = self.stream.txt2img()
|
117 |
if i >= 3:
|
118 |
+
image = postprocess_image(x_output, output_type="pil")[0]
|
119 |
+
if self.safety_checker:
|
120 |
+
safety_checker_input = self.feature_extractor(
|
121 |
+
image, return_tensors="pt").to(self.device)
|
122 |
+
_, has_nsfw_concept = self.safety_checker(
|
123 |
+
images=x_output, clip_input=safety_checker_input.pixel_values.to(
|
124 |
+
self.dtype)
|
125 |
+
)
|
126 |
+
image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
|
127 |
+
images.append(image)
|
128 |
end.record()
|
129 |
|
130 |
torch.cuda.synchronize()
|