radames commited on
Commit
249f661
1 Parent(s): 9796138

pass extra width height

Browse files
Files changed (4) hide show
  1. app-img2img.py +20 -18
  2. app-txt2img.py +19 -17
  3. img2img/index.html +3 -2
  4. requirements.txt +1 -0
app-img2img.py CHANGED
@@ -21,10 +21,11 @@ import os
21
  import time
22
  import psutil
23
 
24
-
25
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
26
  TIMEOUT = float(os.environ.get("TIMEOUT", 0))
27
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
 
 
28
 
29
  # check if MPS is available OSX only M1/M2/M3 chips
30
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -56,7 +57,7 @@ else:
56
  custom_revision="main",
57
  )
58
  pipe.vae = AutoencoderTiny.from_pretrained(
59
- "madebyollin/taesd", torch_dtype=torch.float16, use_safetensors=True
60
  )
61
  pipe.set_progress_bar_config(disable=True)
62
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
@@ -77,18 +78,29 @@ compel_proc = Compel(
77
  user_queue_map = {}
78
 
79
 
80
- def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232):
81
- generator = torch.manual_seed(seed)
82
- prompt_embeds = compel_proc(prompt)
 
 
 
 
 
 
 
 
 
83
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
84
  num_inference_steps = 3
85
  results = pipe(
86
  prompt_embeds=prompt_embeds,
87
  generator=generator,
88
  image=input_image,
89
- strength=strength,
90
  num_inference_steps=num_inference_steps,
91
- guidance_scale=guidance_scale,
 
 
92
  lcm_origin_steps=50,
93
  output_type="pil",
94
  )
@@ -112,13 +124,6 @@ app.add_middleware(
112
  )
113
 
114
 
115
- class InputParams(BaseModel):
116
- seed: int
117
- prompt: str
118
- strength: float
119
- guidance_scale: float
120
-
121
-
122
  @app.websocket("/ws")
123
  async def websocket_endpoint(websocket: WebSocket):
124
  await websocket.accept()
@@ -177,10 +182,7 @@ async def stream(user_id: uuid.UUID):
177
 
178
  image = predict(
179
  input_image,
180
- params.prompt,
181
- params.guidance_scale,
182
- params.strength,
183
- params.seed,
184
  )
185
  if image is None:
186
  continue
 
21
  import time
22
  import psutil
23
 
 
24
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
25
  TIMEOUT = float(os.environ.get("TIMEOUT", 0))
26
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
27
+ WIDTH = 512
28
+ HEIGHT = 512
29
 
30
  # check if MPS is available OSX only M1/M2/M3 chips
31
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
 
57
  custom_revision="main",
58
  )
59
  pipe.vae = AutoencoderTiny.from_pretrained(
60
+ "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
61
  )
62
  pipe.set_progress_bar_config(disable=True)
63
  pipe.to(torch_device=torch_device, torch_dtype=torch_dtype).to(device)
 
78
  user_queue_map = {}
79
 
80
 
81
+ class InputParams(BaseModel):
82
+ prompt: str
83
+ seed: int = 2159232
84
+ guidance_scale: float = 8.0
85
+ strength: float = 0.5
86
+ width: int = WIDTH
87
+ height: int = HEIGHT
88
+
89
+
90
+ def predict(input_image: Image.Image, params: InputParams):
91
+ generator = torch.manual_seed(params.seed)
92
+ prompt_embeds = compel_proc(params.prompt)
93
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
94
  num_inference_steps = 3
95
  results = pipe(
96
  prompt_embeds=prompt_embeds,
97
  generator=generator,
98
  image=input_image,
99
+ strength=params.strength,
100
  num_inference_steps=num_inference_steps,
101
+ guidance_scale=params.guidance_scale,
102
+ width=params.width,
103
+ height=params.height,
104
  lcm_origin_steps=50,
105
  output_type="pil",
106
  )
 
124
  )
125
 
126
 
 
 
 
 
 
 
 
127
  @app.websocket("/ws")
128
  async def websocket_endpoint(websocket: WebSocket):
129
  await websocket.accept()
 
182
 
183
  image = predict(
184
  input_image,
185
+ params,
 
 
 
186
  )
187
  if image is None:
188
  continue
app-txt2img.py CHANGED
@@ -25,7 +25,8 @@ import psutil
25
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
26
  TIMEOUT = float(os.environ.get("TIMEOUT", 0))
27
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
28
-
 
29
  # check if MPS is available OSX only M1/M2/M3 chips
30
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -66,9 +67,9 @@ pipe.unet.to(memory_format=torch.channels_last)
66
  if psutil.virtual_memory().total < 64 * 1024**3:
67
  pipe.enable_attention_slicing()
68
 
69
- if not mps_available:
70
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
71
- pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
72
 
