Spaces:
Runtime error
Runtime error
File size: 4,523 Bytes
d8472fa b0071a4 d638db5 fa23d20 d8472fa fa23d20 b0071a4 fa23d20 b0071a4 fa23d20 b0071a4 fa23d20 b0071a4 fa23d20 b0071a4 fa23d20 1dff132 fa23d20 1dff132 d8472fa fa23d20 d8472fa 1dff132 b0071a4 d8472fa 1dff132 d8472fa fa23d20 d638db5 fa23d20 d8472fa 1dff132 d8472fa d638db5 fa23d20 d638db5 d8472fa d638db5 d8472fa b0071a4 36cef3b bf22b40 36cef3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
from typing import Any, List, Mapping, Optional, Dict
import chainlit as cl
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.vectorstores import Chroma
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from llama_cpp import Llama
from pydantic import Field, BaseModel
class LlamaCppLLM(LLM, BaseModel):
"""Custom LangChain wrapper for llama.cpp"""
client: Any = Field(default=None, exclude=True)
model_path: str = Field(..., description="Path to the model file")
n_ctx: int = Field(default=2048, description="Context window size")
n_threads: int = Field(default=4, description="Number of CPU threads")
n_gpu_layers: int = Field(default=0, description="Number of GPU layers")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.client = Llama(
model_path=self.model_path,
n_ctx=self.n_ctx,
n_threads=self.n_threads,
n_gpu_layers=self.n_gpu_layers
)
@property
def _llm_type(self) -> str:
return "llama.cpp"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if not self.client:
raise RuntimeError("Model not initialized")
response = self.client.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
**kwargs
)
return response["choices"][0]["message"]["content"]
# Initialize the embedding model
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# Load the existing Chroma vector store
persist_directory = os.path.join(os.path.dirname(__file__), 'mydb')
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
# Initialize the LLM
model_path = os.path.join(os.path.dirname(__file__), "models", "llama-model.gguf")
llm = LlamaCppLLM(
model_path=model_path,
n_ctx=2048,
n_threads=4,
n_gpu_layers=0
)
# Create the RAG prompt template
template = """You are a helpful AI assistant. Using only the following context, answer the user's question.
If you cannot find the answer in the context, say "I don't have enough information to answer this question."
Context:
{context}
Question: {question}
Answer: Let me help you with that."""
prompt = ChatPromptTemplate.from_template(template)
@cl.on_chat_start
async def start():
# Send initial message
await cl.Message(
content="Hi! I'm ready to answer your questions based on the stored documents. What would you like to know?"
).send()
@cl.on_message
async def main(message: cl.Message):
# Create a loading message
msg = cl.Message(content="")
await msg.send()
# Start typing effect
async with cl.Step(name="Searching documents..."):
try:
# Search the vector store
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# Create the RAG chain
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# Execute the chain
response = await cl.make_async(rag_chain)(message.content)
# Update loading message with response
await msg.update(content=response)
# Show source documents
docs = retriever.get_relevant_documents(message.content)
elements = []
for i, doc in enumerate(docs):
source_name = f"Source {i+1}"
elements.append(
cl.Text(name=source_name, content=doc.page_content, display="inline")
)
if elements:
await msg.update(elements=elements)
except Exception as e:
import traceback
error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}"
await msg.update(content=error_msg)
if __name__ == '__main__':
cl.start() |