radames commited on
Commit
d6fedfa
1 Parent(s): ff9325e

enable input_mode

Browse files
app-controlnet.py DELETED
@@ -1,322 +0,0 @@
1
- import asyncio
2
- import json
3
- import logging
4
- import traceback
5
- from pydantic import BaseModel
6
-
7
- from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import (
10
- StreamingResponse,
11
- JSONResponse,
12
- HTMLResponse,
13
- FileResponse,
14
- )
15
-
16
- from diffusers import AutoencoderTiny, ControlNetModel
17
- from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
18
- from compel import Compel
19
- import torch
20
-
21
- from canny_gpu import SobelOperator
22
-
23
- # from controlnet_aux import OpenposeDetector
24
- # import cv2
25
-
26
- try:
27
- import intel_extension_for_pytorch as ipex
28
- except:
29
- pass
30
- from PIL import Image
31
- import numpy as np
32
- import gradio as gr
33
- import io
34
- import uuid
35
- import os
36
- import time
37
- import psutil
38
-
39
-
40
- MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
41
- TIMEOUT = float(os.environ.get("TIMEOUT", 0))
42
- SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
43
- TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
44
- WIDTH = 512
45
- HEIGHT = 512
46
- # disable tiny autoencoder for better quality speed tradeoff
47
- USE_TINY_AUTOENCODER = True
48
-
49
- # check if MPS is available OSX only M1/M2/M3 chips
50
- mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
51
- xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
52
- device = torch.device(
53
- "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
54
- )
55
-
56
- # change to torch.float16 to save GPU memory
57
- torch_dtype = torch.float16
58
-
59
- print(f"TIMEOUT: {TIMEOUT}")
60
- print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
61
- print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}")
62
- print(f"device: {device}")
63
-
64
- if mps_available:
65
- device = torch.device("mps")
66
- device = "cpu"
67
- torch_dtype = torch.float32
68
-
69
- controlnet_canny = ControlNetModel.from_pretrained(
70
- "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch_dtype
71
- ).to(device)
72
-
73
- canny_torch = SobelOperator(device=device)
74
- # controlnet_pose = ControlNetModel.from_pretrained(
75
- # "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch_dtype
76
- # ).to(device)
77
- # controlnet_depth = ControlNetModel.from_pretrained(
78
- # "lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch_dtype
79
- # ).to(device)
80
-
81
-
82
- # pose_processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
83
-
84
- if SAFETY_CHECKER == "True":
85
- pipe = LatentConsistencyModelPipeline_controlnet.from_pretrained(
86
- "SimianLuo/LCM_Dreamshaper_v7",
87
- controlnet=controlnet_canny,
88
- scheduler=None,
89
- )
90
- else:
91
- pipe = LatentConsistencyModelPipeline_controlnet.from_pretrained(
92
- "SimianLuo/LCM_Dreamshaper_v7",
93
- safety_checker=None,
94
- controlnet=controlnet_canny,
95
- scheduler=None,
96
- )
97
-
98
- if USE_TINY_AUTOENCODER:
99
- pipe.vae = AutoencoderTiny.from_pretrained(
100
- "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
101
- )
102
- pipe.set_progress_bar_config(disable=True)
103
- pipe.to(device=device, dtype=torch_dtype).to(device)
104
- pipe.unet.to(memory_format=torch.channels_last)
105
-
106
- if psutil.virtual_memory().total < 64 * 1024**3:
107
- pipe.enable_attention_slicing()
108
-
109
- compel_proc = Compel(
110
- tokenizer=pipe.tokenizer,
111
- text_encoder=pipe.text_encoder,
112
- truncate_long_prompts=False,
113
- )
114
- if TORCH_COMPILE:
115
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
116
- pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
117
-
118
- pipe(
119
- prompt="warmup",
120
- image=[Image.new("RGB", (768, 768))],
121
- control_image=[Image.new("RGB", (768, 768))],
122
- )
123
-
124
-
125
- user_queue_map = {}
126
-
127
-
128
- class InputParams(BaseModel):
129
- seed: int = 2159232
130
- prompt: str
131
- guidance_scale: float = 8.0
132
- strength: float = 0.5
133
- steps: int = 4
134
- lcm_steps: int = 50
135
- width: int = WIDTH
136
- height: int = HEIGHT
137
- controlnet_scale: float = 0.8
138
- controlnet_start: float = 0.0
139
- controlnet_end: float = 1.0
140
- canny_low_threshold: float = 0.31
141
- canny_high_threshold: float = 0.78
142
- debug_canny: bool = False
143
-
144
-
145
- def predict(
146
- input_image: Image.Image, params: InputParams, prompt_embeds: torch.Tensor = None
147
- ):
148
- generator = torch.manual_seed(params.seed)
149
-
150
- control_image = canny_torch(
151
- input_image, params.canny_low_threshold, params.canny_high_threshold
152
- )
153
- results = pipe(
154
- control_image=control_image,
155
- prompt_embeds=prompt_embeds,
156
- generator=generator,
157
- image=input_image,
158
- strength=params.strength,
159
- num_inference_steps=params.steps,
160
- guidance_scale=params.guidance_scale,
161
- width=params.width,
162
- height=params.height,
163
- lcm_origin_steps=params.lcm_steps,
164
- output_type="pil",
165
- controlnet_conditioning_scale=params.controlnet_scale,
166
- control_guidance_start=params.controlnet_start,
167
- control_guidance_end=params.controlnet_end,
168
- )
169
- nsfw_content_detected = (
170
- results.nsfw_content_detected[0]
171
- if "nsfw_content_detected" in results
172
- else False
173
- )
174
- if nsfw_content_detected:
175
- return None
176
- result_image = results.images[0]
177
- if params.debug_canny:
178
- # paste control_image on top of result_image
179
- w0, h0 = (200, 200)
180
- control_image = control_image.resize((w0, h0))
181
- w1, h1 = result_image.size
182
- result_image.paste(control_image, (w1 - w0, h1 - h0))
183
-
184
- return result_image
185
-
186
-
187
- app = FastAPI()
188
- app.add_middleware(
189
- CORSMiddleware,
190
- allow_origins=["*"],
191
- allow_credentials=True,
192
- allow_methods=["*"],
193
- allow_headers=["*"],
194
- )
195
-
196
-
197
- @app.websocket("/ws")
198
- async def websocket_endpoint(websocket: WebSocket):
199
- await websocket.accept()
200
- if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE:
201
- print("Server is full")
202
- await websocket.send_json({"status": "error", "message": "Server is full"})
203
- await websocket.close()
204
- return
205
-
206
- try:
207
- uid = str(uuid.uuid4())
208
- print(f"New user connected: {uid}")
209
- await websocket.send_json(
210
- {"status": "success", "message": "Connected", "userId": uid}
211
- )
212
- user_queue_map[uid] = {"queue": asyncio.Queue()}
213
- await websocket.send_json(
214
- {"status": "start", "message": "Start Streaming", "userId": uid}
215
- )
216
- await handle_websocket_data(websocket, uid)
217
- except WebSocketDisconnect as e:
218
- logging.error(f"WebSocket Error: {e}, {uid}")
219
- traceback.print_exc()
220
- finally:
221
- print(f"User disconnected: {uid}")
222
- queue_value = user_queue_map.pop(uid, None)
223
- queue = queue_value.get("queue", None)
224
- if queue:
225
- while not queue.empty():
226
- try:
227
- queue.get_nowait()
228
- except asyncio.QueueEmpty:
229
- continue
230
-
231
-
232
- @app.get("/queue_size")
233
- async def get_queue_size():
234
- queue_size = len(user_queue_map)
235
- return JSONResponse({"queue_size": queue_size})
236
-
237
-
238
- @app.get("/stream/{user_id}")
239
- async def stream(user_id: uuid.UUID):
240
- uid = str(user_id)
241
- try:
242
- user_queue = user_queue_map[uid]
243
- queue = user_queue["queue"]
244
-
245
- async def generate():
246
- last_prompt: str = None
247
- prompt_embeds: torch.Tensor = None
248
- while True:
249
- data = await queue.get()
250
- input_image = data["image"]
251
- params = data["params"]
252
- if input_image is None:
253
- continue
254
- # avoid recalculate prompt embeds
255
- if last_prompt != params.prompt:
256
- print("new prompt")
257
- prompt_embeds = compel_proc(params.prompt)
258
- last_prompt = params.prompt
259
-
260
- image = predict(
261
- input_image,
262
- params,
263
- prompt_embeds,
264
- )
265
- if image is None:
266
- continue
267
- frame_data = io.BytesIO()
268
- image.save(frame_data, format="JPEG")
269
- frame_data = frame_data.getvalue()
270
- if frame_data is not None and len(frame_data) > 0:
271
- yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
272
-
273
- await asyncio.sleep(1.0 / 120.0)
274
-
275
- return StreamingResponse(
276
- generate(), media_type="multipart/x-mixed-replace;boundary=frame"
277
- )
278
- except Exception as e:
279
- logging.error(f"Streaming Error: {e}, {user_queue_map}")
280
- traceback.print_exc()
281
- return HTTPException(status_code=404, detail="User not found")
282
-
283
-
284
- async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
285
- uid = str(user_id)
286
- user_queue = user_queue_map[uid]
287
- queue = user_queue["queue"]
288
- if not queue:
289
- return HTTPException(status_code=404, detail="User not found")
290
- last_time = time.time()
291
- try:
292
- while True:
293
- data = await websocket.receive_bytes()
294
- params = await websocket.receive_json()
295
- params = InputParams(**params)
296
- pil_image = Image.open(io.BytesIO(data))
297
-
298
- while not queue.empty():
299
- try:
300
- queue.get_nowait()
301
- except asyncio.QueueEmpty:
302
- continue
303
- await queue.put({"image": pil_image, "params": params})
304
- if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
305
- await websocket.send_json(
306
- {
307
- "status": "timeout",
308
- "message": "Your session has ended",
309
- "userId": uid,
310
- }
311
- )
312
- await websocket.close()
313
- return
314
-
315
- except Exception as e:
316
- logging.error(f"Error: {e}")
317
- traceback.print_exc()
318
-
319
-
320
- @app.get("/", response_class=HTMLResponse)
321
- async def root():
322
- return FileResponse("./static/controlnet.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app-txt2img.py DELETED
@@ -1,255 +0,0 @@
1
- import asyncio
2
- import json
3
- import logging
4
- import traceback
5
- from pydantic import BaseModel
6
-
7
- from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import (
10
- StreamingResponse,
11
- JSONResponse,
12
- HTMLResponse,
13
- FileResponse,
14
- )
15
-
16
- from diffusers import DiffusionPipeline, AutoencoderTiny
17
- from compel import Compel
18
- import torch
19
-
20
- try:
21
- import intel_extension_for_pytorch as ipex
22
- except:
23
- pass
24
- from PIL import Image
25
- import numpy as np
26
- import gradio as gr
27
- import io
28
- import uuid
29
- import os
30
- import time
31
- import psutil
32
-
33
-
34
- MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
35
- TIMEOUT = float(os.environ.get("TIMEOUT", 0))
36
- SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
37
- TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
38
-
39
- WIDTH = 768
40
- HEIGHT = 768
41
- # disable tiny autoencoder for better quality speed tradeoff
42
- USE_TINY_AUTOENCODER = False
43
-
44
- # check if MPS is available OSX only M1/M2/M3 chips
45
- mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
46
- xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
47
- device = torch.device(
48
- "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
49
- )
50
- torch_device = device
51
- # change to torch.float16 to save GPU memory
52
- torch_dtype = torch.float32
53
-
54
- print(f"TIMEOUT: {TIMEOUT}")
55
- print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
56
- print(f"MAX_QUEUE_SIZE: {MAX_QUEUE_SIZE}")
57
- print(f"device: {device}")
58
-
59
- if mps_available:
60
- device = torch.device("mps")
61
- torch_device = "cpu"
62
- torch_dtype = torch.float32
63
-
64
- if SAFETY_CHECKER == "True":
65
- pipe = DiffusionPipeline.from_pretrained(
66
- "SimianLuo/LCM_Dreamshaper_v7",
67
- )
68
- else:
69
- pipe = DiffusionPipeline.from_pretrained(
70
- "SimianLuo/LCM_Dreamshaper_v7",
71
- safety_checker=None,
72
- )
73
- if USE_TINY_AUTOENCODER:
74
- pipe.vae = AutoencoderTiny.from_pretrained(
75
- "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
76
- )
77
- pipe.set_progress_bar_config(disable=True)
78
- pipe.to(device=torch_device, dtype=torch_dtype).to(device)
79
- pipe.unet.to(memory_format=torch.channels_last)
80
-
81
- # check if computer has less than 64GB of RAM using sys or os
82
- if psutil.virtual_memory().total < 64 * 1024**3:
83
- pipe.enable_attention_slicing()
84
-
85
- if TORCH_COMPILE:
86
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
87
- pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
88
-
89
- pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
90
-
91
- compel_proc = Compel(
92
- tokenizer=pipe.tokenizer,
93
- text_encoder=pipe.text_encoder,
94
- truncate_long_prompts=False,
95
- )
96
- user_queue_map = {}
97
-
98
-
99
- class InputParams(BaseModel):
100
- seed: int = 2159232
101
- prompt: str
102
- guidance_scale: float = 8.0
103
- strength: float = 0.5
104
- steps: int = 4
105
- lcm_steps: int = 50
106
- width: int = WIDTH
107
- height: int = HEIGHT
108
-
109
-
110
- def predict(params: InputParams):
111
- generator = torch.manual_seed(params.seed)
112
- prompt_embeds = compel_proc(params.prompt)
113
- # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
114
- results = pipe(
115
- prompt_embeds=prompt_embeds,
116
- generator=generator,
117
- num_inference_steps=params.steps,
118
- guidance_scale=params.guidance_scale,
119
- width=params.width,
120
- height=params.height,
121
- original_inference_steps=params.lcm_steps,
122
- output_type="pil",
123
- )
124
- nsfw_content_detected = (
125
- results.nsfw_content_detected[0]
126
- if "nsfw_content_detected" in results
127
- else False
128
- )
129
- if nsfw_content_detected:
130
- return None
131
- return results.images[0]
132
-
133
-
134
- app = FastAPI()
135
- app.add_middleware(
136
- CORSMiddleware,
137
- allow_origins=["*"],
138
- allow_credentials=True,
139
- allow_methods=["*"],
140
- allow_headers=["*"],
141
- )
142
-
143
-
144
- @app.websocket("/ws")
145
- async def websocket_endpoint(websocket: WebSocket):
146
- await websocket.accept()
147
- if MAX_QUEUE_SIZE > 0 and len(user_queue_map) >= MAX_QUEUE_SIZE:
148
- print("Server is full")
149
- await websocket.send_json({"status": "error", "message": "Server is full"})
150
- await websocket.close()
151
- return
152
-
153
- try:
154
- uid = str(uuid.uuid4())
155
- print(f"New user connected: {uid}")
156
- await websocket.send_json(
157
- {"status": "success", "message": "Connected", "userId": uid}
158
- )
159
- user_queue_map[uid] = {
160
- "queue": asyncio.Queue(),
161
- }
162
- await websocket.send_json(
163
- {"status": "start", "message": "Start Streaming", "userId": uid}
164
- )
165
- await handle_websocket_data(websocket, uid)
166
- except WebSocketDisconnect as e:
167
- logging.error(f"WebSocket Error: {e}, {uid}")
168
- traceback.print_exc()
169
- finally:
170
- print(f"User disconnected: {uid}")
171
- queue_value = user_queue_map.pop(uid, None)
172
- queue = queue_value.get("queue", None)
173
- if queue:
174
- while not queue.empty():
175
- try:
176
- queue.get_nowait()
177
- except asyncio.QueueEmpty:
178
- continue
179
-
180
-
181
- @app.get("/queue_size")
182
- async def get_queue_size():
183
- queue_size = len(user_queue_map)
184
- return JSONResponse({"queue_size": queue_size})
185
-
186
-
187
- @app.get("/stream/{user_id}")
188
- async def stream(user_id: uuid.UUID):
189
- uid = str(user_id)
190
- try:
191
- user_queue = user_queue_map[uid]
192
- queue = user_queue["queue"]
193
-
194
- async def generate():
195
- while True:
196
- params = await queue.get()
197
- if params is None:
198
- continue
199
-
200
- image = predict(params)
201
- if image is None:
202
- continue
203
- frame_data = io.BytesIO()
204
- image.save(frame_data, format="JPEG")
205
- frame_data = frame_data.getvalue()
206
- if frame_data is not None and len(frame_data) > 0:
207
- yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
208
-
209
- await asyncio.sleep(1.0 / 120.0)
210
-
211
- return StreamingResponse(
212
- generate(), media_type="multipart/x-mixed-replace;boundary=frame"
213
- )
214
- except Exception as e:
215
- logging.error(f"Streaming Error: {e}, {user_queue_map}")
216
- traceback.print_exc()
217
- return HTTPException(status_code=404, detail="User not found")
218
-
219
-
220
- async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
221
- uid = str(user_id)
222
- user_queue = user_queue_map[uid]
223
- queue = user_queue["queue"]
224
- if not queue:
225
- return HTTPException(status_code=404, detail="User not found")
226
- last_time = time.time()
227
- try:
228
- while True:
229
- params = await websocket.receive_json()
230
- params = InputParams(**params)
231
- while not queue.empty():
232
- try:
233
- queue.get_nowait()
234
- except asyncio.QueueEmpty:
235
- continue
236
- await queue.put(params)
237
- if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
238
- await websocket.send_json(
239
- {
240
- "status": "timeout",
241
- "message": "Your session has ended",
242
- "userId": uid,
243
- }
244
- )
245
- await websocket.close()
246
- return
247
-
248
- except Exception as e:
249
- logging.error(f"Error: {e}")
250
- traceback.print_exc()
251
-
252
-
253
- @app.get("/", response_class=HTMLResponse)
254
- async def root():
255
- return FileResponse("./static/txt2img.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_init.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.staticfiles import StaticFiles
 
5
 
6
  import logging
7
  import traceback
@@ -11,8 +12,8 @@ import uuid
11
  from asyncio import Event, sleep
12
  import time
13
  from PIL import Image
14
- import io
15
  from types import SimpleNamespace
 
16
 
17
 
18
  def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipeline):
@@ -23,7 +24,6 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
23
  allow_methods=["*"],
24
  allow_headers=["*"],
25
  )
