lnyan commited on
Commit
f1d2416
·
1 Parent(s): 8ad6ef6
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -12,7 +12,7 @@ from jax import Array as Tensor
12
  from transformers import (FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel,
13
  T5Tokenizer)
14
 
15
-
16
  class HFEmbedder(nnx.Module):
17
  def __init__(self, version: str, max_length: int, **hf_kwargs):
18
  self.is_clip = version.startswith("openai")
@@ -29,7 +29,7 @@ class HFEmbedder(nnx.Module):
29
  self.hf_module, params = FlaxT5EncoderModel.from_pretrained(version, _do_init=False,**hf_kwargs)
30
  self.hf_module._is_initialized = True
31
  import jax
32
- self.hf_module.params = jax.tree_map(lambda x: jax.device_put(x, jax.devices("cuda")[0]), params)
33
  # if dtype==jnp.bfloat16:
34
 
35
  def tokenize(self, text: list[str]) -> Tensor:
@@ -107,10 +107,11 @@ def b64(txt,vec):
107
  encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')
108
  return encoded
109
 
110
- t5,clip=load_encoders()
111
 
112
  @spaces.GPU(duration=20)
113
  def convert(prompt):
 
114
  if isinstance(prompt, str):
115
  prompt = [prompt]
116
  txt = t5.tokenize(prompt)
@@ -126,11 +127,12 @@ def _to_embed(t5, clip, txt, vec):
126
 
127
  to_embed=jax.jit(_to_embed)
128
 
129
- t5_tuple=nnx.split(t5)
130
- clip_tuple=nnx.split(clip)
131
 
132
  @spaces.GPU(duration=120)
133
  def compile(prompt):
 
134
  if isinstance(prompt, str):
135
  prompt = [prompt]
136
  txt = t5.tokenize(prompt)
@@ -138,6 +140,17 @@ def compile(prompt):
138
  text,vec=to_embed(t5_tuple,clip_tuple,txt,vec)
139
  return b64(txt,vec)
140
 
 
 
 
 
 
 
 
 
 
 
 
141
  with gr.Blocks() as demo:
142
  gr.Markdown("""A workaround for flux-flax to fit into 40G VRAM""")
143
  with gr.Row():
@@ -145,9 +158,10 @@ with gr.Blocks() as demo:
145
  prompt = gr.Textbox(label="prompt")
146
  convert_btn = gr.Button(value="Convert")
147
  compile_btn = gr.Button(value="Compile")
 
148
  with gr.Column():
149
  output = gr.Textbox(label="output")
150
-
151
  convert_btn.click(convert, inputs=prompt, outputs=output, api_name="convert")
152
  compile_btn.click(compile, inputs=prompt, outputs=output, api_name="compile")
153
 
 
12
  from transformers import (FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel,
13
  T5Tokenizer)
14
 
15
+ models = {}
16
  class HFEmbedder(nnx.Module):
17
  def __init__(self, version: str, max_length: int, **hf_kwargs):
18
  self.is_clip = version.startswith("openai")
 
29
  self.hf_module, params = FlaxT5EncoderModel.from_pretrained(version, _do_init=False,**hf_kwargs)
30
  self.hf_module._is_initialized = True
31
  import jax
32
+ self.hf_module.params = jax.tree.map(lambda x: jax.device_put(x, jax.devices("cuda")[0]), params)
33
  # if dtype==jnp.bfloat16:
34
 
35
  def tokenize(self, text: list[str]) -> Tensor:
 
107
  encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')
108
  return encoded
109
 
110
+ # t5,clip=load_encoders()
111
 
112
  @spaces.GPU(duration=20)
113
  def convert(prompt):
114
+ t5,clip=models["t5"],models["clip"]
115
  if isinstance(prompt, str):
116
  prompt = [prompt]
117
  txt = t5.tokenize(prompt)
 
127
 
128
  to_embed=jax.jit(_to_embed)
129
 
130
+ # t5_tuple=nnx.split(t5)
131
+ # clip_tuple=nnx.split(clip)
132
 
133
  @spaces.GPU(duration=120)
134
  def compile(prompt):
135
+ t5,clip,t5_tuple,clip_tuple=models["t5"],models["clip"],models["t5_tuple"],models["clip_tuple"]
136
  if isinstance(prompt, str):
137
  prompt = [prompt]
138
  txt = t5.tokenize(prompt)
 
140
  text,vec=to_embed(t5_tuple,clip_tuple,txt,vec)
141
  return b64(txt,vec)
142
 
143
+ @spaces.GPU(duration=120)
144
+ def load(prompt):
145
+ is_schnell = True
146
+ t5 = load_t5("cuda", max_length=256 if is_schnell else 512)
147
+ clip = load_clip("cuda")
148
+ models["t5"]=t5
149
+ models["clip"]=clip
150
+ models["t5_tuple"]=nnx.split(t5)
151
+ models["clip_tuple"]=nnx.split(clip)
152
+ return "Loaded"
153
+
154
  with gr.Blocks() as demo:
155
  gr.Markdown("""A workaround for flux-flax to fit into 40G VRAM""")
156
  with gr.Row():
 
158
  prompt = gr.Textbox(label="prompt")
159
  convert_btn = gr.Button(value="Convert")
160
  compile_btn = gr.Button(value="Compile")
161
+ load_btn = gr.Button(value="Load")
162
  with gr.Column():
163
  output = gr.Textbox(label="output")
164
+ load_btn.click(load, inputs=prompt, outputs=output, api_name="load")
165
  convert_btn.click(convert, inputs=prompt, outputs=output, api_name="convert")
166
  compile_btn.click(compile, inputs=prompt, outputs=output, api_name="compile")
167