amildravid4292 commited on
Commit
4e07682
·
verified ·
1 Parent(s): 82639a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -27,15 +27,15 @@ device = "cuda:0"
27
  generator = torch.Generator(device=device)
28
 
29
 
30
- models_path = snapshot_download(repo_id="Snapchat/w2w/files")
31
-
32
- mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
33
- std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
34
- v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
35
- proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
36
- df = torch.load(f"{models_path}/identity_df.pt")
37
- weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
38
- pinverse = torch.load(f"{models_path}/pinverse_1000pc.pt").bfloat16().to(device)
39
 
40
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
41
  global network
 
27
  generator = torch.Generator(device=device)
28
 
29
 
30
+ models_path = snapshot_download(repo_id="Snapchat/w2w")
31
+
32
+ mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
33
+ std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device)
34
+ v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device)
35
+ proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device)
36
+ df = torch.load(f"{models_path}/files/identity_df.pt")
37
+ weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
38
+ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
39
 
40
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
41
  global network