geekyrakshit commited on
Commit
4ea2b30
·
1 Parent(s): 21537b7

add: BM25sRetriever

Browse files
.gitignore CHANGED
@@ -18,3 +18,5 @@ wandb/
18
  cursor_prompt.txt
19
  test.py
20
  uv.lock
 
 
 
18
  cursor_prompt.txt
19
  test.py
20
  uv.lock
21
+ grays-anatomy-bm25s/
22
+ prompt**.txt
medrag_multi_modal/retrieval/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
- from .multi_modal_retrieval import MultiModalRetriever
 
2
 
3
- __all__ = ["MultiModalRetriever"]
 
1
+ from .bm25s_retrieval import BM25sRetriever
2
+ from .colpali_retrieval import CalPaliRetriever
3
 
4
+ __all__ = ["CalPaliRetriever", "BM25sRetriever"]
medrag_multi_modal/retrieval/bm25s_retrieval.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import bm25s
4
+ import weave
5
+ from Stemmer import Stemmer
6
+
7
+ import wandb
8
+
9
+ LANGUAGE_DICT = {
10
+ "english": "en",
11
+ "french": "fr",
12
+ "german": "de",
13
+ }
14
+
15
+
16
+ class BM25sRetriever(weave.Model):
17
+ language: str
18
+ use_stemmer: bool
19
+ _retriever: Optional[bm25s.BM25]
20
+
21
+ def __init__(
22
+ self,
23
+ language: str = "english",
24
+ use_stemmer: bool = True,
25
+ retriever: Optional[bm25s.BM25] = None,
26
+ ):
27
+ super().__init__(language=language, use_stemmer=use_stemmer)
28
+ self._retriever = retriever or bm25s.BM25()
29
+
30
+ def index(self, corpus_dataset_name: str, index_name: Optional[str] = None):
31
+ corpus_dataset = weave.ref(corpus_dataset_name).get().rows
32
+ corpus = [row["text"] for row in corpus_dataset]
33
+ corpus_tokens = bm25s.tokenize(
34
+ corpus,
35
+ stopwords=LANGUAGE_DICT[self.language],
36
+ stemmer=Stemmer(self.language) if self.use_stemmer else None,
37
+ )
38
+ self._retriever.index(corpus_tokens)
39
+ self._retriever.save(index_name, corpus=[dict(row) for row in corpus_dataset])
40
+ if index_name:
41
+ self._retriever.save(index_name)
42
+ if wandb.run:
43
+ artifact = wandb.Artifact(name=index_name, type="bm25s-index")
44
+ artifact.add_dir(index_name)
45
+ artifact.save()
medrag_multi_modal/retrieval/colpali_retrieval.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  from typing import Any, Optional
3
 
4
- import wandb
5
  import weave
6
  from byaldi import RAGMultiModalModel
7
  from PIL import Image
8
 
 
 
9
  from ..utils import get_wandb_artifact
10
 
11
 
 
1
  import os
2
  from typing import Any, Optional
3
 
 
4
  import weave
5
  from byaldi import RAGMultiModalModel
6
  from PIL import Image
7
 
8
+ import wandb
9
+
10
  from ..utils import get_wandb_artifact
11
 
12
 
pyproject.toml CHANGED
@@ -8,6 +8,7 @@ dependencies = [
8
  "bm25s[full]>=0.2.2",
9
  "Byaldi>=0.0.5",
10
  "firerequests>=0.0.7",
 
11
  "pdf2image>=1.17.0",
12
  "python-dotenv>=1.0.1",
13
  "pymupdf4llm>=0.0.17",
@@ -17,6 +18,7 @@ dependencies = [
17
  "uv>=0.4.20",
18
  "pytest>=8.3.3",
19
  "PyPDF2>=3.0.1",
 
20
  "isort>=5.13.2",
21
  "black>=24.10.0",
22
  "ruff>=0.6.9",
@@ -39,10 +41,12 @@ core = [
39
  "bm25s[full]>=0.2.2",
40
  "Byaldi>=0.0.5",
41
  "firerequests>=0.0.7",
 
42
  "marker-pdf>=0.2.17",
43
  "pdf2image>=1.17.0",
44
  "pdfplumber>=0.11.4",
45
  "PyPDF2>=3.0.1",
 
46
  "python-dotenv>=1.0.1",
47
  "pymupdf4llm>=0.0.17",
48
  "semchunk>=2.2.0",
 
8
  "bm25s[full]>=0.2.2",
9
  "Byaldi>=0.0.5",
10
  "firerequests>=0.0.7",
11
+ "jax[cpu]>=0.4.34",
12
  "pdf2image>=1.17.0",
13
  "python-dotenv>=1.0.1",
14
  "pymupdf4llm>=0.0.17",
 
18
  "uv>=0.4.20",
19
  "pytest>=8.3.3",
20
  "PyPDF2>=3.0.1",
21
+ "PyStemmer>=2.2.0.3",
22
  "isort>=5.13.2",
23
  "black>=24.10.0",
24
  "ruff>=0.6.9",
 
41
  "bm25s[full]>=0.2.2",
42
  "Byaldi>=0.0.5",
43
  "firerequests>=0.0.7",
44
+ "jax[cpu]>=0.4.34",
45
  "marker-pdf>=0.2.17",
46
  "pdf2image>=1.17.0",
47
  "pdfplumber>=0.11.4",
48
  "PyPDF2>=3.0.1",
49
+ "PyStemmer>=2.2.0.3",
50
  "python-dotenv>=1.0.1",
51
  "pymupdf4llm>=0.0.17",
52
  "semchunk>=2.2.0",