yucornetto commited on
Commit
a331dda
1 Parent(s): 15a2a80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -45,15 +45,15 @@ def load_model():
45
  titok_generator = titok_generator.to(device)
46
  return titok_tokenizer, titok_generator
47
 
 
48
 
49
  @spaces.GPU
50
- def demo_infer_(tokenizer,
51
- generator,
52
  guidance_scale, randomize_temperature, num_sample_steps,
53
  class_label, seed):
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
- tokenizer = tokenizer.to(device)
56
- generator = generator.to(device)
57
  n = 4
58
  class_labels = [class_label for _ in range(n)]
59
  torch.manual_seed(seed)
@@ -73,8 +73,6 @@ def demo_infer_(tokenizer,
73
  samples = [Image.fromarray(sample) for sample in generated_image]
74
  return samples
75
 
76
- titok_tokenizer, titok_generator = load_model()
77
- demo_infer = partial(demo_infer_, tokenizer=titok_tokenizer, generator=titok_generator)
78
  with gr.Blocks() as demo:
79
  gr.Markdown("<h1 style='text-align: center'>An Image is Worth 32 Tokens for Reconstruction and Generation</h1>")
80
 
 
45
  titok_generator = titok_generator.to(device)
46
  return titok_tokenizer, titok_generator
47
 
48
+ titok_tokenizer, titok_generator = load_model()
49
 
50
  @spaces.GPU
51
+ def demo_infer(
 
52
  guidance_scale, randomize_temperature, num_sample_steps,
53
  class_label, seed):
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ tokenizer = titok_tokenizer.to(device)
56
+ generator = titok_generator.to(device)
57
  n = 4
58
  class_labels = [class_label for _ in range(n)]
59
  torch.manual_seed(seed)
 
73
  samples = [Image.fromarray(sample) for sample in generated_image]
74
  return samples
75
 
 
 
76
  with gr.Blocks() as demo:
77
  gr.Markdown("<h1 style='text-align: center'>An Image is Worth 32 Tokens for Reconstruction and Generation</h1>")
78