radames commited on
Commit
8eae2b8
1 Parent(s): b788820

Upload 21 files

Browse files
server/.DS_Store ADDED
Binary file (6.15 kB). View file
 
server/config.py CHANGED
@@ -21,18 +21,11 @@ class Config:
21
  port: int = 9090
22
  workers: int = 1
23
 
24
- ####################################################################
25
- # Generation configuration
26
- ####################################################################
27
- # The threshold for the Levenstein distance.
28
- levenstein_distance_threshold: int = 3
29
-
30
  ####################################################################
31
  # Model configuration
32
  ####################################################################
33
  # SD1.x variant model
34
  model_id: str = "SimianLuo/LCM_Dreamshaper_v7"
35
-
36
  # LCM-LORA model
37
  lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
38
  # TinyVAE model
 
21
  port: int = 9090
22
  workers: int = 1
23
 
 
 
 
 
 
 
24
  ####################################################################
25
  # Model configuration
26
  ####################################################################
27
  # SD1.x variant model
28
  model_id: str = "SimianLuo/LCM_Dreamshaper_v7"
 
29
  # LCM-LORA model
30
  lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
31
  # TinyVAE model
server/main.py CHANGED
@@ -32,7 +32,7 @@ class PredictResponseModel(BaseModel):
32
  The response model for the /predict endpoint.
33
  """
34
 
35
- base64_images: list[str]
36
 
37
 
38
  class UpdatePromptResponseModel(BaseModel):
@@ -86,7 +86,6 @@ class Api:
86
  self._update_prompt_lock = asyncio.Lock()
87
 
88
  self.last_prompt: str = ""
89
- self.last_images: list[str] = [""]
90
 
91
  async def _predict(self, inp: PredictInputModel) -> PredictResponseModel:
92
  """
@@ -103,15 +102,7 @@ class Api:
103
  The prediction result.
104
  """
105
  async with self._predict_lock:
106
- if (
107
- self._calc_levenstein_distance(inp.prompt, self.last_prompt)
108
- < self.config.levenstein_distance_threshold
109
- ):
110
- logger.info("Using cached images")
111
- return PredictResponseModel(base64_images=self.last_images)
112
- self.last_prompt = inp.prompt
113
- self.last_images = [self._pil_to_base64(image) for image in self.stream_diffusion(inp.prompt)]
114
- return PredictResponseModel(base64_images=self.last_images)
115
 
116
  def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes:
117
  """
@@ -152,52 +143,6 @@ class Api:
152
  base64_image = base64_image.split("base64,")[1]
153
  return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB")
154
 
155
- def _calc_levenstein_distance(self, a: str, b: str) -> int:
156
- """
157
- Calculate the Levenstein distance between two strings.
158
-
159
- Parameters
160
- ----------
161
- a : str
162
- The first string.
163
-
164
- b : str
165
- The second string.
166
-
167
- Returns
168
- -------
169
- int
170
- The Levenstein distance.
171
- """
172
- if a == b:
173
- return 0
174
- a_k = len(a)
175
- b_k = len(b)
176
- if a == "":
177
- return b_k
178
- if b == "":
179
- return a_k
180
- matrix = [[] for i in range(a_k + 1)]
181
- for i in range(a_k + 1):
182
- matrix[i] = [0 for j in range(b_k + 1)]
183
- for i in range(a_k + 1):
184
- matrix[i][0] = i
185
- for j in range(b_k + 1):
186
- matrix[0][j] = j
187
- for i in range(1, a_k + 1):
188
- ac = a[i - 1]
189
- for j in range(1, b_k + 1):
190
- bc = b[j - 1]
191
- cost = 0 if (ac == bc) else 1
192
- matrix[i][j] = min(
193
- [
194
- matrix[i - 1][j] + 1,
195
- matrix[i][j - 1] + 1,
196
- matrix[i - 1][j - 1] + cost,
197
- ]
198
- )
199
- return matrix[a_k][b_k]
200
-
201
 
202
  if __name__ == "__main__":
203
  from config import Config
 
32
  The response model for the /predict endpoint.
33
  """
34
 
35
+ base64_image: str
36
 
37
 
38
  class UpdatePromptResponseModel(BaseModel):
 
86
  self._update_prompt_lock = asyncio.Lock()
87
 
88
  self.last_prompt: str = ""
 
89
 
90
  async def _predict(self, inp: PredictInputModel) -> PredictResponseModel:
91
  """
 
102
  The prediction result.
103
  """
104
  async with self._predict_lock:
105
+ return PredictResponseModel(base64_image=self._pil_to_base64(self.stream_diffusion(inp.prompt)))
 
 
 
 
 
 
 
 
106
 
107
  def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes:
108
  """
 
