p1atdev commited on
Commit
e212637
·
1 Parent(s): 53b46fb

feat: add code

Browse files
Files changed (3) hide show
  1. app.py +120 -0
  2. modeling_siglip.py +57 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from transformers import (
5
+ AutoProcessor,
6
+ )
7
+
8
+ from PIL import Image
9
+ import gradio as gr
10
+
11
+ from modeling_siglip import SiglipForImageClassification
12
+
13
+ MODEL_NAME = "p1atdev/siglip-tagger-test-3"
14
+ PROCESSOR_NAME = "google/siglip-so400m-patch14-384"
15
+
16
+ model = SiglipForImageClassification.from_pretrained(
17
+ MODEL_NAME,
18
+ )
19
+ # model = torch.compile(model)
20
+ processor = AutoProcessor.from_pretrained(PROCESSOR_NAME)
21
+
22
+
23
+ def compose_text(results: dict[str, float], threshold: float = 0.3):
24
+ return ", ".join(
25
+ [
26
+ key
27
+ for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
28
+ if value > threshold
29
+ ]
30
+ )
31
+
32
+
33
+ @torch.no_grad()
34
+ def predict_tags(image: Image.Image, threshold: float):
35
+ inputs = processor(images=image, return_tensors="pt")
36
+
37
+ logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu()
38
+
39
+ logits = np.clip(logits, 0.0, 1.0)
40
+
41
+ results = {}
42
+
43
+ for prediction in logits:
44
+ for i, prob in enumerate(prediction):
45
+ if prob.item() > 0:
46
+ results[model.config.id2label[i]] = prob.item()
47
+
48
+ return compose_text(results, threshold), results
49
+
50
+
51
+ css = """\
52
+ .sticky {
53
+ position: sticky;
54
+ top: 16px;
55
+ }
56
+
57
+ .gradio-container {
58
+ overflow: clip;
59
+ }
60
+ """
61
+
62
+
63
+ def demo():
64
+ with gr.Blocks(css=css) as ui:
65
+ gr.Markdown(
66
+ """\
67
+ ## SigLIP Tagger Test 3
68
+ An experimental model for tagging danbooru tags of images using SigLIP.
69
+
70
+ Models:
71
+ - (soon)
72
+
73
+ Example images by NovelAI and niji・journey.
74
+
75
+ """
76
+ )
77
+
78
+ with gr.Row():
79
+ with gr.Column():
80
+ with gr.Row(elem_classes="sticky"):
81
+ with gr.Column():
82
+ input_img = gr.Image(
83
+ label="Input image", type="pil", height=480
84
+ )
85
+
86
+ with gr.Group():
87
+ tag_threshold_slider = gr.Slider(
88
+ label="Tags threshold",
89
+ minimum=0.0,
90
+ maximum=1.0,
91
+ value=0.3,
92
+ step=0.01,
93
+ )
94
+
95
+ start_btn = gr.Button(value="Start", variant="primary")
96
+
97
+ gr.Examples(
98
+ examples=[["./sample.jpg"], ["./sample2.webp"]],
99
+ inputs=[input_img],
100
+ cache_examples=False,
101
+ )
102
+
103
+ with gr.Column():
104
+ output_tags = gr.Text(label="Output text", interactive=False)
105
+ output_label = gr.Label(label="Output tags")
106
+
107
+ start_btn.click(
108
+ fn=predict_tags,
109
+ inputs=[input_img, tag_threshold_slider],
110
+ outputs=[output_tags, output_label],
111
+ )
112
+
113
+ ui.launch(
114
+ debug=True,
115
+ # share=True
116
+ )
117
+
118
+
119
+ if __name__ == "__main__":
120
+ demo()
modeling_siglip.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig
7
+ from transformers.utils import ModelOutput
8
+
9
+
10
+ @dataclass
11
+ class SiglipForImageClassifierOutput(ModelOutput):
12
+ loss: torch.FloatTensor | None = None
13
+ logits: torch.FloatTensor | None = None
14
+ pooler_output: torch.FloatTensor | None = None
15
+ hidden_states: tuple[torch.FloatTensor, ...] | None = None
16
+ attentions: tuple[torch.FloatTensor, ...] | None = None
17
+
18
+
19
+ class SiglipForImageClassification(SiglipPreTrainedModel):
20
+ config_class = SiglipVisionConfig
21
+ main_input_name = "pixel_values"
22
+
23
+ def __init__(
24
+ self,
25
+ config,
26
+ ):
27
+ super().__init__(config)
28
+
29
+ self.num_labels = config.num_labels
30
+ self.siglip = SiglipVisionModel(config)
31
+
32
+ # Classifier head
33
+ self.classifier = (
34
+ nn.Linear(config.hidden_size, config.num_labels)
35
+ if config.num_labels > 0
36
+ else nn.Identity()
37
+ )
38
+
39
+ # Initialize weights and apply final processing
40
+ self.post_init()
41
+
42
+ def forward(
43
+ self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None
44
+ ):
45
+ outputs = self.siglip(pixel_values)
46
+ pooler_output = outputs.pooler_output
47
+ logits = self.classifier(pooler_output)
48
+
49
+ loss = None
50
+
51
+ return SiglipForImageClassifierOutput(
52
+ loss=loss,
53
+ logits=logits,
54
+ pooler_output=outputs.pooler_output,
55
+ hidden_states=outputs.hidden_states,
56
+ attentions=outputs.attentions,
57
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ accelerate
3
+ transformers==4.37.2