metadata
library_name: transformers
license: apache-2.0
base_model: SmilingWolf/wd-swinv2-tagger-v3
inference: false
tags:
- wd-v14-tagger
WD SwinV2 Tagger v3 with 🤗 transformers
Converted from SmilingWolf/wd-swinv2-tagger-v3 to transformers library format.
Example
from PIL import Image
import numpy as np
import torch
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
)
MODEL_NAME = "p1atdev/wd-swinv2-tagger-v3-hf"
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
)
model.eval()
processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
image = Image.open("sample.webp")
inputs = processor.preprocess(image, return_tensors="pt")
outputs = model(**inputs.to(model.device, model.dtype))
logits = torch.sigmoid(outputs.logits[0])
# get probabilities
results = {model.config.id2label[i]: logit.float() for i, logit in enumerate(logits)}
results = {
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
}
print(results) # rating tags and character tags are also included
#{'1girl': tensor(0.9968),
# 'solo': tensor(0.9584),
# 'dress': tensor(0.9418),
# 'hat': tensor(0.9264),
# 'sitting': tensor(0.9178),
# 'looking_up': tensor(0.8978),
# 'short_hair': tensor(0.8243),
# 'sky': tensor(0.7846),
# 'outdoors': tensor(0.7676),
# 'rating:general': tensor(0.7562),
# ...