Spaces:
Running
Running
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
from transformers import ( | |
AutoImageProcessor, | |
) | |
import gradio as gr | |
from modeling_siglip import SiglipForImageClassification | |
HF_TOKEN = os.environ.get("HF_READ_TOKEN") | |
EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]] | |
model_maps: dict[str, dict] = { | |
"test2": { | |
"repo": "p1atdev/siglip-tagger-test-2", | |
}, | |
"test3": { | |
"repo": "p1atdev/siglip-tagger-test-3", | |
}, | |
# "test4": { | |
# "repo": "p1atdev/siglip-tagger-test-4", | |
# }, | |
} | |
for key in model_maps.keys(): | |
model_maps[key]["model"] = SiglipForImageClassification.from_pretrained( | |
model_maps[key]["repo"], torch_dtype=torch.bfloat16, token=HF_TOKEN | |
) | |
model_maps[key]["processor"] = AutoImageProcessor.from_pretrained( | |
model_maps[key]["repo"], token=HF_TOKEN | |
) | |
README_MD = ( | |
f"""\ | |
## SigLIP Tagger Test 3 | |
An experimental model for tagging danbooru tags of images using SigLIP. | |
Model(s): | |
""" | |
+ "\n".join( | |
f"- [{value['repo']}](https://huggingface.co./{value['repo']})" | |
for value in model_maps.values() | |
) | |
+ "\n" | |
+ """ | |
Example images by NovelAI and niji・journey. | |
""" | |
) | |
def compose_text(results: dict[str, float], threshold: float = 0.3): | |
return ", ".join( | |
[ | |
key | |
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True) | |
if value > threshold | |
] | |
) | |
def predict_tags(image: Image.Image, model_name: str, threshold: float): | |
if image is None: | |
return None, None | |
inputs = model_maps[model_name]["processor"](image, return_tensors="pt") | |
logits = ( | |
model_maps[model_name]["model"]( | |
**inputs.to( | |
model_maps[model_name]["model"].device, | |
model_maps[model_name]["model"].dtype, | |
) | |
) | |
.logits.detach() | |
.cpu() | |
.float() | |
) | |
logits = np.clip(logits, 0.0, 1.0) | |
results = {} | |
for prediction in logits: | |
for i, prob in enumerate(prediction): | |
if prob.item() > 0: | |
results[model_maps[model_name]["model"].config.id2label[i]] = ( | |
prob.item() | |
) | |
return compose_text(results, threshold), results | |
css = """\ | |
.sticky { | |
position: sticky; | |
top: 16px; | |
} | |
.gradio-container { | |
overflow: clip; | |
} | |
""" | |
def demo(): | |
with gr.Blocks(css=css) as ui: | |
gr.Markdown(README_MD) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(elem_classes="sticky"): | |
with gr.Column(): | |
input_img = gr.Image( | |
label="Input image", type="pil", height=480 | |
) | |
with gr.Group(): | |
model_name_radio = gr.Radio( | |
label="Model", | |
choices=list(model_maps.keys()), | |
value="test3", | |
) | |
tag_threshold_slider = gr.Slider( | |
label="Tags threshold", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.3, | |
step=0.01, | |
) | |
start_btn = gr.Button(value="Start", variant="primary") | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[input_img], | |
cache_examples=False, | |
) | |
with gr.Column(): | |
output_tags = gr.Text(label="Output text", interactive=False) | |
output_label = gr.Label(label="Output tags") | |
start_btn.click( | |
fn=predict_tags, | |
inputs=[input_img, model_name_radio, tag_threshold_slider], | |
outputs=[output_tags, output_label], | |
) | |
ui.launch( | |
debug=True, | |
# share=True | |
) | |
if __name__ == "__main__": | |
demo() | |