drhead commited on
Commit
70e64be
1 Parent(s): 0cdffb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -6
app.py CHANGED
@@ -118,6 +118,25 @@ model = timm.create_model(
118
  num_classes=9083,
119
  ) # type: VisionTransformer
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
122
  model.eval()
123
 
@@ -134,10 +153,9 @@ def create_tags(image, threshold):
134
  tensor = transform(img).unsqueeze(0)
135
 
136
  with torch.no_grad():
137
- logits = model(tensor)
138
- probabilities = torch.nn.functional.sigmoid(logits[0])
139
- indices = torch.where(probabilities > threshold)[0]
140
- values = probabilities[indices]
141
 
142
  temp = []
143
  tag_score = dict()
@@ -150,10 +168,10 @@ def create_tags(image, threshold):
150
 
151
  with gr.Blocks() as demo:
152
  gr.Markdown("""
153
- ## Joint Tagger Project: PILOT Demo
154
  This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
155
 
156
- This tagger is the result of joint efforts between members of the RedRocket team.
157
 
158
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
159
  """)
 
118
  num_classes=9083,
119
  ) # type: VisionTransformer
120
 
121
+ class GatedHead(torch.nn.Module):
122
+ def __init__(self,
123
+ num_features: int,
124
+ num_classes: int
125
+ ):
126
+ super().__init__()
127
+ self.num_classes = num_classes
128
+ self.linear = torch.nn.Linear(num_features, num_classes * 2)
129
+
130
+ self.act = torch.nn.Sigmoid()
131
+ self.gate = torch.nn.Sigmoid()
132
+
133
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
134
+ x = self.linear(x)
135
+ x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:])
136
+ return x
137
+
138
+ model.head = GatedHead(min(model.head.weight.shape), 9083)
139
+
140
  safetensors.torch.load_model(model, "JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors")
141
  model.eval()
142
 
 
153
  tensor = transform(img).unsqueeze(0)
154
 
155
  with torch.no_grad():
156
+ probits = model(tensor)
157
+ indices = torch.where(probits > threshold)[0]
158
+ values = probits[indices]
 
159
 
160
  temp = []
161
  tag_score = dict()
 
168
 
169
  with gr.Blocks() as demo:
170
  gr.Markdown("""
171
+ ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
172
  This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
173
 
174
+ This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
175
 
176
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
177
  """)