Spaces:
Runtime error
Runtime error
import os | |
from typing import Any, List, Mapping, Optional, Dict | |
import chainlit as cl | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_community.vectorstores import Chroma | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
from llama_cpp import Llama | |
from pydantic import Field, BaseModel | |
class LlamaCppLLM(LLM, BaseModel): | |
"""Custom LangChain wrapper for llama.cpp""" | |
client: Any = Field(default=None, exclude=True) | |
model_path: str = Field(..., description="Path to the model file") | |
n_ctx: int = Field(default=2048, description="Context window size") | |
n_threads: int = Field(default=4, description="Number of CPU threads") | |
n_gpu_layers: int = Field(default=0, description="Number of GPU layers") | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.client = Llama( | |
model_path=self.model_path, | |
n_ctx=self.n_ctx, | |
n_threads=self.n_threads, | |
n_gpu_layers=self.n_gpu_layers | |
) | |
def _llm_type(self) -> str: | |
return "llama.cpp" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
if not self.client: | |
raise RuntimeError("Model not initialized") | |
response = self.client.create_chat_completion( | |
messages=[{"role": "user", "content": prompt}], | |
**kwargs | |
) | |
return response["choices"][0]["message"]["content"] | |
# Initialize the embedding model | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
# Load the existing Chroma vector store | |
persist_directory = os.path.join(os.path.dirname(__file__), 'mydb') | |
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings) | |
# Initialize the LLM | |
model_path = os.path.join(os.path.dirname(__file__), "models", "llama-model.gguf") | |
llm = LlamaCppLLM( | |
model_path=model_path, | |
n_ctx=2048, | |
n_threads=4, | |
n_gpu_layers=0 | |
) | |
# Create the RAG prompt template | |
template = """You are a helpful AI assistant. Using only the following context, answer the user's question. | |
If you cannot find the answer in the context, say "I don't have enough information to answer this question." | |
Context: | |
{context} | |
Question: {question} | |
Answer: Let me help you with that.""" | |
prompt = ChatPromptTemplate.from_template(template) | |
async def start(): | |
# Send initial message | |
await cl.Message( | |
content="Hi! I'm ready to answer your questions based on the stored documents. What would you like to know?" | |
).send() | |
async def main(message: cl.Message): | |
# Create a loading message | |
msg = cl.Message(content="") | |
await msg.send() | |
# Start typing effect | |
async with cl.Step(name="Searching documents..."): | |
try: | |
# Search the vector store | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
# Create the RAG chain | |
rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Execute the chain | |
response = await cl.make_async(rag_chain)(message.content) | |
# Update loading message with response | |
await msg.update(content=response) | |
# Show source documents | |
docs = retriever.get_relevant_documents(message.content) | |
elements = [] | |
for i, doc in enumerate(docs): | |
source_name = f"Source {i+1}" | |
elements.append( | |
cl.Text(name=source_name, content=doc.page_content, display="inline") | |
) | |
if elements: | |
await msg.update(elements=elements) | |
except Exception as e: | |
import traceback | |
error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}" | |
await msg.update(content=error_msg) | |
if __name__ == '__main__': | |
cl.start() |