devve1 commited on
Commit
2c4e7a2
1 Parent(s): 9e9db46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -27
app.py CHANGED
@@ -97,37 +97,21 @@ def transform_query(query: str) -> str:
97
  def query_hybrid_search(query: str, client: QdrantClient, collection_name: str, dense_model: OptimumEncoder, sparse_model: SparseTextEmbedding):
98
  dense_embeddings = dense_model.embed_query(transform_query(query))[0]
99
  sparse_embeddings = list(sparse_model.query_embed(query))[0]
100
-
101
- return client.query_batch_points(
102
  collection_name=collection_name,
103
- requests=[
104
- QueryRequest(
105
- prefetch=Prefetch(query=sparse_embeddings.as_object(), using="text-sparse", limit=10),
106
- with_vector=False,
107
- with_payload=True,
108
- query=FusionQuery(fusion=Fusion.DBSF),
109
- limit=3,
110
- offset=0,
111
- filter=Filter(must_not=[
112
- HasIdCondition(has_id=st.session_state.filter_ids)
113
- ])
114
- ),
115
- QueryRequest(
116
- prefetch=Prefetch(query=dense_embeddings, using="text-dense", limit=10),
117
- with_vector=False,
118
- with_payload=True,
119
- query=FusionQuery(fusion=Fusion.DBSF),
120
- limit=3,
121
- offset=0,
122
- filter=Filter(must_not=[
123
- HasIdCondition(has_id=st.session_state.filter_ids)
124
- ])
125
- )
126
- ]
127
  )
128
 
129
  def main(query: str, client: QdrantClient, collection_name: str, tokenizer: AutoTokenizer, llm: vllm.LLM, dense_model: OptimumEncoder, sparse_model: SparseTextEmbedding):
130
- scored_points = query_hybrid_search(query, client, collection_name, dense_model, sparse_model)[0].points
131
 
132
  docs = [(scored_point.payload['text'], scored_point.payload['metadata']) for scored_point in scored_points]
133
  contents, metadatas = [list(t) for t in zip(*docs)]
 
97
  def query_hybrid_search(query: str, client: QdrantClient, collection_name: str, dense_model: OptimumEncoder, sparse_model: SparseTextEmbedding):
98
  dense_embeddings = dense_model.embed_query(transform_query(query))[0]
99
  sparse_embeddings = list(sparse_model.query_embed(query))[0]
100
+
101
+ return client.query_points(
102
  collection_name=collection_name,
103
+ prefetch=[
104
+ Prefetch(query=sparse_embeddings.as_object(), using="text-sparse", limit=10),
105
+ Prefetch(query=dense_embeddings, using="text-dense", limit=10)
106
+ ],
107
+ query=FusionQuery(fusion=Fusion.DBSF),
108
+ with_vector=False,
109
+ with_payload=True
110
+ limit=3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
 
113
  def main(query: str, client: QdrantClient, collection_name: str, tokenizer: AutoTokenizer, llm: vllm.LLM, dense_model: OptimumEncoder, sparse_model: SparseTextEmbedding):
114
+ scored_points = query_hybrid_search(query, client, collection_name, dense_model, sparse_model).points
115
 
116
  docs = [(scored_point.payload['text'], scored_point.payload['metadata']) for scored_point in scored_points]
117
  contents, metadatas = [list(t) for t in zip(*docs)]