Spaces:
Paused
Paused
# -------------------------------------------------------- | |
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Xueyan Zou ([email protected]) | |
# -------------------------------------------------------- | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
from utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog | |
t = [] | |
t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) | |
transform_ret = transforms.Compose(t) | |
t = [] | |
t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) | |
transform_grd = transforms.Compose(t) | |
metedata = MetadataCatalog.get('coco_2017_train_panoptic') | |
def referring_captioning(model, image, texts, inpainting_text, *args, **kwargs): | |
model_last, model_cap = model | |
with torch.no_grad(): | |
image_ori = image | |
image = transform_grd(image) | |
width = image.size[0] | |
height = image.size[1] | |
image = np.asarray(image) | |
image_ori_ = image | |
images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() | |
texts_input = [[texts.strip() if texts.endswith('.') else (texts + '.')]] | |
batch_inputs = [{'image': images, 'groundings': {'texts':texts_input}, 'height': height, 'width': width}] | |
outputs = model_last.model.evaluate_grounding(batch_inputs, None) | |
grd_mask = (outputs[-1]['grounding_mask'] > 0).float() | |
grd_mask_ = (1 - F.interpolate(grd_mask[None,], (224, 224), mode='nearest')[0]).bool() | |
color = [252/255, 91/255, 129/255] | |
visual = Visualizer(image_ori_, metadata=metedata) | |
demo = visual.draw_binary_mask(grd_mask.cpu().numpy()[0], color=color, text=texts) | |
res = demo.get_image() | |
if (1 - grd_mask_.float()).sum() < 5: | |
torch.cuda.empty_cache() | |
return Image.fromarray(res), 'n/a', None | |
grd_mask_ = grd_mask_ * 0 | |
image = transform_ret(image_ori) | |
image_ori = np.asarray(image_ori) | |
image = np.asarray(image) | |
images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() | |
batch_inputs = [{'image': images, 'image_id': 0, 'captioning_mask': grd_mask_}] | |
token_text = texts.replace('.','') if texts.endswith('.') else texts | |
token = model_cap.model.sem_seg_head.predictor.lang_encoder.tokenizer.encode(token_text) | |
token = torch.tensor(token)[None,:-1] | |
outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={'token': token}) | |
# outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={}) | |
text = outputs[-1]['captioning_text'] | |
torch.cuda.empty_cache() | |
return Image.fromarray(res), text, None |