|
import os |
|
from operator import itemgetter |
|
from typing import TypedDict |
|
|
|
from dotenv import load_dotenv |
|
from langchain_community.vectorstores.pgvector import PGVector |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.runnables import RunnableParallel |
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
|
|
from app.config import PG_COLLECTION_NAME |
|
|
|
load_dotenv() |
|
|
|
vector_store = PGVector( |
|
collection_name=PG_COLLECTION_NAME, |
|
connection_string=os.getenv("POSTGRES_URL"), |
|
embedding_function=OpenAIEmbeddings() |
|
) |
|
|
|
template = """ |
|
Answer given the following context: |
|
{context} |
|
|
|
Question: {question} |
|
""" |
|
|
|
ANSWER_PROMPT = ChatPromptTemplate.from_template(template) |
|
|
|
llm = ChatOpenAI(temperature=0, model='gpt-4-1106-preview', streaming=True) |
|
|
|
|
|
class RagInput(TypedDict): |
|
question: str |
|
|
|
|
|
final_chain = ( |
|
RunnableParallel( |
|
context=(itemgetter("question") | vector_store.as_retriever()), |
|
question=itemgetter("question") |
|
) | |
|
RunnableParallel( |
|
answer=(ANSWER_PROMPT | llm), |
|
docs=itemgetter("context") |
|
) |
|
|
|
).with_types(input_type=RagInput) |
|
|