143
  base64_image = base64_image.split("base64,")[1]
144
  return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB")
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  from config import Config
server/requirements.txt CHANGED
@@ -9,4 +9,5 @@ torchvision
9
  torchaudio
10
  triton
11
  # https://github.com/chengzeyi/stable-fast --index-url https://download.pytorch.org/whl/cu121
12
- https://github.com/chengzeyi/stable-fast/releases/download/v0.0.14/stable_fast-0.0.14+torch210cu121-cp310-cp310-manylinux2014_x86_64.whl
 
 
9
  torchaudio
10
  triton
11
  # https://github.com/chengzeyi/stable-fast --index-url https://download.pytorch.org/whl/cu121
12
+ # https://github.com/chengzeyi/stable-fast/releases/download/v0.0.14/stable_fast-0.0.14+torch210cu121-cp310-cp310-manylinux2014_x86_64.whl
13
+ https://github.com/chengzeyi/stable-fast/releases/download/v0.0.15.post1/stable_fast-0.0.15.post1+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl
server/wrapper.py CHANGED
@@ -8,7 +8,6 @@ import torch
8
  from diffusers import AutoencoderTiny, StableDiffusionPipeline
9
 
10
  from streamdiffusion import StreamDiffusion
11
- from streamdiffusion.acceleration.sfast import accelerate_with_stable_fast
12
  from streamdiffusion.image_utils import postprocess_image
13
 
14
 
@@ -33,6 +32,7 @@ class StreamDiffusionWrapper:
33
  self.device = device
34
  self.dtype = dtype
35
  self.prompt = ""
 
36
 
37
  self.stream = self._load_model(
38
  model_id=model_id,
@@ -44,13 +44,18 @@ class StreamDiffusionWrapper:
44
  self.safety_checker = None
45
  if safety_checker:
46
  from transformers import CLIPFeatureExtractor
47
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
 
 
48
  self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
49
- "CompVis/stable-diffusion-safety-checker").to(self.device)
 
50
  self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
51
- "openai/clip-vit-base-patch32")
52
- self.nsfw_fallback_img = PIL.Image.new(
53
- "RGB", (512, 512), (0, 0, 0))
 
54
 
55
  def _load_model(
56
  self,
@@ -61,13 +66,13 @@ class StreamDiffusionWrapper:
61
  warmup: int,
62
  ):
63
  if os.path.exists(model_id):
64
- pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(model_id).to(
65
- device=self.device, dtype=self.dtype
66
- )
67
  else:
68
- pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(model_id).to(
69
- device=self.device, dtype=self.dtype
70
- )
71
 
72
  stream = StreamDiffusion(
73
  pipe=pipe,
@@ -77,8 +82,32 @@ class StreamDiffusionWrapper:
77
  )
78
  stream.load_lcm_lora(lcm_lora_id)
79
  stream.fuse_lora()
80
- stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(device=pipe.device, dtype=pipe.dtype)
81
- stream = accelerate_with_stable_fast(stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  stream.prepare(
84
  "",
@@ -99,37 +128,27 @@ class StreamDiffusionWrapper:
99
 
100
  return stream
101
 
102
- def __call__(self, prompt: str) -> List[PIL.Image.Image]:
103
- self.stream.prepare("")
104
-
105
- images = []
106
- for i in range(9 + 3):
107
- start = torch.cuda.Event(enable_timing=True)
108
- end = torch.cuda.Event(enable_timing=True)
109
-
110
- start.record()
111
-
112
- if self.prompt != prompt:
113
- self.stream.update_prompt(prompt)
114
- self.prompt = prompt
115
-
116
- x_output = self.stream.txt2img()
117
- if i >= 3:
118
- image = postprocess_image(x_output, output_type="pil")[0]
119
- if self.safety_checker:
120
- safety_checker_input = self.feature_extractor(
121
- image, return_tensors="pt").to(self.device)
122
- _, has_nsfw_concept = self.safety_checker(
123
- images=x_output, clip_input=safety_checker_input.pixel_values.to(
124
- self.dtype)
125
- )
126
- image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
127
- images.append(image)
128
- end.record()
129
-
130
- torch.cuda.synchronize()
131
 
132
- return images
133
 
134
 
135
  if __name__ == "__main__":
 
8
  from diffusers import AutoencoderTiny, StableDiffusionPipeline
9
 
10
  from streamdiffusion import StreamDiffusion
 
11
  from streamdiffusion.image_utils import postprocess_image
12
 
13
 
 
32
  self.device = device
33
  self.dtype = dtype
34
  self.prompt = ""
35
+ self.batch_size = len(t_index_list)
36
 
37
  self.stream = self._load_model(
38
  model_id=model_id,
 
44
  self.safety_checker = None
45
  if safety_checker:
46
  from transformers import CLIPFeatureExtractor
47
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
48
+ StableDiffusionSafetyChecker,
49
+ )
50
+
51
  self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
52
+ "CompVis/stable-diffusion-safety-checker"
53
+ ).to(self.device)
54
  self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
55
+ "openai/clip-vit-base-patch32"
56
+ )
57
+ self.nsfw_fallback_img = PIL.Image.new("RGB", (512, 512), (0, 0, 0))
58
+ self.stream.prepare("")
59
 
