yucornetto commited on
Commit
2307701
1 Parent(s): a331dda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -31,9 +31,9 @@ parser.add_argument("--temperature", type=float, default=1.0, help="temperature
31
  args = parser.parse_args()
32
 
33
 
34
- @spaces.GPU
35
  def load_model():
36
- device = "cuda" if torch.cuda.is_available() else "cpu"
37
  config = demo_util.get_config("configs/titok_l32.yaml")
38
  print(config)
39
  titok_tokenizer = demo_util.get_titok_tokenizer(config)
@@ -51,9 +51,9 @@ titok_tokenizer, titok_generator = load_model()
51
  def demo_infer(
52
  guidance_scale, randomize_temperature, num_sample_steps,
53
  class_label, seed):
54
- device = "cuda" if torch.cuda.is_available() else "cpu"
55
- tokenizer = titok_tokenizer.to(device)
56
- generator = titok_generator.to(device)
57
  n = 4
58
  class_labels = [class_label for _ in range(n)]
59
  torch.manual_seed(seed)
@@ -83,7 +83,7 @@ with gr.Blocks() as demo:
83
  with gr.Row():
84
  i1k_class = gr.Dropdown(
85
  list(imagenet_idx2classname.values()),
86
- value='macaw',
87
  type="index", label='ImageNet-1K Class'
88
  )
89
  guidance_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=3.5, label='Classifier-free Guidance Scale')
 
31
  args = parser.parse_args()
32
 
33
 
34
+ # @spaces.GPU
35
  def load_model():
36
+ device = "cuda" #if torch.cuda.is_available() else "cpu"
37
  config = demo_util.get_config("configs/titok_l32.yaml")
38
  print(config)
39
  titok_tokenizer = demo_util.get_titok_tokenizer(config)
 
51
  def demo_infer(
52
  guidance_scale, randomize_temperature, num_sample_steps,
53
  class_label, seed):
54
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ tokenizer = titok_tokenizer #.to(device)
56
+ generator = titok_generator #.to(device)
57
  n = 4
58
  class_labels = [class_label for _ in range(n)]
59
  torch.manual_seed(seed)
 
83
  with gr.Row():
84
  i1k_class = gr.Dropdown(
85
  list(imagenet_idx2classname.values()),
86
+ value='Chihuahua',
87
  type="index", label='ImageNet-1K Class'
88
  )
89
  guidance_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=3.5, label='Classifier-free Guidance Scale')