cdnuts commited on
Commit
9555522
1 Parent(s): f1a50c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -38
app.py CHANGED
@@ -1,36 +1,56 @@
1
  import json
2
- import random
3
- random.seed(999)
 
 
 
 
 
 
4
  import torch
5
  from torchvision.transforms import transforms
6
- import gradio as gr
7
- from datetime import datetime
 
 
 
 
 
 
8
 
9
- model = torch.load('model.pth', map_location=torch.device('cpu'))
10
- model.eval()
11
  transform = transforms.Compose([
12
  transforms.Resize((384, 384)),
13
  transforms.ToTensor(),
14
- transforms.Normalize(
15
- mean=[
16
- 0.5,
17
- 0.5,
18
- 0.5,
19
- ], std=[
20
- 0.5,
21
- 0.5,
22
- 0.5,
23
- ])
24
  ])
25
 
26
- with open("tags_9940.json", "r") as file:
27
- allowed_tags = json.load(file)
 
 
 
 
 
 
28
 
29
- allowed_tags = sorted(allowed_tags)
30
- allowed_tags.append("explicit")
31
- allowed_tags.append("questionable")
32
- allowed_tags.append("safe")
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
34
  def create_tags(image, threshold):
35
  img = image.convert('RGB')
36
  tensor = transform(img).unsqueeze(0)
@@ -46,22 +66,65 @@ def create_tags(image, threshold):
46
  for i in range(indices.size(0)):
47
  temp.append([allowed_tags[indices[i]], values[i].item()])
48
  tag_score[allowed_tags[indices[i]]] = values[i].item()
49
- # temp = sorted(temp, key=lambda x: x[1], reverse=True)
50
- # print("Before adding implicated tags, there are " + str(len(temp)) + " tags")
51
  temp = [t[0] for t in temp]
52
- text_no_impl = " ".join(temp)
53
- current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
54
- print(f"{current_datetime}: finished.")
55
  return text_no_impl, tag_score
56
 
57
- demo = gr.Interface(
58
- create_tags,
59
- inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")],
60
- outputs=[
61
- gr.Textbox(label="Tag String"),
62
- gr.Label(label="Tag Predictions", num_top_classes=200),
63
- ],
64
- allow_flagging="never",
65
- )
66
-
67
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import os
3
+ import zipfile
4
+ from pathlib import Path
5
+ import io
6
+ from tempfile import NamedTemporaryFile
7
+
8
+ from PIL import Image
9
+ import gradio as gr
10
  import torch
11
  from torchvision.transforms import transforms
12
+ from torch.utils.data import Dataset, DataLoader
13
+ import spaces
14
+
15
+ torch.jit.script = lambda f: f
16
+ # torch.cuda.amp.autocast(enabled=True)
17
+
18
+ caption_ext = ".txt"
19
+ exclude_tags = ("explicit", "questionable", "safe")
20
 
 
 
21
  transform = transforms.Compose([
22
  transforms.Resize((384, 384)),
23
  transforms.ToTensor(),
24
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
 
 
 
 
 
 
 
 
 
25
  ])
26
 
27
+ class ZipImageDataset(Dataset):
28
+ def __init__(self, zip_file, dtype):
29
+ self.zip_file = zip_file
30
+ self.dtype = dtype
31
+ self.image_files = [file_info for file_info in zip_file.infolist() if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
32
+
33
+ def __len__(self):
34
+ return len(self.image_files)
35
 
36
+ def __getitem__(self, index):
37
+ file_info = self.image_files[index]
38
+ with self.zip_file.open(file_info) as file:
39
+ image = Image.open(file).convert("RGB")
40
+ image = transform(image).to(self.dtype)
41
+ return {
42
+ "image": image,
43
+ "image_name": file_info.filename,
44
+ }
45
 
46
+ model = torch.load("./model.pth", map_location=torch.device('cpu'))
47
+ model.eval()
48
+
49
+ with open("tags_9940.json", "r") as file:
50
+ tags = json.load(file)
51
+ allowed_tags = sorted(tags) + ["explicit", "questionable", "safe"]
52
+
53
+ @spaces.GPU(duration=5)
54
  def create_tags(image, threshold):
55
  img = image.convert('RGB')
56
  tensor = transform(img).unsqueeze(0)
 
66
  for i in range(indices.size(0)):
67
  temp.append([allowed_tags[indices[i]], values[i].item()])
68
  tag_score[allowed_tags[indices[i]]] = values[i].item()
 
 
69
  temp = [t[0] for t in temp]
70
+ text_no_impl = ", ".join(temp)
 
 
71
  return text_no_impl, tag_score
72
 
73
+ @spaces.GPU(duration=180)
74
+ def process_zip(zip_file, threshold):
75
+ with zipfile.ZipFile(zip_file.name) as zip_ref:
76
+ dataset = ZipImageDataset(zip_ref, next(model.parameters()).dtype)
77
+ dataloader = DataLoader(
78
+ dataset,
79
+ batch_size=64,
80
+ shuffle=False,
81
+ num_workers=0,
82
+ pin_memory=True,
83
+ drop_last=False,
84
+ )
85
+ all_image_names = []
86
+ all_probabilities = []
87
+ with torch.no_grad():
88
+ for i, batch in enumerate(dataloader):
89
+ images = batch["image"]
90
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
91
+ outputs = model(images)
92
+ probabilities = torch.nn.functional.sigmoid(outputs)
93
+ for image_name, prob in zip(batch["image_name"], probabilities):
94
+ indices = torch.where(prob > threshold)[0]
95
+ values = prob[indices]
96
+ temp = []
97
+ tag_score = dict()
98
+ for j in range(indices.size(0)):
99
+ temp.append([allowed_tags[indices[j]], values[j].item()])
100
+ tag_score[allowed_tags[indices[j]]] = values[j].item()
101
+ temp = [t[0] for t in temp]
102
+ text_no_impl = ", ".join(temp)
103
+ all_image_names.append(image_name)
104
+ all_probabilities.append(text_no_impl)
105
+
106
+ temp_file = NamedTemporaryFile(delete=False, suffix=".zip")
107
+ with zipfile.ZipFile(temp_file, "w") as zip_ref:
108
+ for image_name, text_no_impl in zip(all_image_names, all_probabilities):
109
+ with zip_ref.open(image_name + caption_ext, "w") as file:
110
+ file.write(text_no_impl.encode())
111
+ temp_file.seek(0)
112
+ return temp_file.name
113
+
114
+ with gr.Blocks() as demo:
115
+ with gr.Tab("Single Image"):
116
+ gr.Interface(
117
+ create_tags,
118
+ inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")],
119
+ outputs=[
120
+ gr.Textbox(label="Tag String"),
121
+ gr.Label(label="Tag Predictions", num_top_classes=200),
122
+ ],
123
+ allow_flagging="never",
124
+ )
125
+ with gr.Tab("Multiple Images"):
126
+ gr.Interface(fn=process_zip, inputs=[gr.File(label="Zip File", file_types=[".zip"]), gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Threshold")],
127
+ outputs=gr.File(type="binary"))
128
+
129
+ if __name__ == "__main__":
130
+ demo.launch()