AI-RESEARCHER-2024 commited on
Commit
d7b6100
1 Parent(s): 992aa46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -67
app.py CHANGED
@@ -1,53 +1,11 @@
1
  import os
2
- from typing import Any, List, Mapping, Optional, Dict
3
  import chainlit as cl
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain.prompts import ChatPromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_community.vectorstores import Chroma
9
- from langchain.callbacks.manager import CallbackManagerForLLMRun
10
- from langchain.llms.base import LLM
11
  from llama_cpp import Llama
12
- from pydantic import Field, BaseModel
13
-
14
- class LlamaCppLLM(LLM, BaseModel):
15
- """Custom LangChain wrapper for llama.cpp"""
16
-
17
- client: Any = Field(default=None, exclude=True)
18
- model_path: str = Field(..., description="Path to the model file")
19
- n_ctx: int = Field(default=2048, description="Context window size")
20
- n_threads: int = Field(default=4, description="Number of CPU threads")
21
- n_gpu_layers: int = Field(default=0, description="Number of GPU layers")
22
-
23
- def __init__(self, **kwargs):
24
- super().__init__(**kwargs)
25
- self.client = Llama(
26
- model_path=self.model_path,
27
- n_ctx=self.n_ctx,
28
- n_threads=self.n_threads,
29
- n_gpu_layers=self.n_gpu_layers
30
- )
31
-
32
- @property
33
- def _llm_type(self) -> str:
34
- return "llama.cpp"
35
-
36
- def _call(
37
- self,
38
- prompt: str,
39
- stop: Optional[List[str]] = None,
40
- run_manager: Optional[CallbackManagerForLLMRun] = None,
41
- **kwargs: Any,
42
- ) -> str:
43
- if not self.client:
44
- raise RuntimeError("Model not initialized")
45
-
46
- response = self.client.create_chat_completion(
47
- messages=[{"role": "user", "content": prompt}],
48
- **kwargs
49
- )
50
- return response["choices"][0]["message"]["content"]
51
 
52
  # Initialize the embedding model
53
  embeddings = HuggingFaceEmbeddings(
@@ -58,15 +16,15 @@ embeddings = HuggingFaceEmbeddings(
58
 
59
  # Load the existing Chroma vector store
60
  persist_directory = os.path.join(os.path.dirname(__file__), 'mydb')
61
- vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
 
 
 
62
 
63
- # Initialize the LLM
64
- model_path = os.path.join(os.path.dirname(__file__), "models", "llama-model.gguf")
65
- llm = LlamaCppLLM(
66
- model_path=model_path,
67
- n_ctx=2048,
68
- n_threads=4,
69
- n_gpu_layers=0
70
  )
71
 
72
  # Create the RAG prompt template
@@ -84,39 +42,39 @@ prompt = ChatPromptTemplate.from_template(template)
84
 
85
  @cl.on_chat_start
86
  async def start():
87
- # Send initial message
88
  await cl.Message(
89
  content="Hi! I'm ready to answer your questions based on the stored documents. What would you like to know?"
90
  ).send()
91
 
92
  @cl.on_message
93
  async def main(message: cl.Message):
94
- # Create a loading message
95
  msg = cl.Message(content="")
96
  await msg.send()
97
 
98
- # Start typing effect
99
  async with cl.Step(name="Searching documents..."):
100
  try:
101
- # Search the vector store
102
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
 
 
 
 
 
103
 
104
- # Create the RAG chain
105
- rag_chain = (
106
- {"context": retriever, "question": RunnablePassthrough()}
107
- | prompt
108
- | llm
109
- | StrOutputParser()
 
 
110
  )
 
111
 
112
- # Execute the chain
113
- response = await cl.make_async(rag_chain)(message.content)
114
-
115
  # Update loading message with response
116
- await msg.update(content=response)
117
-
118
  # Show source documents
119
- docs = retriever.get_relevant_documents(message.content)
120
  elements = []
121
  for i, doc in enumerate(docs):
122
  source_name = f"Source {i+1}"
@@ -133,4 +91,4 @@ async def main(message: cl.Message):
133
  await msg.update(content=error_msg)
134
 
135
  if __name__ == '__main__':
136
- cl.start()
 
1
  import os
 
2
  import chainlit as cl
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain.prompts import ChatPromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.runnables import RunnablePassthrough
7
  from langchain_community.vectorstores import Chroma
 
 
8
  from llama_cpp import Llama
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Initialize the embedding model
11
  embeddings = HuggingFaceEmbeddings(
 
16
 
17
  # Load the existing Chroma vector store
18
  persist_directory = os.path.join(os.path.dirname(__file__), 'mydb')
19
+ vectorstore = Chroma(
20
+ persist_directory=persist_directory,
21
+ embedding_function=embeddings
22
+ )
23
 
24
+ # Initialize the Llama model using from_pretrained
25
+ llm = Llama.from_pretrained(
26
+ repo_id="bartowski/Llama-3.2-1B-Instruct-GGUF",
27
+ filename="Llama-3.2-1B-Instruct-Q8_0.gguf",
 
 
 
28
  )
29
 
30
  # Create the RAG prompt template
 
42
 
43
  @cl.on_chat_start
44
  async def start():
 
45
  await cl.Message(
46
  content="Hi! I'm ready to answer your questions based on the stored documents. What would you like to know?"
47
  ).send()
48
 
49
  @cl.on_message
50
  async def main(message: cl.Message):
 
51
  msg = cl.Message(content="")
52
  await msg.send()
53
 
 
54
  async with cl.Step(name="Searching documents..."):
55
  try:
 
56
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
57
+ docs = retriever.get_relevant_documents(message.content)
58
+ context = "\n\n".join([doc.page_content for doc in docs])
59
+
60
+ # Format the prompt
61
+ final_prompt = prompt.format(context=context, question=message.content)
62
 
63
+ # Generate response using the Llama model
64
+ response = llm.create_chat_completion(
65
+ messages=[
66
+ {
67
+ "role": "user",
68
+ "content": final_prompt
69
+ }
70
+ ]
71
  )
72
+ assistant_reply = response['choices'][0]['message']['content']
73
 
 
 
 
74
  # Update loading message with response
75
+ await msg.update(content=assistant_reply)
76
+
77
  # Show source documents
 
78
  elements = []
79
  for i, doc in enumerate(docs):
80
  source_name = f"Source {i+1}"
 
91
  await msg.update(content=error_msg)
92
 
93
  if __name__ == '__main__':
94
+ cl.run()