Spaces:
Sleeping
Sleeping
import gradio as gr | |
import gradio.components as grc | |
import onnxruntime | |
import numpy as np | |
from torchvision.transforms import Normalize, Compose, Resize, ToTensor | |
batch_size = 1 | |
def convert_to_rgb(image): | |
return image.convert("RGB") | |
def get_transform(image_size=384): | |
return Compose([ | |
convert_to_rgb, | |
Resize((image_size, image_size)), | |
ToTensor(), | |
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def load_tag_list(tag_list_file): | |
with open(tag_list_file, 'r', encoding="utf-8") as f: | |
tag_list = f.read().splitlines() | |
tag_list = np.array(tag_list) | |
return tag_list | |
def load_word_vocabulary(word_vocabulary_file): | |
with open(word_vocabulary_file, 'r', encoding="utf-8") as f: | |
word_vocabulary = f.read().splitlines() | |
words = [word.split(',') for word in word_vocabulary] | |
word2idx = {} | |
for i in range(len(words)): | |
for j in range(len(words[i])): | |
word2idx[words[i][j]] = i | |
return word2idx | |
from huggingface_hub import hf_hub_download | |
hf_hub_download(repo_id="Inf009/ram-tagger", repo_type="model", local_dir="resources", filename="ram_swin_large_14m_b1.onnx", local_dir_use_symlinks=False) | |
ort_session = onnxruntime.InferenceSession("resources/ram_swin_large_14m_b1.onnx", providers=["CUDAExecutionProvider"]) | |
transform = get_transform() | |
tag_list = load_tag_list("resources/ram_tag_list.txt") | |
word_index = load_word_vocabulary("resources/word_vocabulary_english.txt") | |
def inference_by_image_pil(image): | |
image_arrays = transform(image).unsqueeze(0).numpy() | |
# compute ONNX Runtime output prediction | |
ort_inputs = {ort_session.get_inputs()[0].name: image_arrays} | |
ort_outs = ort_session.run(None, ort_inputs) | |
index = np.argwhere(ort_outs[0][0] == 1) | |
token = tag_list[index].squeeze(axis=1).tolist() | |
token = rerank_tags(token) | |
return ",".join(token) | |
def rerank_tags(tags): | |
indexed_tags = [[] for _ in range(max(word_index.values()) + 1)] | |
for tag in tags: | |
indexed_tags[word_index[tag]].append(tag) | |
reranked_tags = [] | |
for indexed_tag in indexed_tags: | |
reranked_tags += indexed_tag | |
return reranked_tags | |
app = gr.Interface(fn=inference_by_image_pil, inputs=grc.Image(type='pil'), | |
outputs=grc.Text(), title="RAM Tagger", | |
description="A tagger for images.") | |
app.launch() |