Spaces:
Runtime error
Runtime error
File size: 6,547 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import is_model_wrapper
from mmengine.runner import TestLoop, ValLoop, autocast
from mmpretrain.registry import LOOPS
@LOOPS.register_module()
class RetrievalValLoop(ValLoop):
"""Loop for multimodal retrieval val.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 valing. Defaults to
False.
"""
def run(self) -> dict:
"""Launch val."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
feats_local = []
data_samples_local = []
for idx, data_batch in enumerate(self.dataloader):
with torch.no_grad():
self.runner.call_hook(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
if is_model_wrapper(self.runner.model):
data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501
else:
data_preprocessor = self.runner.model.data_preprocessor
# get features for retrieval instead of data samples
data_batch = data_preprocessor(data_batch, False)
feats = self.runner.model._run_forward(
data_batch, mode='tensor')
feats_local.append(feats)
data_samples_local.extend(data_batch['data_samples'])
self.runner.call_hook(
'after_val_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=feats)
# concatenate different features
feats_local = {
k: torch.cat([dic[k] for dic in feats_local])
for k in feats_local[0]
}
# get predictions
if is_model_wrapper(self.runner.model):
predict_all_fn = self.runner.model.module.predict_all
else:
predict_all_fn = self.runner.model.predict_all
img_size = self.dataloader.dataset.img_size
text_size = self.dataloader.dataset.text_size
with torch.no_grad():
i2t_data_samples, t2i_data_samples = predict_all_fn(
feats_local,
data_samples_local,
num_images=img_size,
num_texts=text_size,
)
# process in evaluator and compute metrics
self.evaluator.process(i2t_data_samples, None)
i2t_metrics = self.evaluator.evaluate(img_size)
i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()}
self.evaluator.process(t2i_data_samples, None)
t2i_metrics = self.evaluator.evaluate(text_size)
t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()}
metrics = {**i2t_metrics, **t2i_metrics}
self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
@LOOPS.register_module()
class RetrievalTestLoop(TestLoop):
"""Loop for multimodal retrieval test.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 testing. Defaults to
False.
"""
def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()
feats_local = []
data_samples_local = []
for idx, data_batch in enumerate(self.dataloader):
with torch.no_grad():
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
if is_model_wrapper(self.runner.model):
data_preprocessor = self.runner.model.module.data_preprocessor # noqa: E501
else:
data_preprocessor = self.runner.model.data_preprocessor
# get features for retrieval instead of data samples
data_batch = data_preprocessor(data_batch, False)
feats = self.runner.model._run_forward(
data_batch, mode='tensor')
feats_local.append(feats)
data_samples_local.extend(data_batch['data_samples'])
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=feats)
# concatenate different features
feats_local = {
k: torch.cat([dic[k] for dic in feats_local])
for k in feats_local[0]
}
# get predictions
if is_model_wrapper(self.runner.model):
predict_all_fn = self.runner.model.module.predict_all
else:
predict_all_fn = self.runner.model.predict_all
img_size = self.dataloader.dataset.img_size
text_size = self.dataloader.dataset.text_size
with torch.no_grad():
i2t_data_samples, t2i_data_samples = predict_all_fn(
feats_local,
data_samples_local,
num_images=img_size,
num_texts=text_size,
)
# process in evaluator and compute metrics
self.evaluator.process(i2t_data_samples, None)
i2t_metrics = self.evaluator.evaluate(img_size)
i2t_metrics = {f'i2t/{k}': v for k, v in i2t_metrics.items()}
self.evaluator.process(t2i_data_samples, None)
t2i_metrics = self.evaluator.evaluate(text_size)
t2i_metrics = {f't2i/{k}': v for k, v in t2i_metrics.items()}
metrics = {**i2t_metrics, **t2i_metrics}
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
|