import torch from tqdm import tqdm from pyiqa.archs.musiq_arch import MUSIQ from vbench.utils import load_video, load_dimension_info def transform(images): return images / 255. def technical_quality(model, video_list, device): video_results = [] for video_path in tqdm(video_list): images = load_video(video_path) images = transform(images) acc_score_video = 0. for i in range(len(images)): frame = images[i].unsqueeze(0).to(device) score = model(frame) acc_score_video += float(score) video_results.append({'video_path': video_path, 'video_results': acc_score_video/len(images)}) average_score = sum([o['video_results'] for o in video_results]) / len(video_results) average_score = average_score / 100. return average_score, video_results def compute_imaging_quality(json_dir, device, submodules_list): model_path = submodules_list['model_path'] model = MUSIQ(pretrained_model_path=model_path) model.to(device) model.training = False video_list, _ = load_dimension_info(json_dir, dimension='imaging_quality', lang='en') all_results, video_results = technical_quality(model, video_list, device) return all_results, video_results