ram-tagger / app.py
baixintech_zhangyiming_prod
init
cbb13b8
import gradio as gr
import gradio.components as grc
import onnxruntime
import numpy as np
from torchvision.transforms import Normalize, Compose, Resize, ToTensor
batch_size = 1
def convert_to_rgb(image):
return image.convert("RGB")
def get_transform(image_size=384):
return Compose([
convert_to_rgb,
Resize((image_size, image_size)),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def load_tag_list(tag_list_file):
with open(tag_list_file, 'r', encoding="utf-8") as f:
tag_list = f.read().splitlines()
tag_list = np.array(tag_list)
return tag_list
def load_word_vocabulary(word_vocabulary_file):
with open(word_vocabulary_file, 'r', encoding="utf-8") as f:
word_vocabulary = f.read().splitlines()
words = [word.split(',') for word in word_vocabulary]
word2idx = {}
for i in range(len(words)):
for j in range(len(words[i])):
word2idx[words[i][j]] = i
return word2idx
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="Inf009/ram-tagger", repo_type="model", local_dir="resources", filename="ram_swin_large_14m_b1.onnx", local_dir_use_symlinks=False)
ort_session = onnxruntime.InferenceSession("resources/ram_swin_large_14m_b1.onnx", providers=["CUDAExecutionProvider"])
transform = get_transform()
tag_list = load_tag_list("resources/ram_tag_list.txt")
word_index = load_word_vocabulary("resources/word_vocabulary_english.txt")
def inference_by_image_pil(image):
image_arrays = transform(image).unsqueeze(0).numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: image_arrays}
ort_outs = ort_session.run(None, ort_inputs)
index = np.argwhere(ort_outs[0][0] == 1)
token = tag_list[index].squeeze(axis=1).tolist()
token = rerank_tags(token)
return ",".join(token)
def rerank_tags(tags):
indexed_tags = [[] for _ in range(max(word_index.values()) + 1)]
for tag in tags:
indexed_tags[word_index[tag]].append(tag)
reranked_tags = []
for indexed_tag in indexed_tags:
reranked_tags += indexed_tag
return reranked_tags
app = gr.Interface(fn=inference_by_image_pil, inputs=grc.Image(type='pil'),
outputs=grc.Text(), title="RAM Tagger",
description="A tagger for images.")
app.launch()