yrobel-lima commited on
Commit
f2a1c22
1 Parent(s): d612275

Upload update_vector_database.py

Browse files
Files changed (1) hide show
  1. utils/update_vector_database.py +12 -8
utils/update_vector_database.py CHANGED
@@ -100,7 +100,7 @@ class SparseVectorStore:
100
  def __init__(self, documents: list[Document], collection_name: str, vector_name: str, k: int = 4, splade_model_id: str = "naver/splade-cocondenser-ensembledistil"):
101
  self.validate_environment_variables()
102
  self.client = QdrantClient(url=os.getenv(
103
- "QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
104
  self.model_id = splade_model_id
105
  self.tokenizer, self.model = self.set_tokenizer_config()
106
  self.collection_name = collection_name
@@ -124,7 +124,6 @@ class SparseVectorStore:
124
  model = AutoModelForMaskedLM.from_pretrained(self.model_id)
125
  return tokenizer, model
126
 
127
- @cache
128
  def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
129
  """This function encodes the input text into a sparse vector. The sparse_encoder is required for the QdrantSparseVectorRetriever.
130
  Adapted from the Qdrant documentation: Computing the Sparse Vector code.
@@ -135,17 +134,22 @@ class SparseVectorStore:
135
  Returns:
136
  tuple[list[int], list[float]]: Indices and values of the sparse vector
137
  """
138
- tokens = self.tokenizer(
139
- text, return_tensors="pt", max_length=512, padding="max_length", truncation=True)
140
- output = self.model(**tokens)
 
 
 
141
  logits, attention_mask = output.logits, tokens.attention_mask
142
- relu_log = torch.log(1 + torch.relu(logits))
 
143
  weighted_log = relu_log * attention_mask.unsqueeze(-1)
 
144
  max_val, _ = torch.max(weighted_log, dim=1)
145
  vec = max_val.squeeze()
146
 
147
- indices = vec.nonzero().numpy().flatten()
148
- values = vec.detach().numpy()[indices]
149
 
150
  return indices.tolist(), values.tolist()
151
 
 
100
  def __init__(self, documents: list[Document], collection_name: str, vector_name: str, k: int = 4, splade_model_id: str = "naver/splade-cocondenser-ensembledistil"):
101
  self.validate_environment_variables()
102
  self.client = QdrantClient(url=os.getenv(
103
+ "QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) # TODO: prefer_grpc=True is not working
104
  self.model_id = splade_model_id
105
  self.tokenizer, self.model = self.set_tokenizer_config()
106
  self.collection_name = collection_name
 
124
  model = AutoModelForMaskedLM.from_pretrained(self.model_id)
125
  return tokenizer, model
126
 
 
127
  def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
128
  """This function encodes the input text into a sparse vector. The sparse_encoder is required for the QdrantSparseVectorRetriever.
129
  Adapted from the Qdrant documentation: Computing the Sparse Vector code.
 
134
  Returns:
135
  tuple[list[int], list[float]]: Indices and values of the sparse vector
136
  """
137
+ tokens = self.tokenizer(text, return_tensors="pt",
138
+ max_length=512, padding="max_length", truncation=True)
139
+
140
+ with torch.no_grad():
141
+ output = self.model(**tokens)
142
+
143
  logits, attention_mask = output.logits, tokens.attention_mask
144
+
145
+ relu_log = torch.log1p(torch.relu(logits))
146
  weighted_log = relu_log * attention_mask.unsqueeze(-1)
147
+
148
  max_val, _ = torch.max(weighted_log, dim=1)
149
  vec = max_val.squeeze()
150
 
151
+ indices = torch.nonzero(vec, as_tuple=False).squeeze().numpy()
152
+ values = vec[indices].numpy()
153
 
154
  return indices.tolist(), values.tolist()
155