# ------------------------------------------------------------------------------------------ # Copyright (c) 2024 Baifeng Shi. # All rights reserved. # # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import math import torch import torch.nn.functional as F from einops import rearrange from .utils import batched_forward, merge_chessboard, split_chessboard def forward( model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0, output_shape="bnc", split_forward=False, ): assert input.dim() == 4, "Input image must be in the shape of BxCxHxW." assert input.shape[2] == input.shape[3], "Currently only square images are supported." assert output_shape in [ "bnc", "bchw", ], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)." assert ( output_shape == "bnc" or num_prefix_token == 0 ), "For ConvNet there shouldn't be any prefix token." b, c, input_size, _ = input.shape # image size for each scale assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes." img_sizes = img_sizes or [int(input_size * scale) for scale in scales] # prepare multiscale inputs max_split_size = ( max_split_size or input_size ) # The maximum size of each split of image. Set as the input size by default num_splits = [ math.ceil(size / max_split_size) for size in img_sizes ] # number of splits each scale input_multiscale = [] for size, num_split in zip(img_sizes, num_splits): x = F.interpolate(input.to(torch.float32), size=size, mode="bicubic").to(input.dtype) x = split_chessboard(x, num_split=num_split) input_multiscale.append(x) # run feedforward on each scale outs_multiscale = [ batched_forward(model, x, b) if split_forward else model(x) for x in input_multiscale ] if num_prefix_token > 0: outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale] outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale] if output_shape == "bnc": outs_multiscale = [ rearrange( out, "b (h w) c -> b c h w", h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5) ) for out in outs_multiscale ] # merge outputs of different splits for each scale separately outs_multiscale = [ merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale) ] # interpolate outputs from different scales and concat together output_size = outs_multiscale[resize_output_to_idx].shape[-2] out = torch.cat( [ F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size, mode="area").to( outs_multiscale[i].dtype ) for i in range(len(outs_multiscale)) ], dim=1, ) if output_shape == "bnc": out = rearrange(out, "b c h w -> b (h w) c") if num_prefix_token > 0: # take the mean of prefix tokens from different splits for each scale outs_prefix_multiscale = [ torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale ] out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1) out = torch.cat([out_prefix_multiscale, out], dim=1) return out