yucornetto commited on
Commit
995325f
1 Parent(s): 51a2c42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -44,19 +44,21 @@ titok_generator = titok_generator.to(device)
44
 
45
 
46
  @spaces.GPU
47
- def demo_infer(guidance_scale, randomize_temperature, num_sample_steps,
 
 
48
  class_label, seed):
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
- titok_tokenizer = titok_tokenizer.to(device)
51
- titok_generator = titok_generator.to(device)
52
  n = 4
53
  class_labels = [class_label for _ in range(n)]
54
  torch.manual_seed(seed)
55
  torch.cuda.manual_seed(seed)
56
  t1 = time.time()
57
  generated_image = demo_util.sample_fn(
58
- generator=titok_generator,
59
- tokenizer=titok_tokenizer,
60
  labels=class_labels,
61
  guidance_scale=guidance_scale,
62
  randomize_temperature=randomize_temperature,
@@ -90,6 +92,7 @@ with gr.Blocks() as demo:
90
  with gr.Column():
91
  output = gr.Gallery(label='Generated Images', height=700)
92
  button.click(demo_infer, inputs=[
 
93
  guidance_scale, randomize_temperature, num_sample_steps,
94
  i1k_class, seed],
95
  outputs=[output])
 
44
 
45
 
46
  @spaces.GPU
47
+ def demo_infer(tokenizer,
48
+ generator,
49
+ guidance_scale, randomize_temperature, num_sample_steps,
50
  class_label, seed):
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ tokenizer = tokenizer.to(device)
53
+ generator = generator.to(device)
54
  n = 4
55
  class_labels = [class_label for _ in range(n)]
56
  torch.manual_seed(seed)
57
  torch.cuda.manual_seed(seed)
58
  t1 = time.time()
59
  generated_image = demo_util.sample_fn(
60
+ generator=generator,
61
+ tokenizer=tokenizer,
62
  labels=class_labels,
63
  guidance_scale=guidance_scale,
64
  randomize_temperature=randomize_temperature,
 
92
  with gr.Column():
93
  output = gr.Gallery(label='Generated Images', height=700)
94
  button.click(demo_infer, inputs=[
95
+ titok_tokenizer, titok_generator,
96
  guidance_scale, randomize_temperature, num_sample_steps,
97
  i1k_class, seed],
98
  outputs=[output])