devve1 commited on
Commit
7827d1e
1 Parent(s): a1f5b8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -20,14 +20,7 @@ from fastembed_ext import FastEmbedEmbeddingsLc
20
  from fastembed.sparse.splade_pp import supported_splade_models
21
  from fastembed import SparseTextEmbedding, SparseEmbedding
22
  from unstructured.partition.auto import partition
23
- from qdrant_client.models import (
24
- NamedSparseVector,
25
- NamedVector,
26
- SparseVector,
27
- PointStruct,
28
- SearchRequest,
29
- ScoredPoint,
30
- )
31
 
32
  def make_points(texts: List[str], metadatas: List[dict], dense: List[ndarray], sparse: List[SparseEmbedding])-> List[PointStruct]:
33
  points = []
@@ -186,25 +179,38 @@ def load_models_and_documents():
186
  n_gpu_layers=32
187
  )
188
 
189
- provider = ['CUDAExecutionProvider']
190
-
191
- dense_model = FastEmbedEncoder(
192
  name='mixedbread-ai/mxbai-embed-large-v1',
193
- providers=provider,
194
- cache_dir=os.getenv('HF_HOME')
195
  )
196
 
197
  sparse_model = SparseTextEmbedding(
198
  'Qdrant/all_miniLM_L6_v2_with_attentions',
199
- cache_dir=os.getenv('HF_HOME'),
200
- providers=provider
201
  )
202
 
203
  nltk.download('punkt')
204
  nltk.download('averaged_perceptron_tagger')
205
 
206
- client = QdrantClient(':memory:')
207
  collection_name = 'collection_demo'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  client.create_collection(
210
  collection_name,
 
20
  from fastembed.sparse.splade_pp import supported_splade_models
21
  from fastembed import SparseTextEmbedding, SparseEmbedding
22
  from unstructured.partition.auto import partition
23
+ from pymilvus import MilvusClient
 
 
 
 
 
 
 
24
 
25
  def make_points(texts: List[str], metadatas: List[dict], dense: List[ndarray], sparse: List[SparseEmbedding])-> List[PointStruct]:
26
  points = []
 
179
  n_gpu_layers=32
180
  )
181
 
182
+ dense_model = HuggingfaceEncoder(
 
 
183
  name='mixedbread-ai/mxbai-embed-large-v1',
184
+ device='cuda'
 
185
  )
186
 
187
  sparse_model = SparseTextEmbedding(
188
  'Qdrant/all_miniLM_L6_v2_with_attentions',
189
+ cache_dir=os.getenv('HF_HOME')
 
190
  )
191
 
192
  nltk.download('punkt')
193
  nltk.download('averaged_perceptron_tagger')
194
 
195
+ client = MilvusClient('https://' + os.getenv('SPACE_HOST') + ':' + str(19530))
196
  collection_name = 'collection_demo'
197
+
198
+ fields = [
199
+ FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100),
200
+ FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
201
+ FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
202
+ FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024)
203
+ ]
204
+
205
+ schema = CollectionSchema(fields, "")
206
+ col = Collection(collection_name, schema)
207
+
208
+ sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}
209
+ dense_index = {"index_type": "FLAT", "metric_type": "COSINE"}
210
+ col.create_index("sparse_vector", sparse_index)
211
+ col.create_index("dense_vector", dense_index)
212
+
213
+ client.insert(collection_name)
214
 
215
  client.create_collection(
216
  collection_name,