transfiner / app.py
lkeab
update app
c953948
raw
history blame
2.76 kB
#try:
# import detectron2
#except:
import os
os.system('pip install git+https://github.com/SysCV/transfiner.git')
from matplotlib.pyplot import axis
import gradio as gr
import requests
import numpy as np
from torch import nn
import requests
import torch
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
model_name='./configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x_4gpu_transfiner.yaml'
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_name)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as w ell
cfg.MODEL.WEIGHTS = './output_3x_transfiner_r50.pth'
if not torch.cuda.is_available():
cfg.MODEL.DEVICE='cpu'
predictor = DefaultPredictor(cfg)
def inference(image):
image = image.resize((1024,1024))
img = np.asarray(image)
#img = np.array(image)
outputs = predictor(img)
v = Visualizer(img, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]))
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
return out.get_image()
title = "Mask Transfiner"
description = "Demo for <a target='_blank' href='https://arxiv.org/abs/2111.13673'>Mask Transfiner for High-Quality Instance Segmentation, CVPR 2022</a> based on R50-FPN. To use it, simply upload your image, or click one of the examples to load them. It runs in the cpu environment provided by Hugging Face. Read more at the links below."
article = "<p style='text-align: center'><a target='_blank' href='https://arxiv.org/abs/2111.13673'>Mask Transfiner for High-Quality Instance Segmentation, CVPR 2022</a> | <a target='_blank' href='https://github.com/SysCV/transfiner'>Mask Transfiner Github</a></p>"
gr.Interface(
inference,
[gr.inputs.Image(type="pil", label="Input")],
gr.outputs.Image(type="numpy", label="Output"),
title=title,
description=description,
article=article,
examples=[
["demo/sample_imgs/000000131444.jpg"],
["demo/sample_imgs/000000157365.jpg"],
["demo/sample_imgs/000000176037.jpg"],
["demo/sample_imgs/000000018737.jpg"],
["demo/sample_imgs/000000224200.jpg"],
["demo/sample_imgs/000000558073.jpg"],
["demo/sample_imgs/000000404922.jpg"],
["demo/sample_imgs/000000252776.jpg"],
["demo/sample_imgs/000000482477.jpg"],
["demo/sample_imgs/000000344909.jpg"]
]).launch()