Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 7,079 Bytes
6e20157 d612275 6e20157 d612275 6e20157 d612275 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 d612275 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 d612275 f3d91b8 d612275 f3d91b8 d612275 f3d91b8 d612275 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 f3d91b8 6e20157 d612275 f3d91b8 d612275 f3d91b8 d612275 f3d91b8 d612275 f3d91b8 d612275 f3d91b8 d612275 f3d91b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import os
from functools import cache
import qdrant_client
import torch
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.retrievers import QdrantSparseVectorRetriever
from langchain_community.vectorstores import Qdrant
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from transformers import AutoModelForMaskedLM, AutoTokenizer
class ValidateQdrantClient:
"""Base class for retriever clients to ensure environment variables are set."""
def __init__(self):
self.validate_environment_variables()
def validate_environment_variables(self):
"""Check if the Qdrant environment variables are set."""
required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
raise EnvironmentError(
f"Missing environment variable(s): {', '.join(missing_vars)}"
)
class DenseRetrieverClient(ValidateQdrantClient):
"""Initialize the dense retriever using OpenAI text embeddings and Qdrant vector database."""
TEXT_EMBEDDING_MODELS = [
"text-embedding-ada-002",
"text-embedding-3-small",
"text-embedding-3-large",
]
def __init__(
self,
embeddings_model="text-embedding-3-small",
collection_name="practitioners_db",
search_type="similarity",
k=4,
):
super().__init__()
if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
raise ValueError(
f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
)
self.embeddings_model = embeddings_model
self.collection_name = collection_name
self.search_type = search_type
self.k = k
self.client = qdrant_client.QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
prefer_grpc=True,
)
self._qdrant_collection = None
def set_qdrant_collection(self, embeddings):
"""Prepare the Qdrant collection for the embeddings model."""
return Qdrant(
client=self.client,
collection_name=self.collection_name,
embeddings=embeddings,
)
@property
@cache
def qdrant_collection(self):
"""Load Qdrant collection for a given embeddings model."""
if self._qdrant_collection is None:
self._qdrant_collection = self.set_qdrant_collection(
OpenAIEmbeddings(model=self.embeddings_model)
)
return self._qdrant_collection
def get_dense_retriever(self):
"""Set up retrievers (Qdrant vectorstore as retriever)."""
return self.qdrant_collection.as_retriever(
search_type=self.search_type, search_kwargs={"k": self.k}
)
class SparseRetrieverClient(ValidateQdrantClient):
"""Initialize the sparse retriever using the SPLADE neural retrieval model and Qdrant vector database."""
def __init__(
self,
collection_name,
vector_name,
splade_model_id="naver/splade-cocondenser-ensembledistil",
k=15,
):
# Validate Qdrant client
super().__init__()
self.client = qdrant_client.QdrantClient(
url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")
) # TODO: prefer_grpc=True is not working
self.model_id = splade_model_id
self._tokenizer = None
self._model = None
self.collection_name = collection_name
self.vector_name = vector_name
self.k = k
@property
@cache
def tokenizer(self):
"""Initialize the tokenizer."""
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
return self._tokenizer
@property
@cache
def model(self):
"""Initialize the SPLADE neural retrieval model."""
if self._model is None:
self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
return self._model
def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
"""Encode the input text into a sparse vector."""
tokens = self.tokenizer(
text,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True,
)
with torch.no_grad():
logits = self.model(**tokens).logits
relu_log = torch.log1p(torch.relu(logits))
weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)
max_val = torch.max(weighted_log, dim=1).values.squeeze()
indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
values = max_val[indices].cpu().numpy()
return indices.tolist(), values.tolist()
def get_sparse_retriever(self) -> QdrantSparseVectorRetriever:
"""Return a Qdrant vector sparse retriever."""
return QdrantSparseVectorRetriever(
client=self.client,
collection_name=self.collection_name,
sparse_vector_name=self.vector_name,
sparse_encoder=self.sparse_encoder,
k=self.k,
)
def compression_retriever_setup(
base_retriever, embeddings_model="text-embedding-3-small", k=20
):
"""Creates a ContextualCompressionRetriever with an EmbeddingsFilter."""
filter = EmbeddingsFilter(embeddings=OpenAIEmbeddings(model=embeddings_model), k=k)
return ContextualCompressionRetriever(
base_compressor=filter, base_retriever=base_retriever
)
def multi_query_retriever_setup(retriever):
"""Configure a multi-query retriever using a base retriever."""
prompt = PromptTemplate(
input_variables=["question"],
template="""
Your task is to generate 3 different grammatically correct versions of the provided text,
incorporating the user's location preference in each version. Format these versions as paragraphs and present them as items in a Markdown formatted numbered list ("1. "). There should be no additional new lines or spaces between each version. Do not enclose your response in quotation marks. Do not modify unfamiliar acronyms and keep your responses clear and concise.
**Notes**: The text provided are user questions to Tall Tree Health Centre's AI virtual assistant. `Location preference:` is the location of the Tall Tree Health clinic that the user prefers.
Text to be modified:
```
{question}
```""",
)
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
return MultiQueryRetriever.from_llm(
retriever=retriever, llm=llm, prompt=prompt, include_original=True
)
|