File size: 4,906 Bytes
9afd745 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from typing import (
Any,
Iterable,
List,
Optional,
Tuple,
cast,
Generator
)
import torch
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores.qdrant import Qdrant
from langchain_core.pydantic_v1 import Field
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain.schema import Document
try:
from qdrant_client import QdrantClient, models
except ImportError:
pass
def batchify(_list: List, batch_size: int) -> Generator[List, None, None]:
for i in range(0, len(_list), batch_size):
yield _list[i:i + batch_size]
class MyQdrantSparseVectorRetriever(QdrantSparseVectorRetriever):
splade_doc_tokenizer: Any = Field(repr=False)
splade_doc_model: Any = Field(repr=False)
splade_query_tokenizer: Any = Field(repr=False)
splade_query_model: Any = Field(repr=False)
device: Any = Field(repr=False)
batch_size: int = Field(repr=False)
sparse_encoder: Any or None = Field(repr=False)
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def compute_document_vectors(self, texts: List[str], batch_size: int) -> Tuple[List[List[int]], List[List[float]]]:
indices = []
values = []
for text_batch in batchify(texts, batch_size):
with torch.no_grad():
tokens = self.splade_doc_tokenizer(text_batch, truncation=True, padding=True,
return_tensors="pt").to(self.device)
output = self.splade_doc_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
tvecs, _ = torch.max(weighted_log, dim=1)
# extract all non-zero values and their indices from the sparse vectors
for batch in tvecs.cpu():
indices.append(batch.nonzero(as_tuple=True)[0].numpy())
values.append(batch[indices[-1]].numpy())
return indices, values
def compute_query_vector(self, text: str):
"""
Computes a vector from logits and attention mask using ReLU, log, and max operations.
"""
with torch.no_grad():
tokens = self.splade_query_tokenizer(text, return_tensors="pt").to(self.device)
output = self.splade_query_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
query_vec = max_val.squeeze().cpu()
query_indices = query_vec.nonzero().numpy().flatten()
query_values = query_vec.detach().numpy()[query_indices]
return query_indices, query_values
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
):
client = cast(QdrantClient, self.client)
indices, values = self.compute_document_vectors(texts, self.batch_size)
points = [
models.PointStruct(
id=i + 1,
vector={
self.sparse_vector_name: models.SparseVector(
indices=indices[i],
values=values[i],
)
},
payload={
self.content_payload_key: texts[i],
self.metadata_payload_key: metadatas[i] if metadatas else None,
},
)
for i in range(len(texts))
]
client.upsert(self.collection_name, points=points, **kwargs)
if self.device == "cuda":
torch.cuda.empty_cache()
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
client = cast(QdrantClient, self.client)
query_indices, query_values = self.compute_query_vector(query)
results = client.search(
self.collection_name,
query_filter=self.filter,
query_vector=models.NamedSparseVector(
name=self.sparse_vector_name,
vector=models.SparseVector(
indices=query_indices,
values=query_values,
),
),
limit=self.k,
with_vectors=False,
**self.search_options,
)
return [
Qdrant._document_from_scored_point(
point,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
)
for point in results
]
|