Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Xueyan Zou ([email protected]), Jianwei Yang ([email protected]) | |
# -------------------------------------------------------- | |
import os | |
os.system("python -m pip install git+https://github.com/MaureenZOU/detectron2-xyz.git") | |
import gradio as gr | |
import torch | |
import argparse | |
from xdecoder.BaseModel import BaseModel | |
from xdecoder import build_model | |
from utils.distributed import init_distributed | |
from utils.arguments import load_opt_from_config_files | |
from tasks import * | |
def parse_option(): | |
parser = argparse.ArgumentParser('X-Decoder All-in-One Demo', add_help=False) | |
parser.add_argument('--conf_files', default="configs/xdecoder/svlp_focalt_lang.yaml", metavar="FILE", help='path to config file', ) | |
args = parser.parse_args() | |
return args | |
''' | |
build args | |
''' | |
args = parse_option() | |
opt = load_opt_from_config_files(args.conf_files) | |
opt = init_distributed(opt) | |
# META DATA | |
pretrained_pth_last = os.path.join("xdecoder_focalt_last.pt") | |
pretrained_pth_novg = os.path.join("xdecoder_focalt_last_novg.pt") | |
if not os.path.exists(pretrained_pth_last): | |
os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last.pt")) | |
if not os.path.exists(pretrained_pth_novg): | |
os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last_novg.pt")) | |
''' | |
build model | |
''' | |
model_last = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_last).eval().cuda() | |
model_cap = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_novg).eval().cuda() | |
with torch.no_grad(): | |
model_last.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True) | |
model_cap.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True) | |
''' | |
inference model | |
''' | |
def inference(image, task, *args, **kwargs): | |
image = image.convert("RGB") | |
with torch.autocast(device_type='cuda', dtype=torch.float16): | |
if task == 'Referring Editing': | |
return referring_inpainting(model_last, image, *args, **kwargs) | |
elif task == 'Referring Segmentation': | |
return referring_segmentation(model_last, image, *args, **kwargs) | |
elif task == 'Open Vocabulary Semantic Segmentation': | |
return open_semseg(model_last, image, *args, **kwargs) | |
elif task == 'Open Vocabulary Panoptic Segmentation': | |
return open_panoseg(model_last, image, *args, **kwargs) | |
elif task == 'Open Vocabulary Instance Segmentation': | |
return open_instseg(model_last, image, *args, **kwargs) | |
elif task == 'Image Captioning': | |
return image_captioning(model_cap, image, *args, **kwargs) | |
elif task == 'Referring Captioning (Beta)': | |
return referring_captioning([model_last, model_cap], image, *args, **kwargs) | |
elif task == 'Text Retrieval': | |
return text_retrieval(model_cap, image, *args, **kwargs) | |
elif task == 'Image/Region Retrieval': | |
return region_retrieval([model_cap, model_last], image, *args, **kwargs) | |
''' | |
launch app | |
''' | |
title = "X-Decoder All-in-One Demo" | |
description = """<p style='text-align: center'> <a href='https://x-decoder-vl.github.io/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/pdf/2212.11270.pdf' target='_blank'>Paper</a> | <a href='https://github.com/microsoft/X-Decoder' target='_blank'>Github Repo</a> | <a href='https://youtu.be/wYp6vmyolqE' target='_blank'>Video</a> </p> | |
<p>Skip the queue by duplicating this space and upgrading to GPU in settings</p> | |
<a href="https://huggingface.co./spaces/xdecoder/Demo?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> | |
""" | |
article = "The Demo is Run on X-Decoder (Focal-T)." | |
inputs = [gr.inputs.Image(type='pil'), gr.inputs.Radio(choices=["Referring Segmentation", "Referring Editing", 'Open Vocabulary Semantic Segmentation','Open Vocabulary Instance Segmentation', "Open Vocabulary Panoptic Segmentation", "Image Captioning", "Text Retrieval", "Image/Region Retrieval", "Referring Captioning (Beta)"], type="value", default="OpenVocab Semantic Segmentation", label="Task"), gr.Textbox(label="xdecoder_text"), gr.Textbox(label="inpainting_text"), gr.Textbox(label="task_description")] | |
gr.Interface( | |
fn=inference, | |
inputs=inputs, | |
outputs=[ | |
gr.outputs.Image( | |
type="pil", | |
label="segmentation results"), | |
gr.Textbox(label="text results"), | |
gr.outputs.Image( | |
type="pil", | |
label="editing results"), | |
], | |
examples=[ | |
["./images/fruit.jpg", "Referring Segmentation", "The larger watermelon.,The front white flower.,White tea pot.,Flower bunch.,white vase.,The peach on the left.,The brown knife.,The handkerchief.", '', 'Format: s,s,s'], | |
["./images/apples.jpg", "Referring Editing", "the green apple", 'a red apple', 'x-decoder + ldm (inference takes ~20s), use inpainting_text "clean and empty scene" for image inpainting'], | |
["./images/horse.png", "Referring Editing", "the sky", 'a mountain', 'x-decoder + ldm (inference takes ~20s), use inpainting_text "clean and empty scene" for image inpainting'], | |
["./images/animals.png", "Open Vocabulary Semantic Segmentation", "zebra,antelope,giraffe,ostrich,sky,water,grass,sand,tree", '', 'Format: x,x,x'], | |
["./images/owls.jpeg", "Open Vocabulary Instance Segmentation", "owl", '', 'Format: y,y,y'], | |
["./images/mountain.jpeg", "Image Captioning", "", '', ''], | |
["./images/rose.webp", "Text Retrieval", "lily,rose,peoney,tulip", '', 'Format: s,s,s'], | |
["./images/region_retrieval.png", "Image/Region Retrieval", "The tangerine on the plate.", '', 'Please describe the object in a detailed way (80 images in the pool).'], | |
["./images/landscape.jpg", "Referring Captioning (Beta)", "cloud", '', 'Please fill in a noun/noun phrase. (may start with a/the)'], | |
], | |
title=title, | |
description=description, | |
article=article, | |
allow_flagging='never', | |
cache_examples=True, | |
).launch() |