owiedotch commited on
Commit
0582ce0
1 Parent(s): c686a8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -129,30 +129,39 @@ def process(
129
  shape = (1, 4, height // 8, width // 8)
130
  x_T = torch.randn(shape, device=device, dtype=torch.float32)
131
 
 
 
 
 
 
 
 
 
 
 
132
  if not tile_diffusion and not tile_vae:
133
  samples = sampler.sample_ccsr(
134
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
135
- positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
136
  cfg_scale=cfg_scale,
137
  color_fix_type="adain" if use_color_fix else "none"
138
  )
139
  else:
140
  if tile_vae:
141
- # Remove this line as ControlLDM doesn't have _init_tiled_vae method
142
- # model._init_tiled_vae(encoder_tile_size=vae_encoder_tile_size, decoder_tile_size=vae_decoder_tile_size)
143
  pass
144
  if tile_diffusion:
145
  samples = sampler.sample_with_tile_ccsr(
146
  tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
147
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
148
- positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
149
  cfg_scale=cfg_scale,
150
  color_fix_type="adain" if use_color_fix else "none"
151
  )
152
  else:
153
  samples = sampler.sample_ccsr(
154
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
155
- positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T,
156
  cfg_scale=cfg_scale,
157
  color_fix_type="adain" if use_color_fix else "none"
158
  )
 
129
  shape = (1, 4, height // 8, width // 8)
130
  x_T = torch.randn(shape, device=device, dtype=torch.float32)
131
 
132
+ # Modify the get_learned_conditioning method to handle the attention mask issue
133
+ def modified_get_learned_conditioning(model, prompt):
134
+ tokens = model.cond_stage_model.tokenizer.encode(prompt)
135
+ tokens = torch.LongTensor(tokens).to(model.device).unsqueeze(0)
136
+ encoder_hidden_states = model.cond_stage_model.transformer(input_ids=tokens).last_hidden_state
137
+ return encoder_hidden_states
138
+
139
+ cond = modified_get_learned_conditioning(model, positive_prompt)
140
+ uncond = modified_get_learned_conditioning(model, negative_prompt)
141
+
142
  if not tile_diffusion and not tile_vae:
143
  samples = sampler.sample_ccsr(
144
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
145
+ positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
146
  cfg_scale=cfg_scale,
147
  color_fix_type="adain" if use_color_fix else "none"
148
  )
149
  else:
150
  if tile_vae:
151
+ # Note: Tiled VAE is not implemented in this version
 
152
  pass
153
  if tile_diffusion:
154
  samples = sampler.sample_with_tile_ccsr(
155
  tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride,
156
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
157
+ positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
158
  cfg_scale=cfg_scale,
159
  color_fix_type="adain" if use_color_fix else "none"
160
  )
161
  else:
162
  samples = sampler.sample_ccsr(
163
  steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control,
164
+ positive_prompt=cond, negative_prompt=uncond, x_T=x_T,
165
  cfg_scale=cfg_scale,
166
  color_fix_type="adain" if use_color_fix else "none"
167
  )