yucornetto commited on
Commit
51a2c42
1 Parent(s): 578e7fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -9,6 +9,7 @@ import time
9
  import argparse
10
  import demo_util
11
  import os
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
@@ -42,8 +43,12 @@ titok_tokenizer = titok_tokenizer.to(device)
42
  titok_generator = titok_generator.to(device)
43
 
44
 
 
45
  def demo_infer(guidance_scale, randomize_temperature, num_sample_steps,
46
  class_label, seed):
 
 
 
47
  n = 4
48
  class_labels = [class_label for _ in range(n)]
49
  torch.manual_seed(seed)
 
9
  import argparse
10
  import demo_util
11
  import os
12
+ import spaces
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
 
43
  titok_generator = titok_generator.to(device)
44
 
45
 
46
+ @spaces.GPU
47
  def demo_infer(guidance_scale, randomize_temperature, num_sample_steps,
48
  class_label, seed):
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ titok_tokenizer = titok_tokenizer.to(device)
51
+ titok_generator = titok_generator.to(device)
52
  n = 4
53
  class_labels = [class_label for _ in range(n)]
54
  torch.manual_seed(seed)