Spaces:
Paused
Paused
# -*- coding: utf-8 -*- | |
#!pip install gradio | |
#!pip install -U sentence-transformers | |
#!pip install langchain | |
#!pip install openai | |
#!pip install -U chromadb | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
from langchain.llms import OpenAI | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
from langchain import LLMMathChain, SQLDatabase, SQLDatabaseChain, LLMChain | |
from langchain.agents import initialize_agent, Tool | |
import sqlite3 | |
import pandas as pd | |
import json | |
import chromadb | |
import os | |
cxn = sqlite3.connect('./data/mbr.db') | |
"""# import models""" | |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens | |
#The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
"""# setup vector db | |
- chromadb | |
- https://docs.trychroma.com/getting-started | |
""" | |
from chromadb.config import Settings | |
chroma_client = chromadb.Client(settings=Settings( | |
chroma_db_impl="duckdb+parquet", | |
persist_directory="./data/mychromadb/" # Optional, defaults to .chromadb/ in the current directory | |
)) | |
#!ls ./data/mychromadb/ | |
#collection = chroma_client.create_collection(name="benefit_collection") | |
collection = chroma_client.get_collection(name="benefit_collection", embedding_function=bi_encoder) | |
"""### vector db search examples""" | |
def rtrv(qry,top_k=20): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
) | |
return results | |
def vdb_src(qry,src,top_k=20): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
where={"source": src}, | |
) | |
return results | |
def vdb_pretty(qry,top_k=10): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
include=["metadatas", "documents", "distances","embeddings"] | |
) | |
rslt_pd = pd.DataFrame(results ).explode(['ids','documents', 'metadatas', 'distances', 'embeddings']) | |
rslt_fmt = pd.concat([rslt_pd.drop(['metadatas'], axis=1), rslt_pd['metadatas'].apply(pd.Series)], axis=1 ) | |
return rslt_fmt | |
# qry = 'Why should I chose Medicare Advantage over traditional Medicare?' | |
# rslt_fmt = vdb_pretty(qry, top_k=10) | |
# rslt_fmt | |
# doc_lst = rslt_fmt[['documents']].values.tolist() | |
# len(doc_lst) | |
"""# Introduction | |
- example of the kind of question answering that is possible with this tool | |
- assumes we are answering for a member with a Healthy Options Card | |
*When will I get my card?* | |
# semantic search functions | |
""" | |
# choosing to use rerank for this use case as a baseline | |
def rernk(query, collection=collection, top_k=20, top_n = 5): | |
rtrv_rslts = rtrv(query, top_k=top_k) | |
rtrv_ids = rtrv_rslts.get('ids')[0] | |
rtrv_docs = rtrv_rslts.get('documents')[0] | |
##### Re-Ranking ##### | |
cross_inp = [[query, doc] for doc in rtrv_docs] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
combined = list(zip(rtrv_ids, list(cross_scores))) | |
sorted_tuples = sorted(combined, key=lambda x: x[1], reverse=True) | |
sorted_ids = [t[0] for t in sorted_tuples[:top_n]] | |
predictions = collection.get(ids=sorted_ids, include=["documents","metadatas"]) | |
return predictions | |
#return cross_scores | |
## version w/o re-rank | |
# def get_text_fmt(qry): | |
# prediction_text = [] | |
# predictions = rtrv(qry, top_k = 5) | |
# docs = predictions['documents'][0] | |
# meta = predictions['metadatas'][0] | |
# for i in range(len(docs)): | |
# result = Document(page_content=docs[i], metadata=meta[i]) | |
# prediction_text.append(result) | |
# return prediction_text | |
def get_text_fmt(qry): | |
prediction_text = [] | |
predictions = rernk(qry, collection=collection, top_k=20, top_n = 5) | |
docs = predictions['documents'] | |
meta = predictions['metadatas'] | |
for i in range(len(docs)): | |
result = Document(page_content=docs[i], metadata=meta[i]) | |
prediction_text.append(result) | |
return prediction_text | |
# get_text_fmt('why should I choose a medicare advantage plan over traditional medicare?') | |
"""# LLM based qa functions""" | |
llm = OpenAI(temperature=0) | |
# default model | |
# model_name: str = "text-davinci-003" | |
# instruction fine-tuned, sometimes referred to as GPT-3.5 | |
template = """You are a friendly AI assistant for the insurance company Humana. | |
Given the following extracted parts of a long document and a question, create a succinct final answer. | |
If you don't know the answer, just say that you don't know. Don't try to make up an answer. | |
If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana. | |
QUESTION: {question} | |
========= | |
{summaries} | |
========= | |
FINAL ANSWER:""" | |
PROMPT = PromptTemplate(template=template, input_variables=["summaries", "question"]) | |
chain_qa = load_qa_with_sources_chain(llm=llm, chain_type="stuff", prompt=PROMPT, verbose=False) | |
def get_llm_response(message): | |
mydocs = get_text_fmt(message) | |
responses = chain_qa({"input_documents":mydocs, "question":message}) | |
return responses | |
# rslt = get_llm_response('can I buy shrimp?') | |
# rslt['output_text'] | |
# for d in rslt['input_documents']: | |
# print(d.page_content) | |
# print(d.metadata['url']) | |
# rslt['output_text'] | |
"""# Database query""" | |
## setup member database | |
## only do this once | |
# d = {'mbr_fname':['bruce'], | |
# 'mbr_lname':['broussard'], | |
# 'mbr_id':[456] , | |
# 'policy_id':['H1036-236'], | |
# 'accumulated_out_of_pocket':[3800], | |
# 'accumulated_routine_footcare_visits':[6], | |
# 'accumulated_trasportation_trips':[22], | |
# 'accumulated_drug_cost':[7500], | |
# } | |
# df = pd.DataFrame(data=d, columns=['mbr_fname', 'mbr_lname', 'mbr_id', 'policy_id', 'accumulated_out_of_pocket', 'accumulated_routine_footcare_visits', 'accumulated_trasportation_trips','accumulated_drug_cost']) | |
# df.to_sql(name='mbr_details', con=cxn, if_exists='replace') | |
# # sample db query | |
# qry = '''select accumulated_routine_footcare_visits | |
# from mbr_details''' | |
# foot_det = pd.read_sql(qry, cxn) | |
# foot_det.values[0][0] | |
db = SQLDatabase.from_uri("sqlite:///./data/mbr.db") | |
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, return_intermediate_steps=True) | |
def db_qry(qry): | |
responses = db_chain('my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo | |
return responses | |
"""# Math | |
- default version | |
""" | |
llm_math_chain = LLMMathChain(llm=llm, verbose=True) | |
# llm_math_chain.run('what is the square root of 49?') | |
"""# Greeting""" | |
template = """You are an AI assistant for the insurance company Humana. | |
Your name is Jarvis and you were created on February 13, 2023. | |
Offer polite, friendly greetings and brief small talk. | |
Respond to thanks with, 'Glad to help.' | |
If the question is not about Humana, politely guide the user to ask questions about Humana insurance benefits | |
QUESTION: {question} | |
========= | |
FINAL ANSWER:""" | |
greet_prompt = PromptTemplate(template=template, input_variables=["question"]) | |
greet_llm = LLMChain(prompt=greet_prompt, llm=llm, verbose=True) | |
# greet_llm.run('will it snow in Lousiville tomorrow') | |
# greet_llm.run('Thanks, that was great') | |
"""# MRKL Chain""" | |
tools = [ | |
Tool( | |
name = "Benefit", | |
func=get_llm_response, | |
description='''Useful for confirming what items can be bought with the healthy options card. | |
Useful for when you need to answer questions about healthy options allowance. | |
You should ask targeted questions''' | |
), | |
Tool( | |
name="Calculator", | |
func=llm_math_chain.run, | |
description="useful for when you need to answer questions about math" | |
), | |
Tool( | |
name="Member DB", | |
func=db_qry, | |
description='''useful for when you need to answer questions about member details such their name, id and accumulated use of services. | |
This tool shows how much a benfit has already been consumed. | |
Input should be in the form of a question containing full context''' | |
), | |
Tool( | |
name="Greeting", | |
func=greet_llm.run, | |
description="useful for when you need to respond to greetings, thanks, make small talk or answer questions about yourself" | |
), | |
] | |
mrkl = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=False, return_intermediate_steps=True, max_iterations=5, early_stopping_method="generate") | |
def mrkl_rspnd(qry): | |
response = mrkl({"input":str(qry) }) | |
return response | |
# r = mrkl_rspnd("can I buy fish with the card?") | |
# print(r['output']) | |
# print(json.dumps(r['intermediate_steps'], indent=2)) | |
#r['intermediate_steps'] | |
# from IPython.core.display import display, HTML | |
def get_cot(r): | |
cot = '<p>' | |
try: | |
intermedObj = r['intermediate_steps'] | |
cot +='<b>Input:</b> '+r['input']+'<br>' | |
for agnt_action, obs in intermedObj: | |
al = '<br> '.join(agnt_action.log.split('\n') ) | |
cot += '<b>AI chain of thought:</b> '+ al +'<br>' | |
if type(obs) is dict: | |
if obs.get('input_documents') is not None: #### this criteria doesn't work | |
for d in obs['input_documents']: | |
cot += ' '+'<i>- '+str(d.page_content)+'</i>'+' <a href="'+ str(d.metadata['url']) +'">'+str(d.metadata['page'])+'</a> '+'<br>' | |
cot += '<b>Observation:</b> '+str(obs['output_text']) +'<br><br>' | |
elif obs.get('intermediate_steps') is not None: | |
cot += '<b>Query:</b> '+str(obs.get('intermediate_steps')) +'<br><br>' | |
else: | |
pass | |
else: | |
cot += '<b>Observation:</b> '+str(obs) +'<br><br>' | |
except: | |
pass | |
cot += '</p>' | |
return cot | |
# cot = get_cot(r) | |
# display(HTML(cot)) | |
"""# chat example""" | |
def chat(message, history): | |
history = history or [] | |
message = message.lower() | |
response = mrkl_rspnd(message) | |
cot = get_cot(response) | |
history.append((message, response['output'])) | |
return history, history, cot | |
css=".gradio-container {background-color: lightgray}" | |
xmpl_list = ["Why should I choose a Medicare Advantage plan over Traditional Medicare?", | |
"What is the difference between a Medicare Advantage HMO plan and a PPO plan?", | |
"What is a low income subsidy plan and do I qualify for one of these plans?", | |
"Are my medications covered on a low income subsidy plan?"] | |
with gr.Blocks(css=css) as demo: | |
history_state = gr.State() | |
response_state = gr.State() | |
gr.Markdown('# Sales QA Bot') | |
with gr.Row(): | |
chatbot = gr.Chatbot() | |
with gr.Accordion(label='Show AI chain of thought: ', open=False,): | |
ai_cot = gr.HTML(show_label=False) | |
with gr.Row(): | |
message = gr.Textbox(label='Input your question here:', | |
placeholder='Why should I choose Medicare Advantage?', | |
lines=1) | |
submit = gr.Button(value='Send', | |
variant='secondary').style(full_width=False) | |
submit.click(chat, | |
inputs=[message, history_state], | |
outputs=[chatbot, history_state, ai_cot]) | |
gr.Examples( | |
examples=xmpl_list, | |
inputs=message | |
) | |
demo.launch() | |