26
- print("Init app", app)
27
 
28
  @app.websocket("/ws")
29
  async def websocket_endpoint(websocket: WebSocket):
@@ -41,7 +41,6 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
41
  {"status": "success", "message": "Connected", "userId": uid}
42
  )
43
  user_data_events[uid] = UserDataEvent()
44
- print(f"User data events: {user_data_events}")
45
  await websocket.send_json(
46
  {"status": "start", "message": "Start Streaming", "userId": uid}
47
  )
@@ -59,31 +58,27 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
59
  return JSONResponse({"queue_size": queue_size})
60
 
61
  @app.get("/stream/{user_id}")
62
- async def stream(user_id: uuid.UUID):
63
  uid = str(user_id)
64
  try:
65
 
66
  async def generate():
67
- last_prompt: str = None
68
  while True:
69
  data = await user_data_events[uid].wait_for_data()
70
  params = data["params"]
71
- # input_image = data["image"]
72
- # if input_image is None:
73
- # continue
74
  image = pipeline.predict(params)
75
  if image is None:
76
  continue
77
- frame_data = io.BytesIO()
78
- image.save(frame_data, format="JPEG")
79
- frame_data = frame_data.getvalue()
80
- if frame_data is not None and len(frame_data) > 0:
81
- yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
82
-
83
- await sleep(1.0 / 120.0)
84
 
