ColFlor-Demo / processing_utils.py
ahmed-masry's picture
Create processing_utils.py
9f9c2cc verified
raw
history blame
3.75 kB
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature
def get_torch_device(device: str = "auto") -> str:
"""
Returns the device (string) to be used by PyTorch.
`device` arg defaults to "auto" which will use:
- "cuda:0" if available
- else "mps" if available
- else "cpu".
"""
if device == "auto":
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available(): # for Apple Silicon
device = "mps"
else:
device = "cpu"
logger.info(f"Using device: {device}")
return device
class BaseVisualRetrieverProcessor(ABC):
"""
Base class for visual retriever processors.
"""
@abstractmethod
def process_images(
self,
images: List[Image.Image],
) -> Union[BatchFeature, BatchEncoding]:
pass
@abstractmethod
def process_queries(
self,
queries: List[str],
max_length: int = 50,
suffix: Optional[str] = None,
) -> Union[BatchFeature, BatchEncoding]:
pass
@abstractmethod
def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
pass
@staticmethod
def score_single_vector(
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""
Compute the dot product score for the given single-vector query and passage embeddings.
"""
device = device or get_torch_device("auto")
if len(qs) == 0:
raise ValueError("No queries provided")
if len(ps) == 0:
raise ValueError("No passages provided")
qs_stacked = torch.stack(qs).to(device)
ps_stacked = torch.stack(ps).to(device)
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
scores = scores.to(torch.float32)
return scores
@staticmethod
def score_multi_vector(
qs: List[torch.Tensor],
ps: List[torch.Tensor],
batch_size: int = 128,
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
"""
device = device or get_torch_device("auto")
if len(qs) == 0:
raise ValueError("No queries provided")
if len(ps) == 0:
raise ValueError("No passages provided")
scores_list: List[torch.Tensor] = []
for i in range(0, len(qs), batch_size):
scores_batch = []
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
device
)
for j in range(0, len(ps), batch_size):
ps_batch = torch.nn.utils.rnn.pad_sequence(
ps[j : j + batch_size], batch_first=True, padding_value=0
).to(device)
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
scores_batch = torch.cat(scores_batch, dim=1).cpu()
scores_list.append(scores_batch)
scores = torch.cat(scores_list, dim=0)
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
scores = scores.to(torch.float32)
return scores