Spaces:
Sleeping
Sleeping
import os | |
from typing import Any | |
from huggingface_hub import InferenceClient | |
from rag_demo.rag.base.query import Query | |
from rag_demo.rag.base.template_factory import RAGStep | |
from rag_demo.rag.prompt_templates import QueryExpansionTemplate | |
class QueryExpansion(RAGStep): | |
def generate(self, query: Query, expand_to_n: int) -> Any: | |
api = InferenceClient( | |
model="Qwen/Qwen2.5-72B-Instruct", | |
token=os.getenv("HF_API_TOKEN"), | |
) | |
query_expansion_template = QueryExpansionTemplate() | |
prompt = query_expansion_template.create_template(expand_to_n - 1) | |
response = api.chat_completion( | |
[ | |
{ | |
"role": "user", | |
"content": prompt.template.format( | |
question=query.content, | |
expand_to_n=expand_to_n, | |
separator=query_expansion_template.separator, | |
), | |
} | |
] | |
) | |
result = response.choices[0].message.content | |
queries_content = result.split(query_expansion_template.separator) | |
queries = [query] | |
queries += [ | |
query.replace_content(stripped_content) | |
for content in queries_content | |
if (stripped_content := content.strip()) | |
] | |
return queries | |