85
  return StreamingResponse(
86
- generate(), media_type="multipart/x-mixed-replace;boundary=frame"
 
 
87
  )
88
  except Exception as e:
89
  logging.error(f"Streaming Error: {e}, {user_data_events}")
@@ -99,8 +94,9 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
99
  while True:
100
  params = await websocket.receive_json()
101
  params = pipeline.InputParams(**params)
 
102
  params = SimpleNamespace(**params.dict())
103
- if hasattr(params, "image"):
104
  image_data = await websocket.receive_bytes()
105
  pil_image = Image.open(io.BytesIO(image_data))
106
  params.image = pil_image
@@ -125,6 +121,12 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
125
  async def settings():
126
  info = pipeline.Info.schema()
127
  input_params = pipeline.InputParams.schema()
128
- return JSONResponse({"info": info, "input_params": input_params})
 
 
 
 
 
 
129
 
130
  app.mount("/", StaticFiles(directory="public", html=True), name="public")
 
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.staticfiles import StaticFiles
5
+ from fastapi import Request
6
 
7
  import logging
8
  import traceback
 
12
  from asyncio import Event, sleep
13
  import time
14
  from PIL import Image
 
15
  from types import SimpleNamespace
16
+ from util import pil_to_frame, is_firefox
17
 
18
 
19
  def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipeline):
 
24
  allow_methods=["*"],
25
  allow_headers=["*"],
26
  )
 
27
 
28
  @app.websocket("/ws")
29
  async def websocket_endpoint(websocket: WebSocket):
 
41
  {"status": "success", "message": "Connected", "userId": uid}
42
  )
43
  user_data_events[uid] = UserDataEvent()
 
44
  await websocket.send_json(
45
  {"status": "start", "message": "Start Streaming", "userId": uid}
46
  )
 
58
  return JSONResponse({"queue_size": queue_size})
59
 
60
  @app.get("/stream/{user_id}")
61
+ async def stream(user_id: uuid.UUID, request: Request):
62
  uid = str(user_id)
63
  try:
64
 
65
  async def generate():
 
66
  while True:
67
  data = await user_data_events[uid].wait_for_data()
68
  params = data["params"]
 
 
 
69
  image = pipeline.predict(params)
70
  if image is None:
71
  continue
72
+ frame = pil_to_frame(image)
73
+ yield frame
74
+ # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
75
+ if not is_firefox(request.headers["user-agent"]):
76
+ yield frame
 
 
77
 
78
  return StreamingResponse(
79
+ generate(),
80
+ media_type="multipart/x-mixed-replace;boundary=frame",
81
+ headers={"Cache-Control": "no-cache"},
82
  )
83
  except Exception as e:
84
  logging.error(f"Streaming Error: {e}, {user_data_events}")
 
94
  while True:
95
  params = await websocket.receive_json()
96
  params = pipeline.InputParams(**params)
97
+ info = pipeline.Info()
98
  params = SimpleNamespace(**params.dict())
99
+ if info.input_mode == "image":
100
  image_data = await websocket.receive_bytes()
101
  pil_image = Image.open(io.BytesIO(image_data))
102
  params.image = pil_image
 
121
  async def settings():
122
  info = pipeline.Info.schema()
123
  input_params = pipeline.InputParams.schema()
124
+ return JSONResponse(
125
+ {
126
+ "info": info,
127
+ "input_params": input_params,
128
+ "max_queue_size": args.max_queue_size,
129
+ }
130
+ )
131
 
132
  app.mount("/", StaticFiles(directory="public", html=True), name="public")
frontend/src/lib/components/ImagePlayer.svelte CHANGED
@@ -3,7 +3,10 @@
3
  import { onFrameChangeStore } from '$lib/mediaStream';
4
  import { PUBLIC_BASE_URL } from '$env/static/public';
5
 
6
- $: streamId = $lcmLiveState.streamId;
 
 
 
7
  </script>
8
 
9
  <div class="relative overflow-hidden rounded-lg border border-slate-300">
@@ -14,19 +17,6 @@
14
  <div class="aspect-square w-full rounded-lg" />
15
  {/if}
16
  <div class="absolute left-0 top-0 aspect-square w-1/4">
17
- <div class="relative z-10 aspect-square w-full object-cover">
18
- <slot />
19
- </div>
20
- <svg
21
- xmlns="http://www.w3.org/2000/svg"
22
- viewBox="0 0 448 448"
23
- width="100"
24
- class="absolute top-0 z-0 w-full p-4 opacity-20"
25
- >
26
- <path
27
- fill="currentColor"
28
- d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z"
29
- />
30
- </svg>
31
  </div>
32
  </div>
 
3
  import { onFrameChangeStore } from '$lib/mediaStream';
4
  import { PUBLIC_BASE_URL } from '$env/static/public';
5
 
6
+ $: streamId = $lcmLiveState?.streamId;
7
+ $: {
8
+ console.log('streamId', streamId);
9
+ }
10
  </script>
11
 
12
  <div class="relative overflow-hidden rounded-lg border border-slate-300">
 
17
  <div class="aspect-square w-full rounded-lg" />
18
  {/if}
19
  <div class="absolute left-0 top-0 aspect-square w-1/4">
20
+ <slot />
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  </div>
22
  </div>
frontend/src/lib/components/InputRange.svelte CHANGED
@@ -8,7 +8,7 @@
8
  });
9
  </script>
10
 
11
- <div class="grid grid-cols-4 items-center gap-3">
12
  <label class="text-sm font-medium" for={params.id}>{params?.title}</label>
13
  <input
14
  class="col-span-2 h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-300 dark:bg-gray-500"
 
8
  });
9
  </script>
10
 
11
+ <div class="grid max-w-md grid-cols-4 items-center gap-3">
12
  <label class="text-sm font-medium" for={params.id}>{params?.title}</label>
13
  <input
14
  class="col-span-2 h-2 w-full cursor-pointer appearance-none rounded-lg bg-gray-300 dark:bg-gray-500"
frontend/src/lib/components/PipelineOptions.svelte CHANGED
@@ -6,9 +6,9 @@
6
  import SeedInput from './SeedInput.svelte';
7
  import TextArea from './TextArea.svelte';
8
  import Checkbox from './Checkbox.svelte';
 
9
 
10
  export let pipelineParams: FieldProps[];
11
- export let pipelineValues = {} as any;
12
 
13
  $: advanceOptions = pipelineParams?.filter((e) => e?.hide == true);
14
  $: featuredOptions = pipelineParams?.filter((e) => e?.hide !== true);
