amildravid4292 commited on
Commit
f72cee5
·
verified ·
1 Parent(s): 216fb80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -30
app.py CHANGED
@@ -12,14 +12,15 @@ from utils import load_models, save_model_w2w, save_model_for_diffusers
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
- #global device
16
- #global generator
17
- #global unet
18
- #global vae
19
- #global text_encoder
20
- #global tokenizer
21
- #global noise_scheduler
22
  device = "cuda:0"
 
23
 
24
  models_path = snapshot_download(repo_id="Snapchat/w2w")
25
 
@@ -31,30 +32,33 @@ df = torch.load(f"{models_path}/identity_df.pt")
31
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
32
 
33
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
34
- network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
35
- #global network
36
-
37
- #def sample_model():
38
- # global unet
39
- # del unet
40
- # global network
41
- # unet, _, _, _, _ = load_models(device)
42
-
43
- def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
44
- #global device
45
- #global generator
46
- #global unet
47
- #global vae
48
- #global text_encoder
49
- #global tokenizer
50
- #global noise_scheduler
51
- generator = torch.Generator(device=device).manual_seed(seed)
 
 
52
  latents = torch.randn(
53
  (1, unet.in_channels, 512 // 8, 512 // 8),
54
  generator = generator,
55
  device = device
56
  ).bfloat16()
57
 
 
58
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
59
 
60
  text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
@@ -87,8 +91,12 @@ def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
87
 
88
  return [image]
89
 
90
- with gr.Blocks() as demo:
 
 
 
91
  gr.Markdown("# <em>weights2weights</em> Demo")
 
92
  with gr.Row():
93
  with gr.Column():
94
  files = gr.Files(
@@ -106,9 +114,9 @@ with gr.Blocks() as demo:
106
  placeholder="sks person",
107
  value="sks person")
108
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
109
- seed = gr.Number(value=5, label="Seed", interactive=True)
110
  cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
111
- steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
112
 
113
 
114
  submit = gr.Button("Submit")
@@ -116,10 +124,15 @@ with gr.Blocks() as demo:
116
  with gr.Column():
117
  gallery = gr.Gallery(label="Generated Images")
118
 
119
- #sample.click(fn=sample_model)
120
 
121
  submit.click(fn=inference,
122
  inputs=[prompt, negative_prompt, cfg, steps, seed],
123
  outputs=gallery)
124
 
125
- demo.launch(share=True)
 
 
 
 
 
 
12
  from sampling import sample_weights
13
  from huggingface_hub import snapshot_download
14
 
15
+ global device
16
+ global generator
17
+ global unet
18
+ global vae
19
+ global text_encoder
20
+ global tokenizer
21
+ global noise_scheduler
22
  device = "cuda:0"
23
+ generator = torch.Generator(device=device)
24
 
25
  models_path = snapshot_download(repo_id="Snapchat/w2w")
26
 
 
32
  weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
33
 
34
  unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
35
+ global network
36
+
37
+ def sample_model():
38
+ global unet
39
+ del unet
40
+ global network
41
+ unet, _, _, _, _ = load_models(device)
42
+ network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
43
+
44
+
45
+ @torch.no_grad()
46
+ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
47
+ global device
48
+ global generator
49
+ global unet
50
+ global vae
51
+ global text_encoder
52
+ global tokenizer
53
+ global noise_scheduler
54
+ generator = generator.manual_seed(seed)
55
  latents = torch.randn(
56
  (1, unet.in_channels, 512 // 8, 512 // 8),
57
  generator = generator,
58
  device = device
59
  ).bfloat16()
60
 
61
+
62
  text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
63
 
64
  text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
 
91
 
92
  return [image]
93
 
94
+
95
+
96
+ css = ''
97
+ with gr.Blocks(css=css) as demo:
98
  gr.Markdown("# <em>weights2weights</em> Demo")
99
+ gr.Markdown("Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co/h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license")
100
  with gr.Row():
101
  with gr.Column():
102
  files = gr.Files(
 
114
  placeholder="sks person",
115
  value="sks person")
116
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
117
+ seed = gr.Number(value=5, precision=0, label="Seed", interactive=True)
118
  cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
119
+ steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)
120
 
121
 
122
  submit = gr.Button("Submit")
 
124
  with gr.Column():
125
  gallery = gr.Gallery(label="Generated Images")
126
 
127
+ sample.click(fn=sample_model)
128
 
129
  submit.click(fn=inference,
130
  inputs=[prompt, negative_prompt, cfg, steps, seed],
131
  outputs=gallery)
132
 
133
+
134
+
135
+
136
+
137
+
138
+ demo.launch(share=True)