Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,36 +1,56 @@
|
|
1 |
import json
|
2 |
-
import
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
from torchvision.transforms import transforms
|
6 |
-
import
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
model = torch.load('model.pth', map_location=torch.device('cpu'))
|
10 |
-
model.eval()
|
11 |
transform = transforms.Compose([
|
12 |
transforms.Resize((384, 384)),
|
13 |
transforms.ToTensor(),
|
14 |
-
transforms.Normalize(
|
15 |
-
mean=[
|
16 |
-
0.5,
|
17 |
-
0.5,
|
18 |
-
0.5,
|
19 |
-
], std=[
|
20 |
-
0.5,
|
21 |
-
0.5,
|
22 |
-
0.5,
|
23 |
-
])
|
24 |
])
|
25 |
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def create_tags(image, threshold):
|
35 |
img = image.convert('RGB')
|
36 |
tensor = transform(img).unsqueeze(0)
|
@@ -46,22 +66,65 @@ def create_tags(image, threshold):
|
|
46 |
for i in range(indices.size(0)):
|
47 |
temp.append([allowed_tags[indices[i]], values[i].item()])
|
48 |
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
49 |
-
# temp = sorted(temp, key=lambda x: x[1], reverse=True)
|
50 |
-
# print("Before adding implicated tags, there are " + str(len(temp)) + " tags")
|
51 |
temp = [t[0] for t in temp]
|
52 |
-
text_no_impl = " ".join(temp)
|
53 |
-
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
54 |
-
print(f"{current_datetime}: finished.")
|
55 |
return text_no_impl, tag_score
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
+
import zipfile
|
4 |
+
from pathlib import Path
|
5 |
+
import io
|
6 |
+
from tempfile import NamedTemporaryFile
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
import gradio as gr
|
10 |
import torch
|
11 |
from torchvision.transforms import transforms
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
import spaces
|
14 |
+
|
15 |
+
torch.jit.script = lambda f: f
|
16 |
+
# torch.cuda.amp.autocast(enabled=True)
|
17 |
+
|
18 |
+
caption_ext = ".txt"
|
19 |
+
exclude_tags = ("explicit", "questionable", "safe")
|
20 |
|
|
|
|
|
21 |
transform = transforms.Compose([
|
22 |
transforms.Resize((384, 384)),
|
23 |
transforms.ToTensor(),
|
24 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
])
|
26 |
|
27 |
+
class ZipImageDataset(Dataset):
|
28 |
+
def __init__(self, zip_file, dtype):
|
29 |
+
self.zip_file = zip_file
|
30 |
+
self.dtype = dtype
|
31 |
+
self.image_files = [file_info for file_info in zip_file.infolist() if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.image_files)
|
35 |
|
36 |
+
def __getitem__(self, index):
|
37 |
+
file_info = self.image_files[index]
|
38 |
+
with self.zip_file.open(file_info) as file:
|
39 |
+
image = Image.open(file).convert("RGB")
|
40 |
+
image = transform(image).to(self.dtype)
|
41 |
+
return {
|
42 |
+
"image": image,
|
43 |
+
"image_name": file_info.filename,
|
44 |
+
}
|
45 |
|
46 |
+
model = torch.load("./model.pth", map_location=torch.device('cpu'))
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
with open("tags_9940.json", "r") as file:
|
50 |
+
tags = json.load(file)
|
51 |
+
allowed_tags = sorted(tags) + ["explicit", "questionable", "safe"]
|
52 |
+
|
53 |
+
@spaces.GPU(duration=5)
|
54 |
def create_tags(image, threshold):
|
55 |
img = image.convert('RGB')
|
56 |
tensor = transform(img).unsqueeze(0)
|
|
|
66 |
for i in range(indices.size(0)):
|
67 |
temp.append([allowed_tags[indices[i]], values[i].item()])
|
68 |
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
|
|
|
|
69 |
temp = [t[0] for t in temp]
|
70 |
+
text_no_impl = ", ".join(temp)
|
|
|
|
|
71 |
return text_no_impl, tag_score
|
72 |
|
73 |
+
@spaces.GPU(duration=180)
|
74 |
+
def process_zip(zip_file, threshold):
|
75 |
+
with zipfile.ZipFile(zip_file.name) as zip_ref:
|
76 |
+
dataset = ZipImageDataset(zip_ref, next(model.parameters()).dtype)
|
77 |
+
dataloader = DataLoader(
|
78 |
+
dataset,
|
79 |
+
batch_size=64,
|
80 |
+
shuffle=False,
|
81 |
+
num_workers=0,
|
82 |
+
pin_memory=True,
|
83 |
+
drop_last=False,
|
84 |
+
)
|
85 |
+
all_image_names = []
|
86 |
+
all_probabilities = []
|
87 |
+
with torch.no_grad():
|
88 |
+
for i, batch in enumerate(dataloader):
|
89 |
+
images = batch["image"]
|
90 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
91 |
+
outputs = model(images)
|
92 |
+
probabilities = torch.nn.functional.sigmoid(outputs)
|
93 |
+
for image_name, prob in zip(batch["image_name"], probabilities):
|
94 |
+
indices = torch.where(prob > threshold)[0]
|
95 |
+
values = prob[indices]
|
96 |
+
temp = []
|
97 |
+
tag_score = dict()
|
98 |
+
for j in range(indices.size(0)):
|
99 |
+
temp.append([allowed_tags[indices[j]], values[j].item()])
|
100 |
+
tag_score[allowed_tags[indices[j]]] = values[j].item()
|
101 |
+
temp = [t[0] for t in temp]
|
102 |
+
text_no_impl = ", ".join(temp)
|
103 |
+
all_image_names.append(image_name)
|
104 |
+
all_probabilities.append(text_no_impl)
|
105 |
+
|
106 |
+
temp_file = NamedTemporaryFile(delete=False, suffix=".zip")
|
107 |
+
with zipfile.ZipFile(temp_file, "w") as zip_ref:
|
108 |
+
for image_name, text_no_impl in zip(all_image_names, all_probabilities):
|
109 |
+
with zip_ref.open(image_name + caption_ext, "w") as file:
|
110 |
+
file.write(text_no_impl.encode())
|
111 |
+
temp_file.seek(0)
|
112 |
+
return temp_file.name
|
113 |
+
|
114 |
+
with gr.Blocks() as demo:
|
115 |
+
with gr.Tab("Single Image"):
|
116 |
+
gr.Interface(
|
117 |
+
create_tags,
|
118 |
+
inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")],
|
119 |
+
outputs=[
|
120 |
+
gr.Textbox(label="Tag String"),
|
121 |
+
gr.Label(label="Tag Predictions", num_top_classes=200),
|
122 |
+
],
|
123 |
+
allow_flagging="never",
|
124 |
+
)
|
125 |
+
with gr.Tab("Multiple Images"):
|
126 |
+
gr.Interface(fn=process_zip, inputs=[gr.File(label="Zip File", file_types=[".zip"]), gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Threshold")],
|
127 |
+
outputs=gr.File(type="binary"))
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
demo.launch()
|