Spaces:
Build error
Build error
adymaharana
commited on
Commit
·
1cac669
1
Parent(s):
77e955b
restart
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os, torch
|
2 |
import gradio as gr
|
3 |
import torchvision.utils as vutils
|
4 |
import torchvision.transforms as transforms
|
@@ -68,6 +68,7 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
|
|
68 |
def main(args):
|
69 |
#device = 'cuda:0'
|
70 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
71 |
|
72 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
73 |
|
@@ -77,7 +78,7 @@ def main(args):
|
|
77 |
#if not os.path.exists("./ckpt/25.pth"):
|
78 |
# gdown.download(model_url, quiet=False, use_cookies=False, output="./ckpt/25.pth")
|
79 |
# print("Downloaded checkpoint")
|
80 |
-
assert os.path.exists("./ckpt/25.pth")
|
81 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
82 |
|
83 |
if args.debug:
|
@@ -102,6 +103,9 @@ def main(args):
|
|
102 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
103 |
)
|
104 |
|
|
|
|
|
|
|
105 |
def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
|
106 |
supercondition=False):
|
107 |
|
|
|
1 |
+
import os, sys, torch
|
2 |
import gradio as gr
|
3 |
import torchvision.utils as vutils
|
4 |
import torchvision.transforms as transforms
|
|
|
68 |
def main(args):
|
69 |
#device = 'cuda:0'
|
70 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
71 |
+
#device = torch.device('cpu')
|
72 |
|
73 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
74 |
|
|
|
78 |
#if not os.path.exists("./ckpt/25.pth"):
|
79 |
# gdown.download(model_url, quiet=False, use_cookies=False, output="./ckpt/25.pth")
|
80 |
# print("Downloaded checkpoint")
|
81 |
+
#assert os.path.exists("./ckpt/25.pth")
|
82 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
83 |
|
84 |
if args.debug:
|
|
|
103 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
104 |
)
|
105 |
|
106 |
+
#torch.save(model, './ckpt/checkpoint.pt')
|
107 |
+
#sys.exit()
|
108 |
+
|
109 |
def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
|
110 |
supercondition=False):
|
111 |
|
dalle/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/dalle/__pycache__/__init__.cpython-38.pyc and b/dalle/__pycache__/__init__.cpython-38.pyc differ
|
|
dalle/models/__init__.py
CHANGED
@@ -23,6 +23,7 @@ from ..utils.utils import save_image
|
|
23 |
from .tokenizer import build_tokenizer
|
24 |
import numpy as np
|
25 |
from .stage2.layers import CrossAttentionLayer
|
|
|
26 |
|
27 |
_MODELS = {
|
28 |
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
@@ -1191,7 +1192,9 @@ class StoryDalle(Dalle):
|
|
1191 |
print("Loaded tokenizer from finetuned checkpoint")
|
1192 |
print(model.cross_attention_idxs)
|
1193 |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
|
|
1194 |
# model.from_ckpt(args.model_name_or_path)
|
|
|
1195 |
try:
|
1196 |
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
|
1197 |
except KeyError:
|
@@ -1248,9 +1251,9 @@ class StoryDalle(Dalle):
|
|
1248 |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
1249 |
|
1250 |
with torch.no_grad():
|
1251 |
-
with autocast(enabled=False):
|
1252 |
-
|
1253 |
-
|
1254 |
|
1255 |
B, C, H, W = images.shape
|
1256 |
|
@@ -1310,8 +1313,8 @@ class StoryDalle(Dalle):
|
|
1310 |
# Check if the encoding works as intended
|
1311 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
1312 |
|
1313 |
-
tokens = tokens.to(device)
|
1314 |
-
source = source.to(device)
|
1315 |
|
1316 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
1317 |
B, L, _ = sent_embeds.shape
|
@@ -1322,8 +1325,8 @@ class StoryDalle(Dalle):
|
|
1322 |
prompt = sent_embeds
|
1323 |
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
|
1324 |
|
1325 |
-
with autocast(enabled=False):
|
1326 |
-
|
1327 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
1328 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
1329 |
if self.config.story.condition:
|
@@ -1378,8 +1381,8 @@ class StoryDalle(Dalle):
|
|
1378 |
# Check if the encoding works as intended
|
1379 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
1380 |
|
1381 |
-
tokens = tokens.to(device)
|
1382 |
-
source = source.to(device)
|
1383 |
|
1384 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
1385 |
B, L, _ = sent_embeds.shape
|
@@ -1389,10 +1392,10 @@ class StoryDalle(Dalle):
|
|
1389 |
else:
|
1390 |
prompt = sent_embeds
|
1391 |
pos_enc_prompt = get_positional_encoding(
|
1392 |
-
torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(
|
1393 |
|
1394 |
-
with autocast(enabled=False):
|
1395 |
-
|
1396 |
|
1397 |
# repeat inputs to adjust to n_candidates and story length
|
1398 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
|
|
23 |
from .tokenizer import build_tokenizer
|
24 |
import numpy as np
|
25 |
from .stage2.layers import CrossAttentionLayer
|
26 |
+
from huggingface_hub import hf_hub_download
|
27 |
|
28 |
_MODELS = {
|
29 |
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
|
|
1192 |
print("Loaded tokenizer from finetuned checkpoint")
|
1193 |
print(model.cross_attention_idxs)
|
1194 |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
1195 |
+
|
1196 |
# model.from_ckpt(args.model_name_or_path)
|
1197 |
+
|
1198 |
try:
|
1199 |
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
|
1200 |
except KeyError:
|
|
|
1251 |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
1252 |
|
1253 |
with torch.no_grad():
|
1254 |
+
#with autocast(enabled=False):
|
1255 |
+
codes = self.stage1.get_codes(images).detach()
|
1256 |
+
src_codes = self.stage1.get_codes(src_images).detach()
|
1257 |
|
1258 |
B, C, H, W = images.shape
|
1259 |
|
|
|
1313 |
# Check if the encoding works as intended
|
1314 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
1315 |
|
1316 |
+
#tokens = tokens.to(device)
|
1317 |
+
#source = source.to(device)
|
1318 |
|
1319 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
1320 |
B, L, _ = sent_embeds.shape
|
|
|
1325 |
prompt = sent_embeds
|
1326 |
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
|
1327 |
|
1328 |
+
#with autocast(enabled=False):
|
1329 |
+
src_codes = self.stage1.get_codes(source).detach()
|
1330 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
1331 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
1332 |
if self.config.story.condition:
|
|
|
1381 |
# Check if the encoding works as intended
|
1382 |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
1383 |
|
1384 |
+
#tokens = tokens.to(device)
|
1385 |
+
#source = source.to(device)
|
1386 |
|
1387 |
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
1388 |
B, L, _ = sent_embeds.shape
|
|
|
1392 |
else:
|
1393 |
prompt = sent_embeds
|
1394 |
pos_enc_prompt = get_positional_encoding(
|
1395 |
+
torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(tokens.device), mode='1d')
|
1396 |
|
1397 |
+
#with autocast(enabled=False):
|
1398 |
+
src_codes = self.stage1.get_codes(source).detach()
|
1399 |
|
1400 |
# repeat inputs to adjust to n_candidates and story length
|
1401 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
dalle/models/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/dalle/models/__pycache__/__init__.cpython-38.pyc and b/dalle/models/__pycache__/__init__.cpython-38.pyc differ
|
|