60
  def _load_model(
61
  self,
 
66
  warmup: int,
67
  ):
68
  if os.path.exists(model_id):
69
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(
70
+ model_id
71
+ ).to(device=self.device, dtype=self.dtype)
72
  else:
73
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
74
+ model_id
75
+ ).to(device=self.device, dtype=self.dtype)
76
 
77
  stream = StreamDiffusion(
78
  pipe=pipe,
 
82
  )
83
  stream.load_lcm_lora(lcm_lora_id)
84
  stream.fuse_lora()
85
+ stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(
86
+ device=pipe.device, dtype=pipe.dtype
87
+ )
88
+
89
+ try:
90
+ from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt
91
+
92
+ stream = accelerate_with_tensorrt(
93
+ stream,
94
+ "engines",
95
+ max_batch_size=self.batch_size,
96
+ engine_build_options={"build_static_batch": False},
97
+ )
98
+ print("TensorRT acceleration enabled.")
99
+ except Exception:
100
+ print("TensorRT acceleration has failed. Trying to use Stable Fast.")
101
+ try:
102
+ from streamdiffusion.acceleration.sfast import (
103
+ accelerate_with_stable_fast,
104
+ )
105
+
106
+ stream = accelerate_with_stable_fast(stream)
107
+ print("StableFast acceleration enabled.")
108
+ except Exception:
109
+ print("StableFast acceleration has failed. Using normal mode.")
110
+ pass
111
 
