Spaces:
Runtime error
Runtime error
Commit
·
fd16ff8
1
Parent(s):
ab8308a
update space
Browse files
app.py
CHANGED
@@ -257,7 +257,8 @@ def main():
|
|
257 |
elif 'cyber' in instruction:
|
258 |
e_type = 'cyber'
|
259 |
|
260 |
-
model = models[e_type]
|
|
|
261 |
# model = load_model('text300M', device=device)
|
262 |
# with torch.no_grad():
|
263 |
# new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
|
@@ -280,7 +281,7 @@ def main():
|
|
280 |
latent = latent.to(device)
|
281 |
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
|
282 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
283 |
-
ref_latent = latent.clone().unsqueeze(0)
|
284 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
285 |
|
286 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|
|
|
257 |
elif 'cyber' in instruction:
|
258 |
e_type = 'cyber'
|
259 |
|
260 |
+
model = models[e_type]
|
261 |
+
model = model.to(device)
|
262 |
# model = load_model('text300M', device=device)
|
263 |
# with torch.no_grad():
|
264 |
# new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=model.wrapped.input_proj.weight.dtype)
|
|
|
281 |
latent = latent.to(device)
|
282 |
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
|
283 |
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
284 |
+
ref_latent = latent.clone().unsqueeze(0).to(device)
|
285 |
t_1 = torch.randint(noise_start_t_e_type, noise_start_t_e_type + 1, (1,), device=device).long()
|
286 |
|
287 |
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|