@@ -18,13 +18,13 @@
18
  {#if featuredOptions}
19
  {#each featuredOptions as params}
20
  {#if params.field === FieldType.range}
21
- <InputRange {params} bind:value={pipelineValues[params.id]}></InputRange>
22
  {:else if params.field === FieldType.seed}
23
- <SeedInput bind:value={pipelineValues[params.id]}></SeedInput>
24
  {:else if params.field === FieldType.textarea}
25
- <TextArea {params} bind:value={pipelineValues[params.id]}></TextArea>
26
  {:else if params.field === FieldType.checkbox}
27
- <Checkbox {params} bind:value={pipelineValues[params.id]}></Checkbox>
28
  {/if}
29
  {/each}
30
  {/if}
@@ -32,17 +32,19 @@
32
 
33
  <details open>
34
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
35
- <div class="grid grid-cols-1 items-center gap-3 sm:grid-cols-2">
 
 
36
  {#if advanceOptions}
37
  {#each advanceOptions as params}
38
  {#if params.field === FieldType.range}
39
- <InputRange {params} bind:value={pipelineValues[params.id]}></InputRange>
40
  {:else if params.field === FieldType.seed}
41
- <SeedInput bind:value={pipelineValues[params.id]}></SeedInput>
42
  {:else if params.field === FieldType.textarea}
43
- <TextArea {params} bind:value={pipelineValues[params.id]}></TextArea>
44
  {:else if params.field === FieldType.checkbox}
45
- <Checkbox {params} bind:value={pipelineValues[params.id]}></Checkbox>
46
  {/if}
47
  {/each}
48
  {/if}
 
6
  import SeedInput from './SeedInput.svelte';
7
  import TextArea from './TextArea.svelte';
8
  import Checkbox from './Checkbox.svelte';
9
+ import { pipelineValues } from '$lib/store';
10
 
11
  export let pipelineParams: FieldProps[];
 
12
 
13
  $: advanceOptions = pipelineParams?.filter((e) => e?.hide == true);
14
  $: featuredOptions = pipelineParams?.filter((e) => e?.hide !== true);
 
18
  {#if featuredOptions}
19
  {#each featuredOptions as params}
20
  {#if params.field === FieldType.range}
21
+ <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
22
  {:else if params.field === FieldType.seed}
23
+ <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
24
  {:else if params.field === FieldType.textarea}
25
+ <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
26
  {:else if params.field === FieldType.checkbox}
27
+ <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
28
  {/if}
29
  {/each}
30
  {/if}
 
32
 
33
  <details open>
34
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
35
+ <div
36
+ class="grid grid-cols-1 items-center gap-3 {pipelineValues.length > 5 ? 'sm:grid-cols-2' : ''}"
37
+ >
38
  {#if advanceOptions}
39
  {#each advanceOptions as params}
40
  {#if params.field === FieldType.range}
41
+ <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
42
  {:else if params.field === FieldType.seed}
43
+ <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
44
  {:else if params.field === FieldType.textarea}
45
+ <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
46
  {:else if params.field === FieldType.checkbox}
47
+ <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
48
  {/if}
49
  {/each}
50
  {/if}
frontend/src/lib/components/VideoInput.svelte CHANGED
@@ -62,12 +62,25 @@
62
  }
63
  </script>
64
 
65
- <video
66
- class="aspect-square w-full object-cover"
67
- bind:this={videoEl}
68
- playsinline
69
- autoplay
70
- muted
71
- loop
72
- use:srcObject={mediaStream}
73
- ></video>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  }
63
  </script>
64
 
65
+ <div class="relative z-10 aspect-square w-full object-cover">
66
+ <video
67
+ class="aspect-square w-full object-cover"
68
+ bind:this={videoEl}
69
+ playsinline
70
+ autoplay
71
+ muted
72
+ loop
73
+ use:srcObject={mediaStream}
74
+ ></video>
75
+ </div>
76
+ <svg
77
+ xmlns="http://www.w3.org/2000/svg"
78
+ viewBox="0 0 448 448"
79
+ width="100"
80
+ class="absolute top-0 z-0 w-full p-4 opacity-20"
81
+ >
82
+ <path
83
+ fill="currentColor"
84
+ d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z"
85
+ />
86
+ </svg>
frontend/src/lib/lcmLive.ts CHANGED
@@ -1,5 +1,5 @@
1
  import { writable } from 'svelte/store';
2
- import { PUBLIC_BASE_URL, PUBLIC_WSS_URL } from '$env/static/public';
3
 
4
  export const isStreaming = writable(false);
5
  export const isLCMRunning = writable(false);
@@ -26,55 +26,75 @@ export const lcmLiveState = writable(initialState);
26
  let websocket: WebSocket | null = null;
27
  export const lcmLiveActions = {
28
  async start() {
 
29
 
30
- isLCMRunning.set(true);
31
- try {
32
- const websocketURL = PUBLIC_WSS_URL ? PUBLIC_WSS_URL : `${window.location.protocol === "https:" ? "wss" : "ws"
33
- }:${window.location.host}/ws`;
34
 
35
- websocket = new WebSocket(websocketURL);
36
- websocket.onopen = () => {
37
- console.log("Connected to websocket");
38
- };
39
- websocket.onclose = () => {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  lcmLiveState.update((state) => ({
41
  ...state,
42
- status: LCMLiveStatus.DISCONNECTED
 
43
  }));
44
- console.log("Disconnected from websocket");
45
- isLCMRunning.set(false);
46
- };
47
- websocket.onerror = (err) => {
48
- console.error(err);
49
- };
50
- websocket.onmessage = (event) => {
51
- const data = JSON.parse(event.data);
52
- console.log("WS: ", data);
53
- switch (data.status) {
54
- case "success":
55
- break;
56
- case "start":
57
- const streamId = data.userId;
58
- lcmLiveState.update((state) => ({
59
- ...state,
60
- status: LCMLiveStatus.CONNECTED,
61
- streamId: streamId,
62
- }));
63
- break;
64
- case "timeout":
65
- console.log("timeout");
66
- case "error":
67
- console.log(data.message);
68
- isLCMRunning.set(false);
69
- }
70
- };
71
- lcmLiveState.update((state) => ({
72
- ...state,
73
- }));
74
- } catch (err) {
75
- console.error(err);
76
- isLCMRunning.set(false);
77
- }
78
  },
79
  send(data: Blob | { [key: string]: any }) {
80
  if (websocket && websocket.readyState === WebSocket.OPEN) {
 
1
  import { writable } from 'svelte/store';
2
+ import { PUBLIC_WSS_URL } from '$env/static/public';
3
 
4
  export const isStreaming = writable(false);
5
  export const isLCMRunning = writable(false);
 
26
  let websocket: WebSocket | null = null;
27
  export const lcmLiveActions = {
28
  async start() {
29
+ return new Promise((resolve, reject) => {
30
 
31
+ try {
32
+ const websocketURL = PUBLIC_WSS_URL ? PUBLIC_WSS_URL : `${window.location.protocol === "https:" ? "wss" : "ws"
33
+ }:${window.location.host}/ws`;
 
34
 
35
+ websocket = new WebSocket(websocketURL);
36
+ websocket.onopen = () => {
37
+ console.log("Connected to websocket");
38
+ };
39
+ websocket.onclose = () => {
40
+ lcmLiveState.update((state) => ({
41
+ ...state,
42
+ status: LCMLiveStatus.DISCONNECTED
43
+ }));
44
+ console.log("Disconnected from websocket");
45
+ isLCMRunning.set(false);
46
+ };
47
+ websocket.onerror = (err) => {
48
+ console.error(err);
49
+ };
50
+ websocket.onmessage = (event) => {
51
+ const data = JSON.parse(event.data);
52
+ console.log("WS: ", data);
53
+ switch (data.status) {
54
+ case "success":
55
+ break;
56
+ case "start":
57
+ const streamId = data.userId;
58
+ lcmLiveState.update((state) => ({
59
+ ...state,
60
+ status: LCMLiveStatus.CONNECTED,
61
+ streamId: streamId,
62
+ }));
63
+ isLCMRunning.set(true);
64
+ resolve(streamId);
65
+ break;
66
+ case "timeout":
67
+ console.log("timeout");
68
+ isLCMRunning.set(false);
69
+ lcmLiveState.update((state) => ({
70
+ ...state,
71
+ status: LCMLiveStatus.DISCONNECTED,
72
+ streamId: null,
73
+ }));
74
+ reject("timeout");
75
+ case "error":
76
+ console.log(data.message);
77
+ isLCMRunning.set(false);
78
+ lcmLiveState.update((state) => ({
79
+ ...state,
80
+ status: LCMLiveStatus.DISCONNECTED,
81
+ streamId: null,
82
+ }));
83
+ reject(data.message);
84
+ }
85
+ };
86
+
87
+ } catch (err) {
88
+ console.error(err);
89
+ isLCMRunning.set(false);
90
  lcmLiveState.update((state) => ({
91
  ...state,
92
+ status: LCMLiveStatus.DISCONNECTED,
93
+ streamId: null,
94
  }));
95
+ reject(err);
96
+ }
97
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  },
99
  send(data: Blob | { [key: string]: any }) {
100
  if (websocket && websocket.readyState === WebSocket.OPEN) {
frontend/src/lib/store.ts ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ import { writable, type Writable } from 'svelte/store';
3
+
4
+ export const pipelineValues = writable({});
frontend/src/lib/types.ts CHANGED
@@ -4,6 +4,11 @@ export const enum FieldType {
4
  textarea = "textarea",
5
  checkbox = "checkbox",
6
  }
 
 
 
 
 
7
 
8
  export interface FieldProps {
9
  default: number | string;
@@ -19,5 +24,7 @@ export interface FieldProps {
19
  export interface PipelineInfo {
20
  name: string;
21
  description: string;
22
- mode: string;
 
 
23
  }
 
4
  textarea = "textarea",
5
  checkbox = "checkbox",
6
  }
7
+ export const enum PipelineMode {
8
+ image = "image",
9
+ video = "video",
10
+ text = "text",
11
+ }
12
 
13
  export interface FieldProps {
14
  default: number | string;
 
24
  export interface PipelineInfo {
25
  name: string;
26
  description: string;
27
+ input_mode: {
28
+ default: PipelineMode;
29
+ }
30
  }
frontend/src/routes/+page.svelte CHANGED
@@ -2,6 +2,7 @@
2
  import { onMount } from 'svelte';
3
  import { PUBLIC_BASE_URL } from '$env/static/public';
4
  import type { FieldProps, PipelineInfo } from '$lib/types';
 
5
  import ImagePlayer from '$lib/components/ImagePlayer.svelte';
6
  import VideoInput from '$lib/components/VideoInput.svelte';
7
  import Button from '$lib/components/Button.svelte';
@@ -14,10 +15,12 @@
14
  isMediaStreaming,
15
  onFrameChangeStore
16
  } from '$lib/mediaStream';
 
17
 
18
  let pipelineParams: FieldProps[];
19
  let pipelineInfo: PipelineInfo;
20
- let pipelineValues = {};
 
21
 
22
  onMount(() => {
23
  getSettings();
@@ -27,89 +30,73 @@
27
  const settings = await fetch(`${PUBLIC_BASE_URL}/settings`).then((r) => r.json());
28
  pipelineParams = Object.values(settings.input_params.properties);
29
  pipelineInfo = settings.info.properties;
 
 
30
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
31
  console.log('PARAMS', pipelineParams);
32
  console.log('SETTINGS', pipelineInfo);
33
  }
 
34
 
35
- // $: {
36
- // console.log('isLCMRunning', $isLCMRunning);
37
- // }
38
- // $: {
39
- // console.log('lcmLiveState', $lcmLiveState);
40
- // }
41
- // $: {
42
- // console.log('mediaStreamState', $mediaStreamState);
43
- // }
44
- // $: if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
45
- // lcmLiveActions.send(pipelineValues);
46
- // }
47
- onFrameChangeStore.subscribe(async (frame) => {
48
- if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
49
- lcmLiveActions.send(pipelineValues);
50
- lcmLiveActions.send(frame.blob);
51
  }
52
- });
53
- let startBt: Button;
54
- let stopBt: Button;
55
- let snapShotBt: Button;
56
 
 
 
 
 
 
 
57
  async function toggleLcmLive() {
58
  if (!$isLCMRunning) {
59
- await mediaStreamActions.enumerateDevices();
60
- await mediaStreamActions.start();
61
- lcmLiveActions.start();
 
 
62
  } else {
63
- mediaStreamActions.stop();
 
 
64
  lcmLiveActions.stop();
65
  }
66
  }
67
- async function startLcmLive() {
68
- try {
69
- $isLCMRunning = true;
70
- // const res = await lcmLive.start();
71
- $isLCMRunning = false;
72
- // if (res.status === "timeout")
73
- // toggleMessage("success")
74
- } catch (err) {
75
- console.log(err);
76
- // toggleMessage("error")
77
- $isLCMRunning = false;
78
- }
79
- }
80
- async function stopLcmLive() {
81
- // await lcmLive.stop();
82
- $isLCMRunning = false;
83
- }
84
  </script>
85
 
86
  <div class="fixed right-2 top-2 max-w-xs rounded-lg p-4 text-center text-sm font-bold" id="error" />
87
  <main class="container mx-auto flex max-w-4xl flex-col gap-3 px-4 py-4">
88
  <article class="flex- mx-auto max-w-xl text-center">
89
  <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
90
- <p class="text-sm">
91
  This demo showcases
92
  <a
93
- href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7"
94
  target="_blank"
95
- class="text-blue-500 underline hover:no-underline">LCM</a
96
  >
97
  Image to Image pipeline using
98
  <a
99
- href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
100
  target="_blank"
101
  class="text-blue-500 underline hover:no-underline">Diffusers</a
102
  > with a MJPEG stream server.
103
  </p>
104
- <p class="text-sm">
105
- There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU,
106
- affecting real-time performance. Maximum queue size is 4.
107
- <a
108
- href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
109
- target="_blank"
110
- class="text-blue-500 underline hover:no-underline">Duplicate</a
111
- > and run it on your own GPU.
112
- </p>
 
 
113
  </article>
114
  {#if pipelineParams}
115
  <header>
@@ -122,7 +109,7 @@
122
  > syntax.
123
  </p>
124
  </header>
125
- <PipelineOptions {pipelineParams} bind:pipelineValues></PipelineOptions>
126
  <div class="flex gap-3">
127
  <Button on:click={toggleLcmLive}>
128
  {#if $isLCMRunning}
@@ -135,7 +122,9 @@
135
  </div>
136
 
137
  <ImagePlayer>
138
- <VideoInput></VideoInput>
 
 
139
  </ImagePlayer>
140
  {:else}
141
  <!-- loading -->
 
2
  import { onMount } from 'svelte';
3
  import { PUBLIC_BASE_URL } from '$env/static/public';
4
  import type { FieldProps, PipelineInfo } from '$lib/types';
5
+ import { PipelineMode } from '$lib/types';
6
  import ImagePlayer from '$lib/components/ImagePlayer.svelte';
7
  import VideoInput from '$lib/components/VideoInput.svelte';
8
  import Button from '$lib/components/Button.svelte';
 
15
  isMediaStreaming,
16
  onFrameChangeStore
17
  } from '$lib/mediaStream';
18
+ import { pipelineValues } from '$lib/store';
19
 
20
  let pipelineParams: FieldProps[];
21
  let pipelineInfo: PipelineInfo;
22
+ let isImageMode: boolean = false;
23
+ let maxQueueSize: number = 0;
24
 
25
  onMount(() => {
26
  getSettings();
 
30
  const settings = await fetch(`${PUBLIC_BASE_URL}/settings`).then((r) => r.json());
31
  pipelineParams = Object.values(settings.input_params.properties);
32
  pipelineInfo = settings.info.properties;
33
+ isImageMode = pipelineInfo.input_mode.default === PipelineMode.image;
34
+ maxQueueSize = settings.max_queue_size;
35
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
36
  console.log('PARAMS', pipelineParams);
37
  console.log('SETTINGS', pipelineInfo);
38
  }
39
+ console.log('isImageMode', isImageMode);
40
 
41
+ // send Webcam stream to LCM if image mode
42
+ $: {
43
+ if (isImageMode && $lcmLiveState.status === LCMLiveStatus.CONNECTED) {
44
+ lcmLiveActions.send($pipelineValues);
45
+ lcmLiveActions.send($onFrameChangeStore.blob);
 
 
 
 
 
 
 
 
 
 
 
46
  }
47
+ }
 
 
 
48
 
49
+ // send Webcam stream to LCM
50
+ $: {
51
+ if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
52
+ lcmLiveActions.send($pipelineValues);
53
+ }
54
+ }
55
  async function toggleLcmLive() {
56
  if (!$isLCMRunning) {
57
+ if (isImageMode) {
58
+ await mediaStreamActions.enumerateDevices();
59
+ await mediaStreamActions.start();
60
+ }
61
+ await lcmLiveActions.start();
62
  } else {
63
+ if (isImageMode) {
64
+ mediaStreamActions.stop();
65
+ }
66
  lcmLiveActions.stop();
67
  }
68
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  </script>
70
 
71
  <div class="fixed right-2 top-2 max-w-xs rounded-lg p-4 text-center text-sm font-bold" id="error" />
72
  <main class="container mx-auto flex max-w-4xl flex-col gap-3 px-4 py-4">
73
  <article class="flex- mx-auto max-w-xl text-center">
74
  <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
75
+ <p class="py-2 text-sm">
76
  This demo showcases
77
  <a
78
+ href="https://huggingface.co/blog/lcm_lora"
79
  target="_blank"
80
+ class="text-blue-500 underline hover:no-underline">LCM LoRA</a
81
  >
82
  Image to Image pipeline using
83
  <a
84
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/lcm#performing-inference-with-lcm"
85
  target="_blank"
86
  class="text-blue-500 underline hover:no-underline">Diffusers</a
87
  > with a MJPEG stream server.
88
  </p>
89
+ {#if maxQueueSize > 0}
90
+ <p class="text-sm">
91
+ There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU,
92
+ affecting real-time performance. Maximum queue size is {maxQueueSize}.
93
+ <a
94
+ href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
95
+ target="_blank"
96
+ class="text-blue-500 underline hover:no-underline">Duplicate</a
97
+ > and run it on your own GPU.
98
+ </p>
99
+ {/if}
100
  </article>
101
  {#if pipelineParams}
102
  <header>
 
109
  > syntax.
110
  </p>
111
  </header>
112
+ <PipelineOptions {pipelineParams}></PipelineOptions>
113
  <div class="flex gap-3">
114
  <Button on:click={toggleLcmLive}>
115
  {#if $isLCMRunning}
 
122
  </div>
123
 
124
  <ImagePlayer>
125
+ {#if isImageMode}
126
+ <VideoInput></VideoInput>
127
+ {/if}
128
  </ImagePlayer>
129
  {:else}
130
  <!-- loading -->
pipelines/controlnet.py CHANGED
@@ -28,6 +28,7 @@ class Pipeline:
28
  class Info(BaseModel):
29
  name: str = "txt2img"
30
  description: str = "Generates an image from a text prompt"
 
31
 
32
  class InputParams(BaseModel):
33
  prompt: str = Field(
@@ -125,7 +126,6 @@ class Pipeline:
125
  hide=True,
126
  id="debug_canny",
127
  )
128
- image: bool = True
129
 
130
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
131
  controlnet_canny = ControlNetModel.from_pretrained(
 
28
  class Info(BaseModel):
29
  name: str = "txt2img"
30
  description: str = "Generates an image from a text prompt"
31
+ input_mode: str = "image"
32
 
33
  class InputParams(BaseModel):
34
  prompt: str = Field(
 
126
  hide=True,
127
  id="debug_canny",
128
  )
 
129
 
130
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
131
  controlnet_canny = ControlNetModel.from_pretrained(
pipelines/txt2img.py CHANGED
@@ -22,6 +22,7 @@ class Pipeline:
22
  class Info(BaseModel):
23
  name: str = "txt2img"
24
  description: str = "Generates an image from a text prompt"
 
25
 
26
  class InputParams(BaseModel):
27
  prompt: str = Field(
@@ -52,9 +53,6 @@ class Pipeline:
52
  hide=True,
53
  id="guidance_scale",
54
  )
55
- image: bool = Field(
56
- True, title="Image", field="checkbox", hide=True, id="image"
57
- )
58
 
59
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
60
  if args.safety_checker:
 
22
  class Info(BaseModel):
23
  name: str = "txt2img"
24
  description: str = "Generates an image from a text prompt"
25
+ input_mode: str = "text"
26
 
27
  class InputParams(BaseModel):
28
  prompt: str = Field(
 
53
  hide=True,
54
  id="guidance_scale",
55
  )
 
 
 
56
 
57
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
58
  if args.safety_checker:
static/controlnet.html DELETED
@@ -1,427 +0,0 @@
1
- <!doctype html>
2
- <html>
3
-
4
- <head>
5
- <meta charset="UTF-8">
6
- <title>Real-Time Latent Consistency Model ControlNet</title>
7
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
- <script
9
- src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
10
- <script src="https://cdn.jsdelivr.net/npm/[email protected]/piexif.min.js"></script>
11
- <script src="https://cdn.tailwindcss.com"></script>
12
- <style type="text/tailwindcss">
13
- .button {
14
- @apply bg-gray-700 hover:bg-gray-800 text-white font-normal p-2 rounded disabled:bg-gray-300 dark:disabled:bg-gray-700 disabled:cursor-not-allowed dark:disabled:text-black
15
- }
16
- </style>
17
- <script type="module">
18
- const getValue = (id) => {
19
- const el = document.querySelector(`${id}`)
20
- if (el.type === "checkbox")
21
- return el.checked;
22
- return el.value;
23
- }
24
- const startBtn = document.querySelector("#start");
25
- const stopBtn = document.querySelector("#stop");
26
- const videoEl = document.querySelector("#webcam");
27
- const imageEl = document.querySelector("#player");
28
- const queueSizeEl = document.querySelector("#queue_size");
29
- const errorEl = document.querySelector("#error");
30
- const snapBtn = document.querySelector("#snap");
31
- const webcamsEl = document.querySelector("#webcams");
32
-
33
- function LCMLive(webcamVideo, liveImage) {
34
- let websocket;
35
-
36
- async function start() {
37
- return new Promise((resolve, reject) => {
38
- const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
39
- }:${window.location.host}/ws`;
40
-
41
- const socket = new WebSocket(websocketURL);
42
- socket.onopen = () => {
43
- console.log("Connected to websocket");
44
- };
45
- socket.onclose = () => {
46
- console.log("Disconnected from websocket");
47
- stop();
48
- resolve({ "status": "disconnected" });
49
- };
50
- socket.onerror = (err) => {
51
- console.error(err);
52
- reject(err);
53
- };
54
- socket.onmessage = (event) => {
55
- const data = JSON.parse(event.data);
56
- switch (data.status) {
57
- case "success":
58
- break;
59
- case "start":
60
- const userId = data.userId;
61
- initVideoStream(userId);
62
- break;
63
- case "timeout":
64
- stop();
65
- resolve({ "status": "timeout" });
66
- case "error":
67
- stop();
68
- reject(data.message);
69
-
70
- }
71
- };
72
- websocket = socket;
73
- })
74
- }
75
- function switchCamera() {
76
- const constraints = {
77
- audio: false,
78
- video: { width: 1024, height: 1024, deviceId: mediaDevices[webcamsEl.value].deviceId }
79
- };
80
- navigator.mediaDevices
81
- .getUserMedia(constraints)
82
- .then((mediaStream) => {
83
- webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
84
- webcamVideo.srcObject = mediaStream;
85
- webcamVideo.onloadedmetadata = () => {
86
- webcamVideo.play();
87
- webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
88
- };
89
- })
90
- .catch((err) => {
91
- console.error(`${err.name}: ${err.message}`);
92
- });
93
- }
94
-
95
- async function videoTimeUpdateHandler() {
96
- const dimension = getValue("input[name=dimension]:checked");
97
- const [WIDTH, HEIGHT] = JSON.parse(dimension);
98
-
99
- const canvas = new OffscreenCanvas(WIDTH, HEIGHT);
100
- const videoW = webcamVideo.videoWidth;
101
- const videoH = webcamVideo.videoHeight;
102
- const aspectRatio = WIDTH / HEIGHT;
103
-
104
- const ctx = canvas.getContext("2d");
105
- ctx.drawImage(webcamVideo, videoW / 2 - videoH * aspectRatio / 2, 0, videoH * aspectRatio, videoH, 0, 0, WIDTH, HEIGHT)
106
- const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
107
- websocket.send(blob);
108
- websocket.send(JSON.stringify({
109
- "seed": getValue("#seed"),
110
- "prompt": getValue("#prompt"),
111
- "guidance_scale": getValue("#guidance-scale"),
112
- "strength": getValue("#strength"),
113
- "steps": getValue("#steps"),
114
- "lcm_steps": getValue("#lcm_steps"),
115
- "width": WIDTH,
116
- "height": HEIGHT,
117
- "controlnet_scale": getValue("#controlnet_scale"),
118
- "controlnet_start": getValue("#controlnet_start"),
119
- "controlnet_end": getValue("#controlnet_end"),
120
- "canny_low_threshold": getValue("#canny_low_threshold"),
121
- "canny_high_threshold": getValue("#canny_high_threshold"),
122
- "debug_canny": getValue("#debug_canny")
123
- }));
124
- }
125
- let mediaDevices = [];
126
- async function initVideoStream(userId) {
127
- liveImage.src = `/stream/${userId}`;
128
- await navigator.mediaDevices.enumerateDevices()
129
- .then(devices => {
130
- const cameras = devices.filter(device => device.kind === 'videoinput');
131
- mediaDevices = cameras;
132
- webcamsEl.innerHTML = "";
133
- cameras.forEach((camera, index) => {
134
- const option = document.createElement("option");
135
- option.value = index;
136
- option.innerText = camera.label;
137
- webcamsEl.appendChild(option);
138
- option.selected = index === 0;
139
- });
140
- webcamsEl.addEventListener("change", switchCamera);
141
- })
142
- .catch(err => {
143
- console.error(err);
144
- });
145
- const constraints = {
146
- audio: false,
147
- video: { width: 1024, height: 1024, deviceId: mediaDevices[0].deviceId }
148
- };
149
- navigator.mediaDevices
150
- .getUserMedia(constraints)
151
- .then((mediaStream) => {
152
- webcamVideo.srcObject = mediaStream;
153
- webcamVideo.onloadedmetadata = () => {
154
- webcamVideo.play();
155
- webcamVideo.addEventListener("timeupdate", videoTimeUpdateHandler);
156
- };
157
- })
158
- .catch((err) => {
159
- console.error(`${err.name}: ${err.message}`);
160
- });
161
- }
162
-
163
-
164
- async function stop() {
165
- websocket.close();
166
- navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
167
- mediaStream.getTracks().forEach((track) => track.stop());
168
- });
169
- webcamVideo.removeEventListener("timeupdate", videoTimeUpdateHandler);
170
- webcamsEl.removeEventListener("change", switchCamera);
171
- webcamVideo.srcObject = null;
172
- }
173
- return {
174
- start,
175
- stop
176
- }
177
- }
178
- function toggleMessage(type) {
179
- errorEl.hidden = false;
180
- errorEl.scrollIntoView();
181
- switch (type) {
182
- case "error":
183
- errorEl.innerText = "To many users are using the same GPU, please try again later.";
184
- errorEl.classList.toggle("bg-red-300", "text-red-900");
185
- break;
186
- case "success":
187
- errorEl.innerText = "Your session has ended, please start a new one.";
188
- errorEl.classList.toggle("bg-green-300", "text-green-900");
189
- break;
190
- }
191
- setTimeout(() => {
192
- errorEl.hidden = true;
193
- }, 2000);
194
- }
195
- function snapImage() {
196
- try {
197
- const zeroth = {};
198
- const exif = {};
199
- const gps = {};
200
- zeroth[piexif.ImageIFD.Make] = "LCM Image-to-Image ControNet";
201
- zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${getValue("#prompt")} | seed: ${getValue("#seed")} | guidance_scale: ${getValue("#guidance-scale")} | strength: ${getValue("#strength")} | controlnet_start: ${getValue("#controlnet_start")} | controlnet_end: ${getValue("#controlnet_end")} | lcm_steps: ${getValue("#lcm_steps")} | steps: ${getValue("#steps")}`;
202
- zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
203
- exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
204
-
205
- const exifObj = { "0th": zeroth, "Exif": exif, "GPS": gps };
206
- const exifBytes = piexif.dump(exifObj);
207
-
208
- const canvas = document.createElement("canvas");
209
- canvas.width = imageEl.naturalWidth;
210
- canvas.height = imageEl.naturalHeight;
211
- const ctx = canvas.getContext("2d");
212
- ctx.drawImage(imageEl, 0, 0);
213
- const dataURL = canvas.toDataURL("image/jpeg");
214
- const withExif = piexif.insert(exifBytes, dataURL);
215
-
216
- const a = document.createElement("a");
217
- a.href = withExif;
218
- a.download = `lcm_txt_2_img${Date.now()}.png`;
219
- a.click();
220
- } catch (err) {
221
- console.log(err);
222
- }
223
- }
224
-
225
-
226
- const lcmLive = LCMLive(videoEl, imageEl);
227
- startBtn.addEventListener("click", async () => {
228
- try {
229
- startBtn.disabled = true;
230
- snapBtn.disabled = false;
231
- const res = await lcmLive.start();
232
- startBtn.disabled = false;
233
- if (res.status === "timeout")
234
- toggleMessage("success")
235
- } catch (err) {
236
- console.log(err);
237
- toggleMessage("error")
238
- startBtn.disabled = false;
239
- }
240
- });
241
- stopBtn.addEventListener("click", () => {
242
- lcmLive.stop();
243
- });
244
- window.addEventListener("beforeunload", () => {
245
- lcmLive.stop();
246
- });
247
- snapBtn.addEventListener("click", snapImage);
248
- setInterval(() =>
249
- fetch("/queue_size")
250
- .then((res) => res.json())
251
- .then((data) => {
252
- queueSizeEl.innerText = data.queue_size;
253
- })
254
- .catch((err) => {
255
- console.log(err);
256
- })
257
- , 5000);
258
- </script>
259
- </head>
260
-
261
- <body class="text-black dark:bg-gray-900 dark:text-white">
262
- <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
263
- </div>
264
- <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
265
- <article class="text-center max-w-xl mx-auto">
266
- <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
267
- <h2 class="text-2xl font-bold mb-4">ControlNet</h2>
268
- <p class="text-sm">
269
- This demo showcases
270
- <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
271
- class="text-blue-500 underline hover:no-underline">LCM</a> Image to Image pipeline
272
- using
273
- <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
274
- target="_blank" class="text-blue-500 underline hover:no-underline">Diffusers</a> with a MJPEG
275
- stream server.
276
- </p>
277
- <p class="text-sm">
278
- There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
279
- real-time performance. Maximum queue size is 4. <a
280
- href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
281
- target="_blank" class="text-blue-500 underline hover:no-underline">Duplicate</a> and run it on your
282
- own GPU.
283
- </p>
284
- </article>
285
- <div>
286
- <h2 class="font-medium">Prompt</h2>
287
- <p class="text-sm text-gray-500">
288
- Change the prompt to generate different images, accepts <a
289
- href="https://github.com/damian0815/compel/blob/main/doc/syntax.md" target="_blank"
290
- class="text-blue-500 underline hover:no-underline">Compel</a> syntax.
291
- </p>
292
- <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
293
- <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 outline-none dark:text-black"
294
- title="Prompt, this is an example, feel free to modify"
295
- placeholder="Add your prompt here...">Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5, cinematic, masterpiece</textarea>
296
- </div>
297
- </div>
298
- <div class="">
299
- <details>
300
- <summary class="font-medium cursor-pointer">Advanced Options</summary>
301
- <div class="grid grid-cols-3 sm:grid-cols-6 items-center gap-3 py-3">
302
- <label for="webcams" class="text-sm font-medium">Camera Options: </label>
303
- <select id="webcams" class="text-sm border-2 border-gray-500 rounded-md font-light dark:text-black">
304
- </select>
305
- <div></div>
306
- <label class="text-sm font-medium " for="steps">Inference Steps
307
- </label>
308
- <input type="range" id="steps" name="steps" min="1" max="20" value="4"
309
- oninput="this.nextElementSibling.value = Number(this.value)">
310
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
311
- 4</output>
312
- <!-- -->
313
- <label class="text-sm font-medium" for="lcm_steps">LCM Inference Steps
314
- </label>
315
- <input type="range" id="lcm_steps" name="lcm_steps" min="2" max="60" value="50"
316
- oninput="this.nextElementSibling.value = Number(this.value)">
317
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
318
- 50</output>
319
- <!-- -->
320
- <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
321
- </label>
322
- <input type="range" id="guidance-scale" name="guidance-scale" min="0" max="30" step="0.001"
323
- value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
324
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
325
- 8.0</output>
326
- <!-- -->
327
- <label class="text-sm font-medium" for="strength">Strength</label>
328
- <input type="range" id="strength" name="strength" min="0.1" max="1" step="0.001" value="0.50"
329
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
330
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
331
- 0.5</output>
332
- <!-- -->
333
- <label class="text-sm font-medium" for="controlnet_scale">ControlNet Condition Scale</label>
334
- <input type="range" id="controlnet_scale" name="controlnet_scale" min="0.0" max="1" step="0.001"
335
- value="0.80" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
336
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
337
- 0.8</output>
338
- <!-- -->
339
- <label class="text-sm font-medium" for="controlnet_start">ControlNet Guidance Start</label>
340
- <input type="range" id="controlnet_start" name="controlnet_start" min="0.0" max="1.0" step="0.001"
341
- value="0.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
342
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
343
- 0.0</output>
344
- <!-- -->
345
- <label class="text-sm font-medium" for="controlnet_end">ControlNet Guidance End</label>
346
- <input type="range" id="controlnet_end" name="controlnet_end" min="0.0" max="1.0" step="0.001"
347
- value="1.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
348
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
349
- 1.0</output>
350
- <!-- -->
351
- <label class="text-sm font-medium" for="canny_low_threshold">Canny Low Threshold</label>
352
- <input type="range" id="canny_low_threshold" name="canny_low_threshold" min="0.0" max="1.0"
353
- step="0.001" value="0.1"
354
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
355
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
356
- 0.1</output>
357
- <!-- -->
358
- <label class="text-sm font-medium" for="canny_high_threshold">Canny High Threshold</label>
359
- <input type="range" id="canny_high_threshold" name="canny_high_threshold" min="0.0" max="1.0"
360
- step="0.001" value="0.2"
361
- oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
362
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
363
- 0.2</output>
364
- <!-- -->
365
- <label class="text-sm font-medium" for="seed">Seed</label>
366
- <input type="number" id="seed" name="seed" value="299792458"
367
- class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
368
- <button
369
- onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
370
- class="button">
371
- Rand
372
- </button>
373
- <!-- -->
374
- <!-- -->
375
- <label class="text-sm font-medium" for="dimension">Image Dimensions</label>
376
- <div class="col-span-2 flex gap-2">
377
- <div class="flex gap-1">
378
- <input type="radio" id="dimension512" name="dimension" value="[512,512]" checked
379
- class="cursor-pointer">
380
- <label for="dimension512" class="text-sm cursor-pointer">512x512</label>
381
- </div>
382
- <div class="flex gap-1">
383
- <input type="radio" id="dimension768" name="dimension" value="[768,768]"
384
- lass="cursor-pointer">
385
- <label for="dimension768" class="text-sm cursor-pointer">768x768</label>
386
- </div>
387
- </div>
388
- <!-- -->
389
- <!-- -->
390
- <label class="text-sm font-medium" for="debug_canny">Debug Canny</label>
391
- <div class="col-span-2 flex gap-2">
392
- <input type="checkbox" id="debug_canny" name="debug_canny" class="cursor-pointer">
393
- <label for="debug_canny" class="text-sm cursor-pointer"></label>
394
- </div>
395
- <div></div>
396
- <!-- -->
397
- </div>
398
- </details>
399
- </div>
400
- <div class="flex gap-3">
401
- <button id="start" class="button">
402
- Start
403
- </button>
404
- <button id="stop" class="button">
405
- Stop
406
- </button>
407
- <button id="snap" disabled class="button ml-auto">
408
- Snapshot
409
- </button>
410
- </div>
411
- <div class="relative rounded-lg border border-slate-300 overflow-hidden">
412
- <img id="player" class="w-full aspect-square rounded-lg"
413
- src="">
414
- <div class="absolute top-0 left-0 w-1/4 aspect-square">
415
- <video id="webcam" class="w-full aspect-square relative z-10 object-cover" playsinline autoplay muted
416
- loop></video>
417
- <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 448" width="100"
418
- class="w-full p-4 absolute top-0 opacity-20 z-0">
419
- <path fill="currentColor"
420
- d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z" />
421
- </svg>
422
- </div>
423
- </div>
424
- </main>
425
- </body>
426
-
427
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
static/txt2img.html DELETED
@@ -1,304 +0,0 @@
1
- <!doctype html>
2
- <html>
3
-
4
- <head>
5
- <meta charset="UTF-8">
6
- <title>Real-Time Latent Consistency Model</title>
7
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
- <script
9
- src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
10
- <script src="https://cdn.jsdelivr.net/npm/[email protected]/piexif.min.js"></script>
11
- <script src="https://cdn.tailwindcss.com"></script>
12
- <style type="text/tailwindcss">
13
- .button {
14
- @apply bg-gray-700 hover:bg-gray-800 text-white font-normal p-2 rounded disabled:bg-gray-300 dark:disabled:bg-gray-700 disabled:cursor-not-allowed dark:disabled:text-black
15
- }
16
- </style>
17
- <script type="module">
18
- const getValue = (id) => {
19
- const el = document.querySelector(`${id}`)
20
- if (el.type === "checkbox")
21
- return el.checked;
22
- return el.value;
23
- }
24
- const startBtn = document.querySelector("#start");
25
- const stopBtn = document.querySelector("#stop");
26
- const videoEl = document.querySelector("#webcam");
27
- const imageEl = document.querySelector("#player");
28
- const queueSizeEl = document.querySelector("#queue_size");
29
- const errorEl = document.querySelector("#error");
30
- const snapBtn = document.querySelector("#snap");
31
- const paramsEl = document.querySelector("#params");
32
- const promptEl = document.querySelector("#prompt");
33
- paramsEl.addEventListener("submit", (e) => e.preventDefault());
34
- function LCMLive(promptEl, paramsEl, liveImage) {
35
- let websocket;
36
-
37
- async function start() {
38
- return new Promise((resolve, reject) => {
39
- const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
40
- }:${window.location.host}/ws`;
41
-
42
- const socket = new WebSocket(websocketURL);
43
- socket.onopen = () => {
44
- console.log("Connected to websocket");
45
- };
46
- socket.onclose = () => {
47
- console.log("Disconnected from websocket");
48
- stop();
49
- resolve({ "status": "disconnected" });
50
- };
51
- socket.onerror = (err) => {
52
- console.error(err);
53
- reject(err);
54
- };
55
- socket.onmessage = (event) => {
56
- const data = JSON.parse(event.data);
57
- switch (data.status) {
58
- case "success":
59
- break;
60
- case "start":
61
- const userId = data.userId;
62
- initPromptStream(userId);
63
- break;
64
- case "timeout":
65
- stop();
66
- resolve({ "status": "timeout" });
67
- case "error":
68
- stop();
69
- reject(data.message);
70
- }
71
- };
72
- websocket = socket;
73
- })
74
- }
75
-
76
- async function promptUpdateStream(e) {
77
- const dimension = getValue("input[name=dimension]:checked");
78
- const [WIDTH, HEIGHT] = JSON.parse(dimension);
79
- websocket.send(JSON.stringify({
80
- "seed": getValue("#seed"),
81
- "prompt": getValue("#prompt"),
82
- "guidance_scale": getValue("#guidance-scale"),
83
- "steps": getValue("#steps"),
84
- "lcm_steps": getValue("#lcm_steps"),
85
- "width": WIDTH,
86
- "height": HEIGHT,
87
- }));
88
- }
89
- function debouceInput(fn, delay) {
90
- let timer;
91
- return function (...args) {
92
- clearTimeout(timer);
93
- timer = setTimeout(() => {
94
- fn(...args);
95
- }, delay);
96
- }
97
- }
98
- const debouncedInput = debouceInput(promptUpdateStream, 200);
99
- function initPromptStream(userId) {
100
- liveImage.src = `/stream/${userId}`;
101
- paramsEl.addEventListener("change", debouncedInput);
102
- promptEl.addEventListener("input", debouncedInput);
103
- }
104
-
105
- async function stop() {
106
- websocket.close();
107
- paramsEl.removeEventListener("change", debouncedInput);
108
- promptEl.removeEventListener("input", debouncedInput);
109
- }
110
- return {
111
- start,
112
- stop
113
- }
114
- }
115
- function toggleMessage(type) {
116
- errorEl.hidden = false;
117
- errorEl.scrollIntoView();
118
- switch (type) {
119
- case "error":
120
- errorEl.innerText = "To many users are using the same GPU, please try again later.";
121
- errorEl.classList.toggle("bg-red-300", "text-red-900");
122
- break;
123
- case "success":
124
- errorEl.innerText = "Your session has ended, please start a new one.";
125
- errorEl.classList.toggle("bg-green-300", "text-green-900");
126
- break;
127
- }
128
- setTimeout(() => {
129
- errorEl.hidden = true;
130
- }, 2000);
131
- }
132
- function snapImage() {
133
- try {
134
- const zeroth = {};
135
- const exif = {};
136
- const gps = {};
137
- zeroth[piexif.ImageIFD.Make] = "LCM Text-to-Image";
138
- zeroth[piexif.ImageIFD.ImageDescription] = `prompt: ${getValue("#prompt")} | seed: ${getValue("#seed")} | guidance_scale: ${getValue("#guidance-scale")} | lcm_steps: ${getValue("#lcm_steps")} | steps: ${getValue("#steps")}`;
139
- zeroth[piexif.ImageIFD.Software] = "https://github.com/radames/Real-Time-Latent-Consistency-Model";
140
-
141
- exif[piexif.ExifIFD.DateTimeOriginal] = new Date().toISOString();
142
-
143
- const exifObj = { "0th": zeroth, "Exif": exif, "GPS": gps };
144
- const exifBytes = piexif.dump(exifObj);
145
-
146
- const canvas = document.createElement("canvas");
147
- canvas.width = imageEl.naturalWidth;
148
- canvas.height = imageEl.naturalHeight;
149
- const ctx = canvas.getContext("2d");
150
- ctx.drawImage(imageEl, 0, 0);
151
- const dataURL = canvas.toDataURL("image/jpeg");
152
- const withExif = piexif.insert(exifBytes, dataURL);
153
-
154
- const a = document.createElement("a");
155
- a.href = withExif;
156
- a.download = `lcm_txt_2_img${Date.now()}.png`;
157
- a.click();
158
- } catch (err) {
159
- console.log(err);
160
- }
161
- }
162
-
163
-
164
- const lcmLive = LCMLive(promptEl, paramsEl, imageEl);
165
- startBtn.addEventListener("click", async () => {
166
- try {
167
- startBtn.disabled = true;
168
- snapBtn.disabled = false;
169
- const res = await lcmLive.start();
170
- startBtn.disabled = false;
171
- if (res.status === "timeout")
172
- toggleMessage("success")
173
- } catch (err) {
174
- console.log(err);
175
- toggleMessage("error")
176
- startBtn.disabled = false;
177
- }
178
- });
179
- stopBtn.addEventListener("click", () => {
180
- lcmLive.stop();
181
- });
182
- window.addEventListener("beforeunload", () => {
183
- lcmLive.stop();
184
- });
185
- snapBtn.addEventListener("click", snapImage);
186
- setInterval(() =>
187
- fetch("/queue_size")
188
- .then((res) => res.json())
189
- .then((data) => {
190
- queueSizeEl.innerText = data.queue_size;
191
- })
192
- .catch((err) => {
193
- console.log(err);
194
- })
195
- , 5000);
196
- </script>
197
- </head>
198
-
199
- <body class="text-black dark:bg-gray-900 dark:text-white">
200
- <div class="fixed right-2 top-2 p-4 font-bold text-sm rounded-lg max-w-xs text-center" id="error">
201
- </div>  
202
- <main class="container mx-auto px-4 py-4 max-w-4xl flex flex-col gap-4">
203
- <article class="text-center max-w-xl mx-auto">
204
- <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
205
- <h2 class="text-2xl font-bold mb-4">Text to Image</h2>
206
- <p class="text-sm">
207
- This demo showcases
208
- <a href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7" target="_blank"
209
- class="text-blue-500 underline hover:no-underline">LCM</a> Text to Image model
210
- using
211
- <a href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
212
- target="_blank" class="text-blue-500 underline hover:no-underline">Diffusers</a> with a MJPEG
213
- stream server.
214
- </p>
215
- <p class="text-sm">
216
- There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
217
- real-time performance. Maximum queue size is 10. <a
218
- href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
219
- target="_blank" class="text-blue-500 underline hover:no-underline">Duplicate</a> and run it on your
220
- own GPU.
221
- </p>
222
- </article>
223
- <div>
224
- <h2 class="font-medium">Prompt</h2>
225
- <p class="text-sm text-gray-500 dark:text-gray-400">
226
- Start your session and type your prompt here, accepts
227
- <a href="https://github.com/damian0815/compel/blob/main/doc/syntax.md" target="_blank"
228
- class="text-blue-500 underline hover:no-underline">Compel</a> syntax.
229
- </p>
230
- <div class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
231
- <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 outline-none dark:text-black"
232
- title=" Start your session and type your prompt here, you can see the result in real-time."
233
- placeholder="Add your prompt here...">Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5, cinematic, masterpiece</textarea>
234
- </div>
235
-
236
- </div>
237
- <div class="">
238
- <details>
239
- <summary class="font-medium cursor-pointer">Advanced Options</summary>
240
- <form class="grid grid-cols-3 items-center gap-3 py-3" id="params" action="">
241
- <label class="text-sm font-medium" for="dimension">Image Dimensions</label>
242
- <div class="col-span-2 flex gap-2">
243
- <div class="flex gap-1">
244
- <input type="radio" id="dimension512" name="dimension" value="[512,512]" checked
245
- class="cursor-pointer">
246
- <label for="dimension512" class="text-sm cursor-pointer">512x512</label>
247
- </div>
248
- <div class="flex gap-1">
249
- <input type="radio" id="dimension768" name="dimension" value="[768,768]"
250
- lass="cursor-pointer">
251
- <label for="dimension768" class="text-sm cursor-pointer">768x768</label>
252
- </div>
253
- </div>
254
- <!-- -->
255
- <label class="text-sm font-medium " for="steps">Inference Steps
256
- </label>
257
- <input type="range" id="steps" name="steps" min="1" max="20" value="4"
258
- oninput="this.nextElementSibling.value = Number(this.value)">
259
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
260
- 4</output>
261
- <!-- -->
262
- <label class="text-sm font-medium" for="lcm_steps">LCM Inference Steps
263
- </label>
264
- <input type="range" id="lcm_steps" name="lcm_steps" min="2" max="60" value="50"
265
- oninput="this.nextElementSibling.value = Number(this.value)">
266
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
267
- 50</output>
268
- <!-- -->
269
- <label class="text-sm font-medium" for="guidance-scale">Guidance Scale
270
- </label>
271
- <input type="range" id="guidance-scale" name="guidance-scale" min="0" max="30" step="0.001"
272
- value="8.0" oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)">
273
- <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
274
- 8.0</output>
275
- <!-- -->
276
- <label class="text-sm font-medium" for="seed">Seed</label>
277
- <input type="number" id="seed" name="seed" value="299792458"
278
- class="font-light border border-gray-700 text-right rounded-md p-2 dark:text-black">
279
- <button class="button" onclick="document.querySelector('#seed').value = Math.floor(Math.random() * 1000000000); document.querySelector('#params').dispatchEvent(new Event('change'))">
280
- Rand
281
- </button>
282
- <!-- -->
283
- </form>
284
- </details>
285
- </div>
286
- <div class="flex gap-3">
287
- <button id="start" class="button">
288
- Start
289
- </button>
290
- <button id="stop" class="button">
291
- Stop
292
- </button>
293
- <button id="snap" disabled class="button ml-auto">
294
- Snapshot
295
- </button>
296
- </div>
297
- <div class="relative rounded-lg border border-slate-300 overflow-hidden">
298
- <img id="player" class="w-full aspect-square rounded-lg"
299
- src="">
300
- </div>
301
- </main>
302
- </body>
303
-
304
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util.py CHANGED
@@ -2,6 +2,8 @@ from importlib import import_module
2
  from types import ModuleType
3
  from typing import Dict, Any
4
  from pydantic import BaseModel as PydanticBaseModel, Field
 
 
5
 
6
 
7
  def get_pipeline_class(pipeline_name: str) -> ModuleType:
@@ -16,3 +18,20 @@ def get_pipeline_class(pipeline_name: str) -> ModuleType:
16
  raise ValueError(f"'Pipeline' class not found in module '{pipeline_name}'.")
17
 
18
  return pipeline_class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from types import ModuleType
3
  from typing import Dict, Any
4
  from pydantic import BaseModel as PydanticBaseModel, Field
5
+ from PIL import Image
6
+ import io
7
 
8
 
9
  def get_pipeline_class(pipeline_name: str) -> ModuleType:
 
18
  raise ValueError(f"'Pipeline' class not found in module '{pipeline_name}'.")
19
 
20
  return pipeline_class
21
+
22
+
23
+ def pil_to_frame(image: Image.Image) -> bytes:
24
+ frame_data = io.BytesIO()
25
+ image.save(frame_data, format="JPEG")
26
+ frame_data = frame_data.getvalue()
27
+ return (
28
+ b"--frame\r\n"
29
+ + b"Content-Type: image/jpeg\r\n"
30
+ + f"Content-Length: {len(frame_data)}\r\n\r\n".encode()
31
+ + frame_data
32
+ + b"\r\n"
33
+ )
34
+
35
+
36
+ def is_firefox(user_agent: str) -> bool:
37
+ return "Firefox" in user_agent