amildravid4292 commited on
Commit
ac24ff3
·
verified ·
1 Parent(s): bf723db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -19,15 +19,15 @@ from lora_w2w import LoRAw2w
19
  from huggingface_hub import snapshot_download
20
  import spaces
21
 
22
- global device
23
- global generator
24
- global unet
25
- global vae
26
- global text_encoder
27
- global tokenizer
28
- global noise_scheduler
29
- global network
30
- device = "cuda"
31
  #generator = torch.Generator(device=device)
32
 
33
  models_path = snapshot_download(repo_id="Snapchat/w2w")
@@ -43,10 +43,10 @@ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=to
43
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
44
 
45
 
46
- global young
47
- global pointy
48
- global wavy
49
- global thick
50
 
51
  young = get_direction(df, "Young", pinverse, 1000, device)
52
  young = debias(young, "Male", df, pinverse, device)
 
19
  from huggingface_hub import snapshot_download
20
  import spaces
21
 
22
+
23
+ gr.State(generator)
24
+ gr.State(unet)
25
+ gr.State(vae)
26
+ gr.State(text_encoder)
27
+ gr.State(tokenizer)
28
+ gr.State(noise_scheduler)
29
+ gr.State(network)
30
+ device = gr.State("cuda")
31
  #generator = torch.Generator(device=device)
32
 
33
  models_path = snapshot_download(repo_id="Snapchat/w2w")
 
43
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
44
 
45
 
46
+ gr.State(young)
47
+ gr.State(pointy)
48
+ gr.State(wavy)
49
+ gr.State(thick)
50
 
51
  young = get_direction(df, "Young", pinverse, 1000, device)
52
  young = debias(young, "Male", df, pinverse, device)