import argparse import os import sys import cv2 import numpy as np import torch import torch.nn.functional as F from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import AutoTokenizer, BitsAndBytesConfig from model.segment_anything.utils.transforms import ResizeLongestSide def parse_args(args): parser = argparse.ArgumentParser(description="EVF infer") parser.add_argument("--version", required=True) parser.add_argument("--vis_save_path", default="./infer", type=str) parser.add_argument( "--precision", default="fp16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for inference", ) parser.add_argument("--image_size", default=224, type=int, help="image size") parser.add_argument("--model_max_length", default=512, type=int) parser.add_argument("--local-rank", default=0, type=int, help="node rank") parser.add_argument("--load_in_8bit", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=False) parser.add_argument("--model_type", default="ori", choices=["ori", "effi"]) parser.add_argument("--image_path", type=str, default="assets/zebra.jpg") parser.add_argument("--prompt", type=str, default="zebra top left") return parser.parse_args(args) def sam_preprocess( x: np.ndarray, pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), img_size=1024, model_type="ori") -> torch.Tensor: ''' preprocess of Segment Anything Model, including scaling, normalization and padding. preprocess differs between SAM and Effi-SAM, where Effi-SAM use no padding. input: ndarray output: torch.Tensor ''' assert img_size==1024, \ "both SAM and Effi-SAM receive images of size 1024^2, don't change this setting unless you're sure that your employed model works well with another size." x = ResizeLongestSide(img_size).apply_image(x) resize_shape = x.shape[:2] x = torch.from_numpy(x).permute(2,0,1).contiguous() # Normalize colors x = (x - pixel_mean) / pixel_std if model_type=="effi": x = F.interpolate(x.unsqueeze(0), (img_size, img_size), mode="bilinear").squeeze(0) else: # Pad h, w = x.shape[-2:] padh = img_size - h padw = img_size - w x = F.pad(x, (0, padw, 0, padh)) return x, resize_shape def beit3_preprocess(x: np.ndarray, img_size=224) -> torch.Tensor: ''' preprocess for BEIT-3 model. input: ndarray output: torch.Tensor ''' beit_preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) return beit_preprocess(x) def init_models(args): tokenizer = AutoTokenizer.from_pretrained( args.version, padding_side="right", use_fast=False, ) torch_dtype = torch.float32 if args.precision == "bf16": torch_dtype = torch.bfloat16 elif args.precision == "fp16": torch_dtype = torch.half kwargs = {"torch_dtype": torch_dtype} if args.load_in_4bit: kwargs.update( { "torch_dtype": torch.half, "quantization_config": BitsAndBytesConfig( llm_int8_skip_modules=["visual_model"], load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ), } ) elif args.load_in_8bit: kwargs.update( { "torch_dtype": torch.half, "quantization_config": BitsAndBytesConfig( llm_int8_skip_modules=["visual_model"], load_in_8bit=True, ), } ) if args.model_type=="ori": from model.evf_sam import EvfSamModel model = EvfSamModel.from_pretrained( args.version, low_cpu_mem_usage=True, **kwargs ) elif args.model_type=="effi": from model.evf_effisam import EvfEffiSamModel model = EvfEffiSamModel.from_pretrained( args.version, low_cpu_mem_usage=True, **kwargs ) if (not args.load_in_4bit) and (not args.load_in_8bit): model = model.cuda() model.eval() return tokenizer, model def main(args): args = parse_args(args) # clarify IO image_path = args.image_path if not os.path.exists(image_path): print("File not found in {}".format(image_path)) exit() prompt = args.prompt os.makedirs(args.vis_save_path, exist_ok=True) save_path = "{}/{}_vis.png".format( args.vis_save_path, os.path.basename(image_path).split(".")[0] ) # initialize model and tokenizer tokenizer, model = init_models(args) # preprocess image_np = cv2.imread(image_path) image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) original_size_list = [image_np.shape[:2]] image_beit = beit3_preprocess(image_np, args.image_size).to(dtype=model.dtype, device=model.device) image_sam, resize_shape = sam_preprocess(image_np, model_type=args.model_type) image_sam = image_sam.to(dtype=model.dtype, device=model.device) input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device=model.device) # infer pred_mask = model.inference( image_sam.unsqueeze(0), image_beit.unsqueeze(0), input_ids, resize_list=[resize_shape], original_size_list=original_size_list, ) pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask = pred_mask > 0 # save visualization save_img = image_np.copy() save_img[pred_mask] = ( image_np * 0.5 + pred_mask[:, :, None].astype(np.uint8) * np.array([50, 120, 220]) * 0.5 )[pred_mask] save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) cv2.imwrite(save_path, save_img) if __name__ == "__main__": main(sys.argv[1:])