geekyrakshit commited on
Commit
b82a487
·
1 Parent(s): c063304

add: ContrieverRetriever

Browse files
.gitignore CHANGED
@@ -19,4 +19,5 @@ cursor_prompt.txt
19
  test.py
20
  uv.lock
21
  grays-anatomy-bm25s/
22
- prompt**.txt
 
 
19
  test.py
20
  uv.lock
21
  grays-anatomy-bm25s/
22
+ prompt**.txt
23
+ **.safetensors
docs/retreival/contriever.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Contriever Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.contriever_retrieval
medrag_multi_modal/retrieval/__init__.py CHANGED
@@ -1,4 +1,10 @@
1
  from .bm25s_retrieval import BM25sRetriever
2
  from .colpali_retrieval import CalPaliRetriever
 
3
 
4
- __all__ = ["CalPaliRetriever", "BM25sRetriever"]
 
 
 
 
 
 
1
  from .bm25s_retrieval import BM25sRetriever
2
  from .colpali_retrieval import CalPaliRetriever
3
+ from .contriever_retrieval import ContrieverRetriever, SimilarityMetric
4
 
5
+ __all__ = [
6
+ "CalPaliRetriever",
7
+ "BM25sRetriever",
8
+ "ContrieverRetriever",
9
+ "SimilarityMetric",
10
+ ]
medrag_multi_modal/retrieval/common.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import wandb
4
+
5
+
6
+ class SimilarityMetric(Enum):
7
+ COSINE = "cosine"
8
+ EUCLIDEAN = "euclidean"
9
+
10
+
11
+ def mean_pooling(token_embeddings, mask):
12
+ token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
13
+ sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
14
+ return sentence_embeddings
15
+
16
+
17
+ def get_wandb_artifact(artifact_address: str, artifact_type: str):
18
+ if wandb.run:
19
+ artifact = wandb.run.use_artifact(artifact_address, type=artifact_type)
20
+ artifact_dir = artifact.download()
21
+ else:
22
+ api = wandb.Api()
23
+ artifact = api.artifact(artifact_address)
24
+ artifact_dir = artifact.download()
25
+ metadata = artifact.metadata
26
+ return artifact_dir, metadata
27
+
28
+
29
+ def argsort_scores(scores: list[float], descending: bool = False):
30
+ return [
31
+ {"item": item, "original_index": idx}
32
+ for idx, item in sorted(
33
+ list(enumerate(scores)), key=lambda x: x[1], reverse=descending
34
+ )
35
+ ]
medrag_multi_modal/retrieval/contriever_retrieval.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import safetensors
5
+ import safetensors.torch
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import weave
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoTokenizer,
12
+ BertPreTrainedModel,
13
+ PreTrainedTokenizerFast,
14
+ )
15
+
16
+ import wandb
17
+
18
+ from .common import SimilarityMetric, argsort_scores, get_wandb_artifact, mean_pooling
19
+
20
+
21
+ class ContrieverRetriever(weave.Model):
22
+ """
23
+ `ContrieverRetriever` is a class to perform retrieval tasks using the Contriever model.
24
+
25
+ It provides methods to encode text data into embeddings, index a dataset of text chunks,
26
+ and retrieve the most relevant chunks for a given query based on similarity metrics.
27
+
28
+ Args:
29
+ model_name (str): The name of the pre-trained model to use for encoding.
30
+ vector_index (Optional[torch.Tensor]): The tensor containing the vector representations
31
+ of the indexed chunks.
32
+ chunk_dataset (Optional[list[dict]]): The weave dataset of text chunks to be indexed.
33
+ """
34
+
35
+ model_name: str
36
+ _chunk_dataset: Optional[list[dict]]
37
+ _tokenizer: PreTrainedTokenizerFast
38
+ _model: BertPreTrainedModel
39
+ _vector_index: Optional[torch.Tensor]
40
+
41
+ def __init__(
42
+ self,
43
+ model_name: str = "facebook/contriever",
44
+ vector_index: Optional[torch.Tensor] = None,
45
+ chunk_dataset: Optional[list[dict]] = None,
46
+ ):
47
+ super().__init__(model_name=model_name)
48
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
49
+ self._model = AutoModel.from_pretrained(self.model_name)
50
+ self._vector_index = vector_index
51
+ self._chunk_dataset = chunk_dataset
52
+
53
+ def encode(self, corpus: list[str]) -> torch.Tensor:
54
+ inputs = self._tokenizer(
55
+ corpus, padding=True, truncation=True, return_tensors="pt"
56
+ )
57
+ outputs = self._model(**inputs)
58
+ return mean_pooling(outputs[0], inputs["attention_mask"])
59
+
60
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
61
+ """
62
+ Indexes a dataset of text chunks and optionally saves the vector index to a file.
63
+
64
+ This method retrieves a dataset of text chunks from a Weave reference, encodes the
65
+ text chunks into vector representations using the Contriever model, and stores the
66
+ resulting vector index. If an index name is provided, the vector index is saved to
67
+ a file in the safetensors format. Additionally, if a Weave run is active, the vector
68
+ index file is logged as an artifact to Weave.
69
+
70
+ !!! example "Example Usage"
71
+ ```python
72
+ import weave
73
+ from dotenv import load_dotenv
74
+
75
+ import wandb
76
+ from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric
77
+
78
+ load_dotenv()
79
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
80
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="contriever-index")
81
+ retriever = ContrieverRetriever(model_name="facebook/contriever")
82
+ retriever.index(chunk_dataset_name="grays-anatomy-chunks:v0", index_name="grays-anatomy-contriever")
83
+ ```
84
+
85
+ Args:
86
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks
87
+ to be indexed.
88
+ index_name (Optional[str]): The name of the index artifact to be saved. If provided,
89
+ the vector index is saved to a file and logged as an artifact to Weave.
90
+ """
91
+ self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows
92
+ corpus = [row["text"] for row in self._chunk_dataset]
93
+ with torch.no_grad():
94
+ vector_index = self.encode(corpus)
95
+ self._vector_index = vector_index
96
+ if index_name:
97
+ safetensors.torch.save_file(
98
+ {"vector_index": vector_index.cpu()}, "vector_index.safetensors"
99
+ )
100
+ if wandb.run:
101
+ artifact = wandb.Artifact(
102
+ name=index_name,
103
+ type="contriever-index",
104
+ metadata={"model_name": self.model_name},
105
+ )
106
+ artifact.add_file("vector_index.safetensors")
107
+ artifact.save()
108
+
109
+ @classmethod
110
+ def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
111
+ """
112
+ Creates an instance of the class from a Weave artifact.
113
+
114
+ This method retrieves a vector index and metadata from a Weave artifact stored in
115
+ Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave
116
+ reference. The vector index is loaded from a safetensors file and moved to the
117
+ appropriate device (CPU or GPU). The text chunks are converted into a list of
118
+ dictionaries. The method then returns an instance of the class initialized with
119
+ the retrieved model name, vector index, and chunk dataset.
120
+
121
+ !!! example "Example Usage"
122
+ ```python
123
+ import weave
124
+ from dotenv import load_dotenv
125
+
126
+ from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric
127
+
128
+ load_dotenv()
129
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
130
+ retriever = ContrieverRetriever.from_wandb_artifact(
131
+ chunk_dataset_name="grays-anatomy-chunks:v0",
132
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-contriever:v1",
133
+ )
134
+ ```
135
+
136
+ Args:
137
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks.
138
+ index_artifact_address (str): The address of the Weave artifact containing the
139
+ vector index.
140
+
141
+ Returns:
142
+ An instance of the class initialized with the retrieved model name, vector index,
143
+ and chunk dataset.
144
+ """
145
+ artifact_dir, metadata = get_wandb_artifact(
146
+ index_artifact_address, "contriever-index"
147
+ )
148
+ with safetensors.torch.safe_open(
149
+ os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
150
+ ) as f:
151
+ vector_index = f.get_tensor("vector_index")
152
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
153
+ vector_index = vector_index.to(device)
154
+ chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
155
+ return cls(
156
+ model_name=metadata["model_name"],
157
+ vector_index=vector_index,
158
+ chunk_dataset=chunk_dataset,
159
+ )
160
+
161
+ @weave.op()
162
+ def retrieve(
163
+ self,
164
+ query: str,
165
+ top_k: int = 2,
166
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
167
+ ):
168
+ """
169
+ Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
170
+
171
+ This method encodes the input query into an embedding and computes similarity scores between
172
+ the query embedding and the precomputed vector index. The similarity metric can be either
173
+ cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
174
+ are returned as a list of dictionaries, each containing a chunk and its corresponding score.
175
+
176
+ !!! example "Example Usage"
177
+ ```python
178
+ import weave
179
+ from dotenv import load_dotenv
180
+
181
+ from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric
182
+
183
+ load_dotenv()
184
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
185
+ retriever = ContrieverRetriever.from_wandb_artifact(
186
+ chunk_dataset_name="grays-anatomy-chunks:v0",
187
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-contriever:v1",
188
+ )
189
+ scores = retriever.retrieve(query="What are Ribosomes?", metric=SimilarityMetric.COSINE)
190
+ ```
191
+
192
+ Args:
193
+ query (str): The input query string to search for relevant chunks.
194
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
195
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring.
196
+
197
+ Returns:
198
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
199
+ """
200
+ query = [query]
201
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
202
+ with torch.no_grad():
203
+ query_embedding = self.encode(query).to(device)
204
+ if metric == SimilarityMetric.EUCLIDEAN:
205
+ scores = torch.squeeze(query_embedding @ self._vector_index.T)
206
+ else:
207
+ scores = F.cosine_similarity(query_embedding, self._vector_index)
208
+ scores = scores.cpu().numpy().tolist()
209
+ scores = argsort_scores(scores, descending=True)[:top_k]
210
+ retrieved_chunks = []
211
+ for score in scores:
212
+ retrieved_chunks.append(
213
+ {
214
+ "chunk": self._chunk_dataset[score["original_index"]],
215
+ "score": score["item"],
216
+ }
217
+ )
218
+ return retrieved_chunks
mkdocs.yml CHANGED
@@ -74,5 +74,6 @@ nav:
74
  - Retrieval:
75
  - BM25-Sparse: 'retreival/bm25s.md'
76
  - ColPali: 'retreival/colpali.md'
 
77
 
78
  repo_url: https://github.com/soumik12345/medrag-multi-modal
 
74
  - Retrieval:
75
  - BM25-Sparse: 'retreival/bm25s.md'
76
  - ColPali: 'retreival/colpali.md'
77
+ - Contriever: 'retreival/contriever.md'
78
 
79
  repo_url: https://github.com/soumik12345/medrag-multi-modal