prithivMLmods commited on
Commit
85f0aba
1 Parent(s): 0c1b8f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -2,7 +2,6 @@
2
  import os
3
  import random
4
  import uuid
5
- import json
6
  import gradio as gr
7
  import numpy as np
8
  from PIL import Image
@@ -49,7 +48,7 @@ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
49
 
50
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
 
52
- def load_model(model_id):
53
  pipe = StableDiffusionXLPipeline.from_pretrained(
54
  model_id,
55
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -66,8 +65,8 @@ def load_model(model_id):
66
 
67
  return pipe
68
 
69
- current_model_id = MODEL_OPTIONS["RealVisXL_V4.0_Lightning"]
70
- pipe = load_model(current_model_id)
71
 
72
  MAX_SEED = np.iinfo(np.int32).max
73
 
@@ -97,9 +96,8 @@ def generate(
97
  num_images: int = 1,
98
  progress=gr.Progress(track_tqdm=True),
99
  ):
100
- global pipe
101
- if model_choice != current_model_id:
102
- pipe = load_model(MODEL_OPTIONS[model_choice])
103
 
104
  seed = int(randomize_seed_fn(seed, randomize_seed))
105
  generator = torch.Generator(device=device).manual_seed(seed)
@@ -131,14 +129,7 @@ def generate(
131
 
132
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
133
  gr.Markdown(DESCRIPTIONx)
134
-
135
- with gr.Group():
136
  with gr.Row():
137
- model_choice = gr.Dropdown(
138
- label="Model",
139
- choices=list(MODEL_OPTIONS.keys()),
140
- value="RealVisXL_V4.0_Lightning"
141
- )
142
  prompt = gr.Text(
143
  label="Prompt",
144
  show_label=False,
@@ -149,6 +140,13 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
149
  run_button = gr.Button("Run", scale=0)
150
  result = gr.Gallery(label="Result", columns=1, show_label=False)
151
 
 
 
 
 
 
 
 
152
  with gr.Accordion("Advanced options", open=False, visible=False):
153
  num_images = gr.Slider(
154
  label="Number of Images",
@@ -250,4 +248,4 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
250
  gr.Markdown("⚠️ users are accountable for the content they generate and are responsible for ensuring it meets appropriate ethical standards.")
251
 
252
  if __name__ == "__main__":
253
- demo.queue(max_size=40).launch()
 
2
  import os
3
  import random
4
  import uuid
 
5
  import gradio as gr
6
  import numpy as np
7
  from PIL import Image
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
+ def load_and_prepare_model(model_id):
52
  pipe = StableDiffusionXLPipeline.from_pretrained(
53
  model_id,
54
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
65
 
66
  return pipe
67
 
68
+ # Preload and compile both models
69
+ models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
70
 
71
  MAX_SEED = np.iinfo(np.int32).max
72
 
 
96
  num_images: int = 1,
97
  progress=gr.Progress(track_tqdm=True),
98
  ):
99
+ global models
100
+ pipe = models[model_choice]
 
101
 
102
  seed = int(randomize_seed_fn(seed, randomize_seed))
103
  generator = torch.Generator(device=device).manual_seed(seed)
 
129
 
130
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
131
  gr.Markdown(DESCRIPTIONx)
 
 
132
  with gr.Row():
 
 
 
 
 
133
  prompt = gr.Text(
134
  label="Prompt",
135
  show_label=False,
 
140
  run_button = gr.Button("Run", scale=0)
141
  result = gr.Gallery(label="Result", columns=1, show_label=False)
142
 
143
+ with gr.Row():
144
+ model_choice = gr.Dropdown(
145
+ label="Model",
146
+ choices=list(MODEL_OPTIONS.keys()),
147
+ value="RealVisXL_V4.0_Lightning"
148
+ )
149
+
150
  with gr.Accordion("Advanced options", open=False, visible=False):
151
  num_images = gr.Slider(
152
  label="Number of Images",
 
248
  gr.Markdown("⚠️ users are accountable for the content they generate and are responsible for ensuring it meets appropriate ethical standards.")
249
 
250
  if __name__ == "__main__":
251
+ demo.queue(max_size=40).launch()