Spaces:
Sleeping
Sleeping
import os | |
import chainlit as cl | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.prompts.chat import (AIMessagePromptTemplate, | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate) | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
embeddings = OpenAIEmbeddings() | |
async def init(): | |
vector_store = FAISS.load_local("docs.faiss", embeddings) | |
chain = RetrievalQAWithSourcesChain.from_chain_type( | |
ChatOpenAI(temperature=0, streaming=True, model="gpt-4-1106-preview"), | |
chain_type="stuff", | |
retriever=vector_store.as_retriever(search_kwargs={"k": 7}), | |
) | |
cl.user_session.set("chain", chain) | |
async def main(message): | |
chain = cl.user_session.get("chain") # type: RetrievalQAWithSourcesChain | |
cb = cl.AsyncLangchainCallbackHandler( | |
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"] | |
) | |
cb.answer_reached = True | |
res = await chain.acall(message.content, callbacks=[cb]) | |
if cb.has_streamed_final_answer: | |
await cb.final_stream.update() | |
else: | |
answer = res["answer"] | |
await cl.Message( | |
content=answer, | |
).send() | |