Spaces:
Runtime error
Runtime error
import torch | |
from tqdm import tqdm | |
from pyiqa.archs.musiq_arch import MUSIQ | |
from vbench.utils import load_video, load_dimension_info | |
def transform(images): | |
return images / 255. | |
def technical_quality(model, video_list, device): | |
video_results = [] | |
for video_path in tqdm(video_list): | |
images = load_video(video_path) | |
images = transform(images) | |
acc_score_video = 0. | |
for i in range(len(images)): | |
frame = images[i].unsqueeze(0).to(device) | |
score = model(frame) | |
acc_score_video += float(score) | |
video_results.append({'video_path': video_path, 'video_results': acc_score_video/len(images)}) | |
average_score = sum([o['video_results'] for o in video_results]) / len(video_results) | |
average_score = average_score / 100. | |
return average_score, video_results | |
def compute_imaging_quality(json_dir, device, submodules_list): | |
model_path = submodules_list['model_path'] | |
model = MUSIQ(pretrained_model_path=model_path) | |
model.to(device) | |
model.training = False | |
video_list, _ = load_dimension_info(json_dir, dimension='imaging_quality', lang='en') | |
all_results, video_results = technical_quality(model, video_list, device) | |
return all_results, video_results | |