danicafisher's picture
Tests using langchain
30e2f30
raw
history blame
5.74 kB
# from typing import List
# from chainlit.types import AskFileResponse
from aimakerspace.text_utils import CharacterTextSplitter, PDFFileLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
)
# from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.vectordatabase import VectorDatabase
# from aimakerspace.openai_utils.chatmodel import ChatOpenAI
import chainlit as cl
# import asyncio
from operator import itemgetter
import nest_asyncio
nest_asyncio.apply()
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
filepath_NIST = "data/NIST.AI.600-1.pdf"
filepath_Blueprint = "data/Blueprint-for-an-AI-Bill-of-Rights.pdf"
documents_NIST = PyMuPDFLoader(filepath_NIST).load()
documents_Blueprint = PyMuPDFLoader(filepath_Blueprint).load()
documents = documents_NIST + documents_Blueprint
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 500,
chunk_overlap = 50
)
rag_documents = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Qdrant.from_documents(
documents=rag_documents,
embedding=embeddings,
location=":memory:",
collection_name="Implications of AI"
)
retriever = qdrant_vectorstore.as_retriever()
RAG_PROMPT = """\
Given a provided context and question, you must answer the question based only on context.
If you cannot answer the question based on the context - you must say "I don't know".
Context: {context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
# RAG_PROMPT_TEMPLATE = """ \
# Use the provided context to answer the user's query.
# You may not answer the user's query unless there is specific context in the following text.
# If you do not know the answer, or cannot answer, please respond with "I don't know".
# """
# rag_prompt = SystemRolePrompt(RAG_PROMPT_TEMPLATE)
# USER_PROMPT_TEMPLATE = """ \
# Context:
# {context}
# User Query:
# {user_query}
# """
# user_prompt = UserRolePrompt(USER_PROMPT_TEMPLATE)
# class RetrievalAugmentedQAPipeline:
# def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
# self.llm = llm
# self.vector_db_retriever = vector_db_retriever
# async def arun_pipeline(self, user_query: str):
# context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
# context_prompt = ""
# for context in context_list:
# context_prompt += context[0] + "\n"
# formatted_system_prompt = rag_prompt.create_message()
# formatted_user_prompt = user_prompt.create_message(user_query=user_query, context=context_prompt)
# async def generate_response():
# async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
# yield chunk
# return {"response": generate_response(), "context": context_list}
# ------------------------------------------------------------
@cl.on_chat_start # marks a function that will be executed at the start of a user session
async def start_chat():
# settings = {
# "model": "gpt-3.5-turbo",
# "temperature": 0,
# "max_tokens": 500,
# "top_p": 1,
# "frequency_penalty": 0,
# "presence_penalty": 0,
# }
# # Create a dict vector store
# vector_db = VectorDatabase()
# vector_db = await vector_db.abuild_from_list(rag_documents)
# vector_db = await vector_db.abuild_from_list(split_documents_NIST)
# vector_db = await vector_db.abuild_from_list(split_documents_Blueprint)
# # chat_openai = ChatOpenAI()
# llm = ChatOpenAI(model="gpt-4o-mini", tags=["base_llm"])
# # Create a chain
# retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
# vector_db_retriever=vector_db,
# llm=llm
# )
primary_llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
rag_chain = (
# INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
# "question" : populated by getting the value of the "question" key
# "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
# "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
# by getting the value of the "context" key from the previous step
| RunnablePassthrough.assign(context=itemgetter("context"))
# "response" : the "context" and "question" values are used to format our prompt object and then piped
# into the LLM and stored in a key called "response"
# "context" : populated by getting the value of the "context" key from the previous step
| {"response": prompt | primary_llm, "context": itemgetter("context")}
)
# cl.user_session.set("settings", settings)
cl.user_session.set("chain", rag_chain)
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
async def main(message):
chain = cl.user_session.get("chain")
msg = cl.Message(content="")
result = await chain.arun_pipeline(message.content)
async for stream_resp in result["response"]:
await msg.stream_token(stream_resp)
await msg.send()