radames commited on
Commit
b788820
1 Parent(s): c0b464a

Upload 3 files

Browse files
Files changed (3) hide show
  1. server/config.py +4 -0
  2. server/main.py +1 -0
  3. server/wrapper.py +21 -1
server/config.py CHANGED
@@ -2,6 +2,9 @@ from dataclasses import dataclass, field
2
  from typing import List
3
 
4
  import torch
 
 
 
5
 
6
 
7
  @dataclass
@@ -46,3 +49,4 @@ class Config:
46
  t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
47
  # Number of warmup steps
48
  warmup: int = 10
 
 
2
  from typing import List
3
 
4
  import torch
5
+ import os
6
+
7
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "False") == "True"
8
 
9
 
10
  @dataclass
 
49
  t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
50
  # Number of warmup steps
51
  warmup: int = 10
52
+ safety_checker: bool = SAFETY_CHECKER
server/main.py CHANGED
@@ -62,6 +62,7 @@ class Api:
62
  dtype=config.dtype,
63
  t_index_list=config.t_index_list,
64
  warmup=config.warmup,
 
65
  )
66
  self.app = FastAPI()
67
  self.app.add_api_route(
 
62
  dtype=config.dtype,
63
  t_index_list=config.t_index_list,
64
  warmup=config.warmup,
65
+ safety_checker=config.safety_checker,
66
  )
67
  self.app = FastAPI()
68
  self.app.add_api_route(
server/wrapper.py CHANGED
@@ -28,6 +28,7 @@ class StreamDiffusionWrapper:
28
  dtype: str,
29
  t_index_list: List[int],
30
  warmup: int,
 
31
  ):
32
  self.device = device
33
  self.dtype = dtype
@@ -40,6 +41,16 @@ class StreamDiffusionWrapper:
40
  t_index_list=t_index_list,
41
  warmup=warmup,
42
  )
 
 
 
 
 
 
 
 
 
 
43
 
44
  def _load_model(
45
  self,
@@ -104,7 +115,16 @@ class StreamDiffusionWrapper:
104
 
105
  x_output = self.stream.txt2img()
106
  if i >= 3:
107
- images.append(postprocess_image(x_output, output_type="pil")[0])
 
 
 
 
 
 
 
 
 
108
  end.record()
109
 
110
  torch.cuda.synchronize()
 
28
  dtype: str,
29
  t_index_list: List[int],
30
  warmup: int,
31
+ safety_checker: bool,
32
  ):
33
  self.device = device
34
  self.dtype = dtype
 
41
  t_index_list=t_index_list,
42
  warmup=warmup,
43
  )
44
+ self.safety_checker = None
45
+ if safety_checker:
46
+ from transformers import CLIPFeatureExtractor
47
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
48
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
49
+ "CompVis/stable-diffusion-safety-checker").to(self.device)
50
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
51
+ "openai/clip-vit-base-patch32")
52
+ self.nsfw_fallback_img = PIL.Image.new(
53
+ "RGB", (512, 512), (0, 0, 0))
54
 
55
  def _load_model(
56
  self,
 
115
 
116
  x_output = self.stream.txt2img()
117
  if i >= 3:
118
+ image = postprocess_image(x_output, output_type="pil")[0]
119
+ if self.safety_checker:
120
+ safety_checker_input = self.feature_extractor(
121
+ image, return_tensors="pt").to(self.device)
122
+ _, has_nsfw_concept = self.safety_checker(
123
+ images=x_output, clip_input=safety_checker_input.pixel_values.to(
124
+ self.dtype)
125
+ )
126
+ image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
127
+ images.append(image)
128
  end.record()
129
 
130
  torch.cuda.synchronize()