import os import json import numpy as np import clip from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from vbench.utils import load_video, load_dimension_info from vbench.third_party.umt.datasets.video_transforms import ( Compose, Resize, CenterCrop, Normalize, create_random_augment, random_short_side_scale_jitter, random_crop, random_resized_crop_with_shift, random_resized_crop, horizontal_flip, random_short_side_scale_jitter, uniform_crop, ) from vbench.third_party.umt.datasets.volume_transforms import ClipToTensor from timm.models import create_model from vbench.third_party.umt.models.modeling_finetune import vit_large_patch16_224 from tqdm import tqdm def build_dict(): CUR_DIR = os.path.dirname(os.path.abspath(__file__)) path = f'{CUR_DIR}/third_party/umt/kinetics_400_categories.txt' results = {} with open(path, 'r') as f: cat_list = f.readlines() cat_list = [c.strip() for c in cat_list] for line in cat_list: cat, number = line.split('\t') results[number] = cat.lower() return results def human_action(umt_path, video_list, device): state_dict = torch.load(umt_path, map_location='cpu') model = create_model( "vit_large_patch16_224", pretrained=False, num_classes=400, all_frames=16, tubelet_size=1, use_learnable_pos_emb=False, fc_drop_rate=0., drop_rate=0., drop_path_rate=0.2, attn_drop_rate=0., drop_block_rate=None, use_checkpoint=False, checkpoint_num=16, use_mean_pooling=True, init_scale=0.001, ) data_transform = Compose([ Resize(256, interpolation='bilinear'), CenterCrop(size=(224, 224)), ClipToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) model = model.to(device) model.load_state_dict(state_dict, strict=False) model.eval() cat_dict = build_dict() cnt= 0 cor_num = 0 video_results = [] for video_path in tqdm(video_list): video_label_ls = video_path.split('/')[-1].lower().split('-')[0].split("person is ")[-1].split('_')[0] cnt += 1 images = load_video(video_path, data_transform, num_frames=16) images = images.unsqueeze(0) images = images.to(device) with torch.no_grad(): logits = torch.sigmoid(model(images)) results, indices = torch.topk(logits, 5, dim=1) indices = indices.squeeze().tolist() results = results.squeeze().tolist() results = [round(f, 4) for f in results] cat_ls = [] for i in range(5): if results[i] >= 0.85: cat_ls.append(cat_dict[str(indices[i])]) flag = False for cat in cat_ls: if cat == video_label_ls: cor_num += 1 flag = True # print(f"{cnt}: {video_path} correct, top-5: {cat_ls}, logits: {results}", flush=True) break if flag is False: # print(f"{cnt}: {video_path} false, gt: {video_label_ls}, top-5: {cat_ls}, logits: {results}", flush=True) pass video_results.append({'video_path': video_path, 'video_results': flag}) # print(f"cor num: {cor_num}, total: {cnt}") acc = cor_num / cnt return acc, video_results def compute_human_action(json_dir, device, submodules_list): umt_path = submodules_list[0] video_list, _ = load_dimension_info(json_dir, dimension='human_action', lang='en') all_results, video_results = human_action(umt_path, video_list, device) return all_results, video_results