Spaces:
Build error
Build error
File size: 5,561 Bytes
a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada b71e116 a7b8ada b71e116 a7b8ada b71e116 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 b71e116 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import ruamel_yaml as yaml
import numpy as np
import random
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.tag2text import tag2text_caption
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
#######Swin Version
pretrained = '/home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth'
config_file = 'configs/tag2text_caption.yaml'
config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)
model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit=config['vit'],
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
prompt=config['prompt'],config=config,threshold = 0.75 )
model.eval()
model = model.to(device)
def inference(raw_image, model_n, input_tag, strategy):
if model_n == 'Image Captioning':
raw_image = raw_image.resize((image_size, image_size))
image = transform(raw_image).unsqueeze(0).to(device)
model.threshold = 0.7
if input_tag == '' or input_tag == 'none' or input_tag == 'None':
input_tag_list = None
else:
input_tag_list = []
input_tag_list.append(input_tag.replace(',',' | '))
with torch.no_grad():
if strategy == "Beam search":
caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)
if input_tag_list == None:
tag_1 = tag_predict
tag_2 = ['none']
else:
_, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
tag_2 = tag_predict
else:
caption,tag_predict = model.generate(image, tag_input = input_tag_list,sample=True, top_p=0.9, max_length=20, min_length=5, return_tag_predict = True)
if input_tag_list == None:
tag_1 = tag_predict
tag_2 = ['none']
else:
_, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
tag_2 = tag_predict
return tag_1[0],tag_2[0],caption[0]
else:
image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
with torch.no_grad():
answer = model_vq(image_vq, question, train=False, inference='generate')
return 'answer: '+answer[0]
inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning'], type="value", default="Image Captioning", label="Task"),gr.inputs.Textbox(lines=2, label="User Identified Tags (Optional, Enter with commas)"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Beam search", label="Caption Decoding Strategy")]
outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Identified Tags"), gr.outputs.Textbox(label="Image Caption") ]
title = "Tag2Text"
description = "Gradio demo for Tag2Text: Guiding Language-Image Model via Image Tagging (Fudan University, OPPO Research Institute, International Digital Economy Academy)."
article = "<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>"
demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000551338.jpg',"Image Captioning","none","Beam search"],
['images/COCO_val2014_000000551338.jpg',"Image Captioning","fence, sky","Beam search"],
# ['images/COCO_val2014_000000551338.jpg',"Image Captioning","grass","Beam search"],
['images/COCO_val2014_000000483108.jpg',"Image Captioning","none","Beam search"],
['images/COCO_val2014_000000483108.jpg',"Image Captioning","electric cable","Beam search"],
# ['images/COCO_val2014_000000483108.jpg',"Image Captioning","sky, train","Beam search"],
['images/COCO_val2014_000000483108.jpg',"Image Captioning","track, train","Beam search"] ,
['images/COCO_val2014_000000483108.jpg',"Image Captioning","grass","Beam search"]
])
|