112
  stream.prepare(
113
  "",
 
128
 
129
  return stream
130
 
131
+ def __call__(self, prompt: str) -> PIL.Image.Image:
132
+ if self.prompt != prompt:
133
+ self.stream.update_prompt(prompt)
134
+ self.prompt = prompt
135
+ for i in range(self.batch_size):
136
+ x_output = self.stream.txt2img()
137
+
138
+ x_output = self.stream.txt2img()
139
+ image = postprocess_image(x_output, output_type="pil")[0]
140
+
141
+ if self.safety_checker:
142
+ safety_checker_input = self.feature_extractor(
143
+ image, return_tensors="pt"
144
+ ).to(self.device)
145
+ _, has_nsfw_concept = self.safety_checker(
146
+ images=x_output,
147
+ clip_input=safety_checker_input.pixel_values.to(self.dtype),
148
+ )
149
+ image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
 
 
 
 
 
 
 
 
 
 
150
 
151
+ return image
152
 
153
 
154
  if __name__ == "__main__":
view/.DS_Store ADDED
Binary file (6.15 kB). View file
 
view/src/App.tsx CHANGED
@@ -1,28 +1,75 @@
1
- import React, { useCallback, useEffect, useState } from "react";
2
- import { TextField, Grid, Paper } from "@mui/material";
3
 
4
  function App() {
5
  const [inputPrompt, setInputPrompt] = useState("");
6
- const [images, setImages] = useState(Array(9).fill("images/white.jpg"));
 
7
 
8
- const fetchImages = useCallback(async () => {
9
- try {
10
- const response = await fetch("/api/predict", {
11
- method: "POST",
12
- headers: { 'Content-Type': 'application/json' },
13
- body: JSON.stringify({ prompt: inputPrompt })
14
- });
15
- const data = await response.json();
16
- const imageUrls = data.base64_images.map((base64: string) => `data:image/jpeg;base64,${base64}`);
17
- setImages(imageUrls);
18
- } catch (error) {
19
- console.error("Error fetching images:", error);
20
  }
21
- }, [inputPrompt]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  const handlePromptChange = (event: React.ChangeEvent<HTMLInputElement>) => {
24
  setInputPrompt(event.target.value);
25
- fetchImages();
 
 
 
 
 
 
 
 
 
26
  };
27
 
28
  return (
@@ -48,20 +95,38 @@ function App() {
48
  flexDirection: "column",
49
  }}
50
  >
51
- <Grid container spacing={2}>
52
- {images.map((image, index) => (
53
- <Grid item xs={4} key={index}>
54
- <Paper style={{ padding: "10px", textAlign: "center" }}>
55
- <img src={image} alt={`Generated ${index}`} style={{ maxWidth: "100%", maxHeight: "200px", borderRadius: "10px" }} />
56
- </Paper>
57
- </Grid>
58
- ))}
59
- </Grid>
 
 
 
 
 
 
 
 
 
 
60
  <TextField
61
  variant="outlined"
62
  value={inputPrompt}
63
  onChange={handlePromptChange}
64
- style={{ marginBottom: "20px", marginTop: "20px", width: "640px", color: "#ffffff", borderColor: "#ffffff", borderRadius: "10px", backgroundColor: "#ffffff" }}
 
 
 
 
 
 
 
 
65
  placeholder="Enter a prompt"
66
  />
67
  </div>
 
1
+ import React, { useCallback, useState } from "react";
2
+ import { TextField, Grid } from "@mui/material";
3
 
4
  function App() {
5
  const [inputPrompt, setInputPrompt] = useState("");
6
+ const [lastPrompt, setLastPrompt] = useState("");
7
+ const [images, setImages] = useState(Array(16).fill("images/white.jpg"));
8
 
9
+ const calculateEditDistance = (a: string, b: string) => {
10
+ if (a.length === 0) return b.length;
11
+ if (b.length === 0) return a.length;
12
+
13
+ const matrix = [];
14
+
15
+ for (let i = 0; i <= b.length; i++) {
16
+ matrix[i] = [i];
17
+ }
18
+ for (let i = 0; i <= a.length; i++) {
19
+ matrix[0][i] = i;
 
20
  }
21
+
22
+ for (let i = 1; i <= b.length; i++) {
23
+ for (let j = 1; j <= a.length; j++) {
24
+ if (b.charAt(i - 1) === a.charAt(j - 1)) {
25
+ matrix[i][j] = matrix[i - 1][j - 1];
26
+ } else {
27
+ matrix[i][j] = Math.min(
28
+ matrix[i - 1][j - 1] + 1,
29
+ Math.min(matrix[i][j - 1] + 1, matrix[i - 1][j] + 1)
30
+ );
31
+ }
32
+ }
33
+ }
34
+
35
+ return matrix[b.length][a.length];
36
+ };
37
+
38
+ const fetchImage = useCallback(
39
+ async (index: number) => {
40
+ try {
41
+ const response = await fetch("/api/predict", {
42
+ method: "POST",
43
+ headers: { "Content-Type": "application/json" },
44
+ body: JSON.stringify({ prompt: inputPrompt }),
45
+ });
46
+ const data = await response.json();
47
+ const imageUrl = `data:image/jpeg;base64,${data.base64_image}`;
48
+
49
+ setImages((prevImages) => {
50
+ const newImages = [...prevImages];
51
+ newImages[index] = imageUrl;
52
+ return newImages;
53
+ });
54
+ } catch (error) {
55
+ console.error("Error fetching image:", error);
56
+ }
57
+ },
58
+ [inputPrompt]
59
+ );
60
 
61
  const handlePromptChange = (event: React.ChangeEvent<HTMLInputElement>) => {
62
  setInputPrompt(event.target.value);
63
+ const newPrompt = event.target.value;
64
+ const editDistance = calculateEditDistance(lastPrompt, newPrompt);
65
+
66
+ if (editDistance >= 2) {
67
+ setInputPrompt(newPrompt);
68
+ setLastPrompt(newPrompt);
69
+ for (let i = 0; i < 16; i++) {
70
+ fetchImage(i);
71
+ }
72
+ }
73
  };
74
 
75
  return (
 
95
  flexDirection: "column",
96
  }}
97
  >
98
+ <Grid
99
+ container
100
+ spacing={1}
101
+ style={{ maxWidth: "50%", maxHeight: "70%" }}
102
+ >
103
+ {images.map((image, index) => (
104
+ <Grid item xs={3} key={index}>
105
+ <img
106
+ src={image}
107
+ alt={`Generated ${index}`}
108
+ style={{
109
+ maxWidth: "100%",
110
+ maxHeight: "150px",
111
+ borderRadius: "10px",
112
+ }}
113
+ />
114
+ </Grid>
115
+ ))}
116
+ </Grid>
117
  <TextField
118
  variant="outlined"
119
  value={inputPrompt}
120
  onChange={handlePromptChange}
121
+ style={{
122
+ marginBottom: "20px",
123
+ marginTop: "20px",
124
+ width: "640px",
125
+ color: "#ffffff",
126
+ borderColor: "#ffffff",
127
+ borderRadius: "10px",
128
+ backgroundColor: "#ffffff",
129
+ }}
130
  placeholder="Enter a prompt"
131
  />
132
  </div>