73
  compel_proc = Compel(
74
  tokenizer=pipe.tokenizer,
@@ -77,17 +78,25 @@ compel_proc = Compel(
77
  )
78
  user_queue_map = {}
79
 
80
-
81
- def predict(prompt, guidance_scale=8.0, seed=2159232):
82
- generator = torch.manual_seed(seed)
83
- prompt_embeds = compel_proc(prompt)
 
 
 
 
 
 
84
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
85
  num_inference_steps = 8
86
  results = pipe(
87
  prompt_embeds=prompt_embeds,
88
  generator=generator,
89
  num_inference_steps=num_inference_steps,
90
- guidance_scale=guidance_scale,
 
 
91
  lcm_origin_steps=50,
92
  output_type="pil",
93
  )
@@ -110,13 +119,6 @@ app.add_middleware(
110
  allow_headers=["*"],
111
  )
112
 
113
-
114
- class InputParams(BaseModel):
115
- prompt: str
116
- seed: int
117
- guidance_scale: float
118
-
119
-
120
  @app.websocket("/ws")
121
  async def websocket_endpoint(websocket: WebSocket):
122
  await websocket.accept()
@@ -173,7 +175,7 @@ async def stream(user_id: uuid.UUID):
173
  if params is None:
174
  continue
175
 
176
- image = predict(params.prompt, params.guidance_scale, params.seed)
177
  if image is None:
178
  continue
179
  frame_data = io.BytesIO()
 
25
  MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
26
  TIMEOUT = float(os.environ.get("TIMEOUT", 0))
27
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
28
+ WIDTH = 512
29
+ HEIGHT = 512
30
  # check if MPS is available OSX only M1/M2/M3 chips
31
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
67
  if psutil.virtual_memory().total < 64 * 1024**3:
68
  pipe.enable_attention_slicing()
69
 
70
+ # if not mps_available:
71
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
72
+ # pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
73
 
74
  compel_proc = Compel(
75
  tokenizer=pipe.tokenizer,
 
78
  )
79
  user_queue_map = {}
80
 
81
+ class InputParams(BaseModel):
82
+ prompt: str
83
+ seed: int = 2159232
84
+ guidance_scale: float = 8.0
85
+ width: int = WIDTH
86
+ height: int = HEIGHT
87
+
88
+ def predict(params: InputParams):
89
+ generator = torch.manual_seed(params.seed)
90
+ prompt_embeds = compel_proc(params.prompt)
91
  # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
92
  num_inference_steps = 8
93
  results = pipe(
94
  prompt_embeds=prompt_embeds,
95
  generator=generator,
96
  num_inference_steps=num_inference_steps,
97
+ guidance_scale=params.guidance_scale,
98
+ width=params.width,
99
+ height=params.height,
100
  lcm_origin_steps=50,
101
  output_type="pil",
102
  )
 
119
  allow_headers=["*"],
120
  )
121
 
 
 
 
 
 
 
 
122
  @app.websocket("/ws")
123
  async def websocket_endpoint(websocket: WebSocket):
124
  await websocket.accept()
 
175
  if params is None:
176
  continue
177
 
178
+ image = predict(params)
179
  if image is None:
180
  continue
181
  frame_data = io.BytesIO()
img2img/index.html CHANGED
@@ -10,8 +10,9 @@
10
  <script src="https://cdn.jsdelivr.net/npm/[email protected]/piexif.min.js"></script>
11
  <script src="https://cdn.tailwindcss.com"></script>
12
  <script type="module">
13
- const WIDTH = 768;
14
- const HEIGHT = 768;
 
15
  const seedEl = document.querySelector("#seed");
16
  const promptEl = document.querySelector("#prompt");
17
  const guidanceEl = document.querySelector("#guidance-scale");
 
10
  <script src="https://cdn.jsdelivr.net/npm/[email protected]/piexif.min.js"></script>
11
  <script src="https://cdn.tailwindcss.com"></script>
12
  <script type="module">
13
+ // you can change the size of the input image to 768x768 if you have a powerful GPU
14
+ const WIDTH = 512;
15
+ const HEIGHT = 512;
16
  const seedEl = document.querySelector("#seed");
17
  const promptEl = document.querySelector("#prompt");
18
  const guidanceEl = document.querySelector("#guidance-scale");
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  diffusers==0.21.4
2
  transformers==4.34.1
3
  gradio==3.50.2
 
4
  torch==2.1.0
5
  fastapi==0.104.0
6
  uvicorn==0.23.2
 
1
  diffusers==0.21.4
2
  transformers==4.34.1
3
  gradio==3.50.2
4
+ --extra-index-url https://download.pytorch.org/whl/cu121
5
  torch==2.1.0
6
  fastapi==0.104.0
7
  uvicorn==0.23.2