Spaces:
Sleeping
Sleeping
geekyrakshit
commited on
Commit
·
4ea2b30
1
Parent(s):
21537b7
add: BM25sRetriever
Browse files
.gitignore
CHANGED
@@ -18,3 +18,5 @@ wandb/
|
|
18 |
cursor_prompt.txt
|
19 |
test.py
|
20 |
uv.lock
|
|
|
|
|
|
18 |
cursor_prompt.txt
|
19 |
test.py
|
20 |
uv.lock
|
21 |
+
grays-anatomy-bm25s/
|
22 |
+
prompt**.txt
|
medrag_multi_modal/retrieval/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
-
from .
|
|
|
2 |
|
3 |
-
__all__ = ["
|
|
|
1 |
+
from .bm25s_retrieval import BM25sRetriever
|
2 |
+
from .colpali_retrieval import CalPaliRetriever
|
3 |
|
4 |
+
__all__ = ["CalPaliRetriever", "BM25sRetriever"]
|
medrag_multi_modal/retrieval/bm25s_retrieval.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import bm25s
|
4 |
+
import weave
|
5 |
+
from Stemmer import Stemmer
|
6 |
+
|
7 |
+
import wandb
|
8 |
+
|
9 |
+
LANGUAGE_DICT = {
|
10 |
+
"english": "en",
|
11 |
+
"french": "fr",
|
12 |
+
"german": "de",
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class BM25sRetriever(weave.Model):
|
17 |
+
language: str
|
18 |
+
use_stemmer: bool
|
19 |
+
_retriever: Optional[bm25s.BM25]
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
language: str = "english",
|
24 |
+
use_stemmer: bool = True,
|
25 |
+
retriever: Optional[bm25s.BM25] = None,
|
26 |
+
):
|
27 |
+
super().__init__(language=language, use_stemmer=use_stemmer)
|
28 |
+
self._retriever = retriever or bm25s.BM25()
|
29 |
+
|
30 |
+
def index(self, corpus_dataset_name: str, index_name: Optional[str] = None):
|
31 |
+
corpus_dataset = weave.ref(corpus_dataset_name).get().rows
|
32 |
+
corpus = [row["text"] for row in corpus_dataset]
|
33 |
+
corpus_tokens = bm25s.tokenize(
|
34 |
+
corpus,
|
35 |
+
stopwords=LANGUAGE_DICT[self.language],
|
36 |
+
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
37 |
+
)
|
38 |
+
self._retriever.index(corpus_tokens)
|
39 |
+
self._retriever.save(index_name, corpus=[dict(row) for row in corpus_dataset])
|
40 |
+
if index_name:
|
41 |
+
self._retriever.save(index_name)
|
42 |
+
if wandb.run:
|
43 |
+
artifact = wandb.Artifact(name=index_name, type="bm25s-index")
|
44 |
+
artifact.add_dir(index_name)
|
45 |
+
artifact.save()
|
medrag_multi_modal/retrieval/colpali_retrieval.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import os
|
2 |
from typing import Any, Optional
|
3 |
|
4 |
-
import wandb
|
5 |
import weave
|
6 |
from byaldi import RAGMultiModalModel
|
7 |
from PIL import Image
|
8 |
|
|
|
|
|
9 |
from ..utils import get_wandb_artifact
|
10 |
|
11 |
|
|
|
1 |
import os
|
2 |
from typing import Any, Optional
|
3 |
|
|
|
4 |
import weave
|
5 |
from byaldi import RAGMultiModalModel
|
6 |
from PIL import Image
|
7 |
|
8 |
+
import wandb
|
9 |
+
|
10 |
from ..utils import get_wandb_artifact
|
11 |
|
12 |
|
pyproject.toml
CHANGED
@@ -8,6 +8,7 @@ dependencies = [
|
|
8 |
"bm25s[full]>=0.2.2",
|
9 |
"Byaldi>=0.0.5",
|
10 |
"firerequests>=0.0.7",
|
|
|
11 |
"pdf2image>=1.17.0",
|
12 |
"python-dotenv>=1.0.1",
|
13 |
"pymupdf4llm>=0.0.17",
|
@@ -17,6 +18,7 @@ dependencies = [
|
|
17 |
"uv>=0.4.20",
|
18 |
"pytest>=8.3.3",
|
19 |
"PyPDF2>=3.0.1",
|
|
|
20 |
"isort>=5.13.2",
|
21 |
"black>=24.10.0",
|
22 |
"ruff>=0.6.9",
|
@@ -39,10 +41,12 @@ core = [
|
|
39 |
"bm25s[full]>=0.2.2",
|
40 |
"Byaldi>=0.0.5",
|
41 |
"firerequests>=0.0.7",
|
|
|
42 |
"marker-pdf>=0.2.17",
|
43 |
"pdf2image>=1.17.0",
|
44 |
"pdfplumber>=0.11.4",
|
45 |
"PyPDF2>=3.0.1",
|
|
|
46 |
"python-dotenv>=1.0.1",
|
47 |
"pymupdf4llm>=0.0.17",
|
48 |
"semchunk>=2.2.0",
|
|
|
8 |
"bm25s[full]>=0.2.2",
|
9 |
"Byaldi>=0.0.5",
|
10 |
"firerequests>=0.0.7",
|
11 |
+
"jax[cpu]>=0.4.34",
|
12 |
"pdf2image>=1.17.0",
|
13 |
"python-dotenv>=1.0.1",
|
14 |
"pymupdf4llm>=0.0.17",
|
|
|
18 |
"uv>=0.4.20",
|
19 |
"pytest>=8.3.3",
|
20 |
"PyPDF2>=3.0.1",
|
21 |
+
"PyStemmer>=2.2.0.3",
|
22 |
"isort>=5.13.2",
|
23 |
"black>=24.10.0",
|
24 |
"ruff>=0.6.9",
|
|
|
41 |
"bm25s[full]>=0.2.2",
|
42 |
"Byaldi>=0.0.5",
|
43 |
"firerequests>=0.0.7",
|
44 |
+
"jax[cpu]>=0.4.34",
|
45 |
"marker-pdf>=0.2.17",
|
46 |
"pdf2image>=1.17.0",
|
47 |
"pdfplumber>=0.11.4",
|
48 |
"PyPDF2>=3.0.1",
|
49 |
+
"PyStemmer>=2.2.0.3",
|
50 |
"python-dotenv>=1.0.1",
|
51 |
"pymupdf4llm>=0.0.17",
|
52 |
"semchunk>=2.2.0",
|