Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
2b08e86
1
Parent(s):
28b55ac
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import os
|
2 |
-
os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html")
|
3 |
os.system("git clone https://github.com/openai/CLIP")
|
4 |
os.system("pip install -e ./CLIP")
|
5 |
-
os.system("pip install einops ninja scipy numpy Pillow tqdm imageio-ffmpeg imageio")
|
6 |
import sys
|
7 |
sys.path.append('./CLIP')
|
8 |
import io
|
@@ -105,65 +103,71 @@ zs = torch.randn([10000, G.mapping.z_dim], device=device)
|
|
105 |
w_stds = G.mapping(zs, None).std(0)
|
106 |
|
107 |
|
108 |
-
def inference(text,steps,image):
|
109 |
-
|
110 |
-
|
111 |
-
if image:
|
112 |
-
target = embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
|
113 |
-
else:
|
114 |
target = clip_model.embed_text(text)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
|
169 |
title = "StyleGAN3+CLIP"
|
@@ -172,7 +176,7 @@ article = "<p style='text-align: center'><a href='https://colab.research.google.
|
|
172 |
examples = [['mario',150,None]]
|
173 |
gr.Interface(
|
174 |
inference,
|
175 |
-
["text",gr.inputs.Slider(minimum=50, maximum=200, step=1, default=150, label="steps"),gr.inputs.Image(type="pil", label="Image (Optional)", optional=True)],
|
176 |
[gr.outputs.Image(type="pil", label="Output"),"playable_video"],
|
177 |
title=title,
|
178 |
description=description,
|
|
|
1 |
import os
|
|
|
2 |
os.system("git clone https://github.com/openai/CLIP")
|
3 |
os.system("pip install -e ./CLIP")
|
|
|
4 |
import sys
|
5 |
sys.path.append('./CLIP')
|
6 |
import io
|
|
|
103 |
w_stds = G.mapping(zs, None).std(0)
|
104 |
|
105 |
|
106 |
+
def inference(text,steps,image,mode):
|
107 |
+
if mode == "CLIP+StyleGAN3":
|
108 |
+
all_frames = []
|
|
|
|
|
|
|
109 |
target = clip_model.embed_text(text)
|
110 |
+
if image:
|
111 |
+
target = embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
|
112 |
+
else:
|
113 |
+
target = clip_model.embed_text(text)
|
114 |
+
steps = steps
|
115 |
+
#seed = 2
|
116 |
+
seed = -1
|
117 |
+
if seed == -1:
|
118 |
+
seed = np.random.randint(0,2**32 - 1)
|
119 |
+
tf = Compose([
|
120 |
+
Resize(224),
|
121 |
+
lambda x: torch.clamp((x+1)/2,min=0,max=1),
|
122 |
+
])
|
123 |
+
torch.manual_seed(seed)
|
124 |
+
timestring = time.strftime('%Y%m%d%H%M%S')
|
125 |
+
with torch.no_grad():
|
126 |
+
qs = []
|
127 |
+
losses = []
|
128 |
+
for _ in range(8):
|
129 |
+
q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
|
130 |
+
images = G.synthesis(q * w_stds + G.mapping.w_avg)
|
131 |
+
embeds = embed_image(images.add(1).div(2))
|
132 |
+
loss = spherical_dist_loss(embeds, target).mean(0)
|
133 |
+
i = torch.argmin(loss)
|
134 |
+
qs.append(q[i])
|
135 |
+
losses.append(loss[i])
|
136 |
+
qs = torch.stack(qs)
|
137 |
+
losses = torch.stack(losses)
|
138 |
+
print(losses)
|
139 |
+
print(losses.shape, qs.shape)
|
140 |
+
i = torch.argmin(losses)
|
141 |
+
q = qs[i].unsqueeze(0)
|
142 |
+
q.requires_grad_()
|
143 |
+
q_ema = q
|
144 |
+
opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
|
145 |
+
loop = tqdm(range(steps))
|
146 |
+
for i in loop:
|
147 |
+
opt.zero_grad()
|
148 |
+
w = q * w_stds
|
149 |
+
image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
|
150 |
+
embed = embed_image(image.add(1).div(2))
|
151 |
+
loss = spherical_dist_loss(embed, target).mean()
|
152 |
+
loss.backward()
|
153 |
+
opt.step()
|
154 |
+
loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
|
155 |
+
q_ema = q_ema * 0.9 + q * 0.1
|
156 |
+
image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
|
157 |
+
pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
|
158 |
+
all_frames.append(pil_image)
|
159 |
+
#os.makedirs(f'samples/{timestring}', exist_ok=True)
|
160 |
+
#pil_image.save(f'samples/{timestring}/{i:04}.jpg')
|
161 |
+
writer = imageio.get_writer('test.mp4', fps=15)
|
162 |
+
for im in all_frames:
|
163 |
+
writer.append_data(np.array(im))
|
164 |
+
writer.close()
|
165 |
+
return pil_image, "test.mp4"
|
166 |
+
else:
|
167 |
+
os.system("python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \
|
168 |
+
--network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl")
|
169 |
+
img = Image.new("RGB", (800, 1280), (255, 255, 255))
|
170 |
+
return img, "lerp.mp4"
|
171 |
|
172 |
|
173 |
title = "StyleGAN3+CLIP"
|
|
|
176 |
examples = [['mario',150,None]]
|
177 |
gr.Interface(
|
178 |
inference,
|
179 |
+
["text",gr.inputs.Slider(minimum=50, maximum=200, step=1, default=150, label="steps"),gr.inputs.Image(type="pil", label="Image (Optional)", optional=True),gradio.inputs.Radio(choices["CLIP+StyleGAN3","Stylegan3 interpolation"] type="value", default="CLIP+StyleGAN3", label="mode")],
|
180 |
[gr.outputs.Image(type="pil", label="Output"),"playable_video"],
|
181 |
title=title,
|
182 |
description=description,
|