Spaces:
Starting
on
T4
Starting
on
T4
import os | |
import re | |
import lz4 | |
import json | |
import time | |
import uuid | |
import torch | |
import spacy | |
import base64 | |
import asyncio | |
import msgpack | |
import sqlite3 | |
import outlines | |
import validators | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from vllm import LLM | |
from numpy import ndarray | |
from datetime import datetime | |
from typing import List, Dict | |
from pydantic import BaseModel | |
from dense_embed import embed_text | |
from ppt_chunker import ppt_chunker | |
from outlines import models, generate | |
from qdrant_client import QdrantClient | |
from unstructured.cleaners.core import clean | |
from streamlit_navigation_bar import st_navbar | |
from vllm.sampling_params import SamplingParams | |
from fastembed import SparseTextEmbedding, SparseEmbedding | |
from unstructured.nlp.tokenize import download_nltk_packages | |
from scipy.sparse import csr_matrix, save_npz, load_npz, vstack | |
from infinity_emb import AsyncEngineArray, EngineArgs, AsyncEmbeddingEngine | |
from sqlalchemy import ( | |
create_engine, | |
MetaData, | |
Table, | |
Column, | |
String, | |
Integer, | |
select, | |
column, | |
) | |
from qdrant_client.models import ( | |
NamedSparseVector, | |
NamedVector, | |
SparseVector, | |
PointStruct, | |
ScoredPoint, | |
Prefetch, | |
FusionQuery, | |
Fusion, | |
SearchRequest, | |
Modifier, | |
OptimizersConfigDiff, | |
HnswConfigDiff, | |
Distance, | |
VectorParams, | |
SparseVectorParams, | |
SparseIndexParams, | |
Batch, | |
PointIdsList, | |
QueryRequest, | |
Filter, | |
HasIdCondition, | |
Datatype, | |
BinaryQuantization, | |
BinaryQuantizationConfig, | |
MultiVectorConfig | |
) | |
global_state_documents_only = False | |
class Question(BaseModel): | |
answer: str | |
icon_to_types = { | |
'ppt':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAC4ElEQVR4nO2YS2gTQRzGF/VgPVuPpsXSQ3eshdbXqVDMEtlN1UN8HqTZYouCVJTsiIfgTcWbD0o0mp3GKq1oC55E2AbbKn3dakISCVjwUNtLlZ1a24xs0oqG2uxmJ2si+8F32NnL95v/fx67DGPLli1b/5UUkauJtDsvDHm5XkXkphSvc17xcksqgksYSXMYwSksS71Yhudx6MouphSk+Ju3RLzcGUV0jg6JHFnPGMH1LcMRVZZOkz7P5n8TXuRcisjF/xY8LwDKWpVhDIcgZ1nwiXPCNkXkHuYLrhcAr4EgKUD6LlUUNfxIp3OHIjon9YY3AoCzHl8IXq0sWvgh0RkzEr4AAKK1FHWIbNsYm/lCAfBqJchj/1ZqAEZ6nhIAURHsNhQy0OToCjQ5vj1ochAzDu6vJgOte00DYATJYsh32AiA6fC/Q9AAUJEU1X1O0Aq/ZhoAOAPhO1nWABhJw2UNoMowjdG16oIBcrWy/IPMx6Pk9eV2iyoACQ75OkwDxF4+JdEXT8jMaCTznF5ZJoNtR60BkKWwaYDgwZpfY/FXzzNj0/3IGgAEJ6gCjN29ma3KwDOLAKRZ0wBhrpGglnoy2HaMLHyeyYy98XVaAqAiuGgaIFcf+ns2XMRJAdD0ommA8Xu3yNidG+Td7etk4Gxr3m2ULgA7S3UN6DFlgIlyB+gpcwBQ+EFWqGmFTwggnTqyp6p8AXjwNm/4kgZwg+N6Ab7SCv9oXxUdAB5ME4/eD5rGnRdpQGjhIy21dADcdS0MDSUFEKC8qxAdi/c+Q0ufPAcqEjw7biHA+1Szg95vFU1xV0NlgmdjRZ95no3GhNrtVMP/CQHGijnzcVdDcX4t5rRTdzF6PkW7bTbSR373Ia3cpsPzYJrabmNU2h6ddNedSvBgWDvyDcx2WjthEwJ7gviZTUwpKOaur9YuXQmBDSd5djLJg7kkD76v+ot2Jc68E0CHrruNLVu2bDGlrJ8c/urSuEn7XgAAAABJRU5ErkJggg==', | |
'Powerpoint'), | |
'pptx':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAC4ElEQVR4nO2YS2gTQRzGF/VgPVuPpsXSQ3eshdbXqVDMEtlN1UN8HqTZYouCVJTsiIfgTcWbD0o0mp3GKq1oC55E2AbbKn3dakISCVjwUNtLlZ1a24xs0oqG2uxmJ2si+8F32NnL95v/fx67DGPLli1b/5UUkauJtDsvDHm5XkXkphSvc17xcksqgksYSXMYwSksS71Yhudx6MouphSk+Ju3RLzcGUV0jg6JHFnPGMH1LcMRVZZOkz7P5n8TXuRcisjF/xY8LwDKWpVhDIcgZ1nwiXPCNkXkHuYLrhcAr4EgKUD6LlUUNfxIp3OHIjon9YY3AoCzHl8IXq0sWvgh0RkzEr4AAKK1FHWIbNsYm/lCAfBqJchj/1ZqAEZ6nhIAURHsNhQy0OToCjQ5vj1ochAzDu6vJgOte00DYATJYsh32AiA6fC/Q9AAUJEU1X1O0Aq/ZhoAOAPhO1nWABhJw2UNoMowjdG16oIBcrWy/IPMx6Pk9eV2iyoACQ75OkwDxF4+JdEXT8jMaCTznF5ZJoNtR60BkKWwaYDgwZpfY/FXzzNj0/3IGgAEJ6gCjN29ma3KwDOLAKRZ0wBhrpGglnoy2HaMLHyeyYy98XVaAqAiuGgaIFcf+ns2XMRJAdD0ommA8Xu3yNidG+Td7etk4Gxr3m2ULgA7S3UN6DFlgIlyB+gpcwBQ+EFWqGmFTwggnTqyp6p8AXjwNm/4kgZwg+N6Ab7SCv9oXxUdAB5ME4/eD5rGnRdpQGjhIy21dADcdS0MDSUFEKC8qxAdi/c+Q0ufPAcqEjw7biHA+1Szg95vFU1xV0NlgmdjRZ95no3GhNrtVMP/CQHGijnzcVdDcX4t5rRTdzF6PkW7bTbSR373Ia3cpsPzYJrabmNU2h6ddNedSvBgWDvyDcx2WjthEwJ7gviZTUwpKOaur9YuXQmBDSd5djLJg7kkD76v+ot2Jc68E0CHrruNLVu2bDGlrJ8c/urSuEn7XgAAAABJRU5ErkJggg==', | |
'Powerpoint'), | |
'txt':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAYAAACqaXHeAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEA0lEQVR4nO2bW08TQRTH+8QDH8DPpF+EGLt+AENI44OC7u6L+mAxYQ0xLdZoDE8kCpr4oEYUGu4kEOgFRWm5FLbd3Zk9Ztpus132WmZ3WtqTnGQzLbP7/82ZM2dmaSw2sIENbGAdWHIkOSxw0pgQlzYFTlJFTgIGXhQ4aSKRyAzFojT+9vQNMT71i5HoS04gRCpe4KZWyY2TiUzrIdTj3cCOzooA1ZO64/ND23azm79DfD+73IqEyMSLcWmF3HQykYG5ldOOAVxVvHayB7pcbt2fifgPu0pHAGiJJ59FAoB3EN8JAJriIwHAu4gPCoC2+NAB8B7i/QDA1ePQxIcKgPch3gsAkv8BMQIhDPGhAeB9incDgOQStJlWDSyegHMTHwqA5Ehy2ChyXtx/A3OrZ47ivSIAGxCUi8DitcpB/U915dxRfCgABE4aM4ocL/Hvv+TrN39+L+Wc8FS5I/E6VhvwdB1AvYgQQFzaDFqKLryep5btUbUMtoaUyCJA9SucjPznzDxUj3boLnVKBUDHpgiQo4sAkWFt35bwyNwn5iK+qwCgMNZ5reYqvmsAoJCKHD/OHABiKJ4pgK3vi5AcTTE/DPHtcSnPcy9vUQMwOZpmLyqgC3elAhUA6KzY6vTBj1pPuO+pInoAMOZ2XwJApsTWdwCQJdvTBPD4Zw0mFrsYALJZ6vwAyFea5ayN5Sq4/p2HizWoqDocynobBHJ9cK7X+wjSH3UAyGGd9wMg5/OBiXhiC3mt1Uauif2V9cD9UQOAXIqcoFPAMLvPpjcaW18VAzzLKvAkq4DS1JreVgP3RwUA8qjwaAIgvlZqKF4vYVgvN6/LzqMaOgDwKG9pAzCPuhENT7MKWwDYpbanDcA87635gAkA7LGxCQPAp0IXAVA9dnW0AZBwJ2FvGJkOZFowBYBctrS0AWw0E9+aKQmSa2YAkMd+niaA1FZjGVRQY9TNCfHVJqNlEDwOM2gCIBUfsY+5y4VQwaHAiRwAtpzkBAVAKjWnau3rb1QP+3FLKbx9jOHbHxS4P+oAsM0x1rXfDYLHGV5fAMDmWkAugXaau/SfGdcWAG4Tf0Re0QBgDaB22h8A1Law3wcgLyvJK6vm1JgcnYn0QJPZoajWOrdvjLzhu0tbPQVB4KS8eEe6GQiAdsWXFl5uTbCduO/Q9mtis8NeEB8qAL0HxIcKAHpAfM8BwJTF9xQAHIL4SACk+VmYEWfbxARtM4tPPXoLaf5dm4irtIUOQLSJiKBtdg9Mu20AgJaJlpEkIWwNbT9tJOxJuFpDlnZbVyZBHFLCiyoJFkmHhfWdrhe/t7xE/yczAidNsN64BPb41Dg1AIlEZqgJoR4JXe5sfjY3sIENLHZd7D/x1k4dCUv1GwAAAABJRU5ErkJggg==', | |
'Txt'), | |
'doc':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAACOElEQVR4nO2ZQWgTQRSGh+DFq70VD2lvQuKpV714cGtLu3ioiiJeWr0oaSnSm0KhKnabJiKYYhHRQxORBmLqoYVKqaGGWmizAYsaCBhDpdCkqatpap68l3QO4kHJsBnN/PDDzts34f+YN3sJY0pKSkr/pFz31q64/Ym822+CPU7kXX7zsjAAe8ObZJfPzAkEsDe8u2oFsC91An41Qv/rCGmGh2nGDtMMqMWOznFoHpipC8BOreE5RMc4uIdNoWZ/AAAirQCGG/kEjg8GARVZ+shrR3ofwb5ODD3j9Wg8RTXcIw3AwS4ffN/dgy85i9f6fLMc4MaTGK9vbH2lXtwjDQDTDHidzFDYlksPaf14LgnFUgVqdiVNNXyHwl6pRohpBtwJxSncmZEXtE5lc7D07jNMx95DwdqFAx1eOHc7Sj3YKx1A180whTOeL0Pz+QA9e6ffwvXJBXpuu/qU1ijslQ6gqec+lMsAi8kMnQKqZyQCxwan6NnzYJ5GB3uw99f9Rz1vhJr9LQA6md6Eb8USBGZWKfThCxN0WfEuhGMfwCqWqOd3e6UACFSDb1tF+LRZ4PX4ehZKez/oHfZIC3Dx7kv+6QwtrPO6L7zC69gjLUBr9TOJ6g+84vWztyp3AoU90gKwGqwAPI13AqMFUeEdp7x1ADhpXBMB4WgfKx/qnLRaT4e2RJqJUosehHqYKYCq1AnoaoQafIScejBvd3hn95S4PziceqgPf9DO8E492CsMQElJSYmJ0E+635eFCoKREwAAAABJRU5ErkJggg==', | |
'Microsoft Word'), | |
'docx':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAACOElEQVR4nO2ZQWgTQRSGh+DFq70VD2lvQuKpV714cGtLu3ioiiJeWr0oaSnSm0KhKnabJiKYYhHRQxORBmLqoYVKqaGGWmizAYsaCBhDpdCkqatpap68l3QO4kHJsBnN/PDDzts34f+YN3sJY0pKSkr/pFz31q64/Ym822+CPU7kXX7zsjAAe8ObZJfPzAkEsDe8u2oFsC91An41Qv/rCGmGh2nGDtMMqMWOznFoHpipC8BOreE5RMc4uIdNoWZ/AAAirQCGG/kEjg8GARVZ+shrR3ofwb5ODD3j9Wg8RTXcIw3AwS4ffN/dgy85i9f6fLMc4MaTGK9vbH2lXtwjDQDTDHidzFDYlksPaf14LgnFUgVqdiVNNXyHwl6pRohpBtwJxSncmZEXtE5lc7D07jNMx95DwdqFAx1eOHc7Sj3YKx1A180whTOeL0Pz+QA9e6ffwvXJBXpuu/qU1ijslQ6gqec+lMsAi8kMnQKqZyQCxwan6NnzYJ5GB3uw99f9Rz1vhJr9LQA6md6Eb8USBGZWKfThCxN0WfEuhGMfwCqWqOd3e6UACFSDb1tF+LRZ4PX4ehZKez/oHfZIC3Dx7kv+6QwtrPO6L7zC69gjLUBr9TOJ6g+84vWztyp3AoU90gKwGqwAPI13AqMFUeEdp7x1ADhpXBMB4WgfKx/qnLRaT4e2RJqJUosehHqYKYCq1AnoaoQafIScejBvd3hn95S4PziceqgPf9DO8E492CsMQElJSYmJ0E+635eFCoKREwAAAABJRU5ErkJggg==', | |
'Microsoft Word'), | |
'xslx':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAADGUlEQVR4nO2Za0hTYRzG9z0vu59JpK3vedvmTHOmM1OnE0wqbZWJjjTNS2nz7jZvFVmJUTOCchF0+RaBiAThQkgQKSKKbjohIqMIKlG3J3aiVXicw3PaMToPPF/eT7/fe/7vew4cHo8LFy5c/qlssht1cnvpjNxuxGqqcJhX1dhRs1M5as6kLbBx0OhcLTwdAYVHwtE+TVuADjxdAYXDDE5Azj0BMzdCtLLhTCH8aXifAREXiyC/WhrYQ0z06HVET+4M0a0H0ZUDojMb0g4dpNYsSC1ZkJgzIWnPgKRtB8Rt6RC3pkPcsh2i5jSImrQQNqZC2JACUYMWMose4f2GPyT+vkC33kkXXmhKgeD4NkhMqVB1FyHJVoHkKzVk6V4CPD8EGIEX1CdDUK+BqusAEs+XB1Cgizl4fl3SD4H+sgAKdDIHzz/GgoCUAn69NRsvZ2fgSeuQzQsffaoQcwvz+DY/h5gTBUvg+Ue3siBgpd75/dfbSYHZL58QYckhd/7W5Ai5Zh26RAkfWpvIgoBl+bEZef6QBO4YvgxNXwlcbjdevHdC1phGCR9akxB4AYmPmU84V4z5xQV8/PoZY68fwe12Q2+rXhY+tHoLCwLtvg/sBcdt/MyNiWGf8CFV8SwItPm+bezjd70CNyeGfcKHHImHOCcS0l0KyAxqsvzmVFrlrRRf8DkDNeTYjE89wYNXk6RE3kDtsvAhlWoWBFqp4cOa0/Hs3RsSWm+rgrbPSMpMfXiLMJOWEj64Mo4FgRbql9Tpe9dI+OGnY96xufP4PrnWOzJICR9coQq8gIgCPqH3IHn7uNwuaM4We2de2VOABdci2cST+5bABx1WsiDQtPLnga+Z/x0+qJwFASGD8EFlChYEGpiDX3coNvACAlOKkyn4oJJoFgTqNZn8Oo2TCXjBzkiIc6NA7P4lQLc8uhHpNsPfivVRkObHgChQrR0BSV4M/KnU0/xYELuVkO2NWzsCxB4l/K5n5xmElzEhwCSMjBP4H58AYVA72YInDGr6PzhkhepMNiQIg3paZojPoC3AhQsXLjwm8x3YSSmFlSW/AQAAAABJRU5ErkJggg==', | |
'Excel') | |
} | |
def transform_query(query: str) -> str: | |
""" For retrieval, add the prompt for query (not for documents). | |
""" | |
return f'Represent this sentence for searching relevant passages: {query}' | |
def query_hybrid_search(query: str, client: QdrantClient, collection_name: str, dense_model: AsyncEmbeddingEngine, sparse_model: SparseTextEmbedding): | |
dense_embeddings, tokens_count = asyncio.run(embed_text(dense_model[0], transform_query(query))) | |
sparse_embeddings = list(sparse_model.query_embed(query))[0] | |
return client.query_points( | |
collection_name=collection_name, | |
prefetch=[ | |
Prefetch(query=sparse_embeddings.as_object(), using="text-sparse", limit=25), | |
Prefetch(query=dense_embeddings[0], using="text-dense", limit=25) | |
], | |
query=FusionQuery(fusion=Fusion.RRF), | |
with_vectors=False, | |
with_payload=True, | |
limit=10, | |
score_threshold=0.95 | |
) | |
def build_prompt_conv(): | |
return [ | |
{ | |
'role': 'system', | |
'content': """Assume the role of an innovator who thrives on creativity and resourcefulness. Your responses should encourage new approaches and challenge conventional thinking. | |
Behavior: Focus on brainstorming and ideation, offering unconventional solutions to problems. | |
Mannerisms: Use energetic, enthusiastic language that reflects your innovative spirit. Frequently propose ideas that are bold and forward-looking.""" | |
}, | |
{ | |
'role': 'user', | |
'content': f"""Generate a short, single-sentence summary of the user's intent or topic based on their question, capturing the main focus of what they want to discuss. | |
Question : {st.session_state.user_input} | |
""" | |
} | |
] | |
def build_initial_prompt(query: str): | |
"""Determine whether the following query is a 'Domain-Specific Question' or a 'General Question.' | |
A 'Domain-Specific Question' requires knowledge or familiarity with a particular field, niche, or specialized area of interest, including specific video games, movies, books, academic disciplines, or professional fields. | |
A 'General Question' is broad, open-ended, and can be answered by almost anyone without needing specific context or prior knowledge about any particular domain. | |
A Domain-Specific Question can also just contain a word related to a particular field, niche, or specialized area of interet. For example: the word 'aggro' is related to specific video games. | |
Examples : | |
1. Query: "What are the symptoms of Type 2 diabetes?" | |
Choose one: Domain-Specific Question | |
2. Query: "What is your favorite color?" | |
Choose one: General Question | |
3. Query: "Who is the main character in Dark Souls?" | |
Choose one: Domain-Specific Question | |
4. Query: "How do you bake a cake?" | |
Choose one: General Question | |
5. Query: "Explain the difference between RAM and ROM." | |
Choose one: Domain-Specific Question | |
6. Query: "Tell me more about your weekend." | |
Choose one: General Question | |
7. Query: "Explain me more" | |
Choose one: General Question | |
8. Query: "What is god mode ?" | |
Choose one: Domain-Specific Question | |
9. Query: "Give me the meaning of aggro" | |
Choose one: Domain-Specific Question | |
10. Query: "Give me a description of an aimbot" | |
Choose one: Domain-Specific Question | |
Now, determine the following query : {{ query }} | |
Choose one: 'Domain-Specific Question' or 'General Question' | |
""" | |
def open_query_prompt(past_messages: str, query: str): | |
"""{{ past_messages }} | |
user: {{ query }} | |
assistant: | |
""" | |
def route_llm(context: str, query: str): | |
"""Based on the following context, decide if it is relevant to the given query. Return 'Yes' for relevant and 'No' for irrelevant. | |
Context : {{ context }} | |
Query: {{ query }} | |
""" | |
def answer_with_context(context: str, query: str): | |
"""Context information is below. | |
--------------------- | |
{context} | |
--------------------- | |
Given the context information and not prior knowledge, answer the query. | |
Query: {query} | |
Answer: | |
""" | |
def idk(query: str): | |
"Just express that you don't find the knowledge required in the vector database to answer the question. Be creative and original." | |
def self_knowledge(query: str): | |
"""Answer the following question by using your own knowledge about the topic. | |
Question: {{ query }} | |
""" | |
def main(query: str, client: QdrantClient, collection_name: str, llm, dense_model: AsyncEmbeddingEngine, sparse_model: SparseTextEmbedding, past_messages: str): | |
s = time.time() | |
scored_points = query_hybrid_search(query, client, collection_name, dense_model, sparse_model).points | |
print(f'Score : {scored_points[0]}') | |
docs = [(scored_point.payload['text'], scored_point.payload['metadata']) for scored_point in scored_points] | |
contents, metadatas = [list(t) for t in zip(*docs)] | |
context = "\n\n".join(contents) | |
print(f'Context : {context}') | |
gen_text = outlines.generate.json(llm, Question, whitespace_pattern=r"[\n ]?") | |
gen_choice = outlines.generate.choice(llm, choices=['Yes', 'No']) | |
prompt = route_llm(context, query) | |
action = gen_choice(prompt, max_tokens=2, sampling_params=SamplingParams(temperature=0)) | |
print(f'Choice: {action}') | |
if action == 'Yes': | |
filtered_metadatas = { | |
value | |
for metadata in metadatas | |
if 'url' in metadata | |
for value in [metadata['url']] | |
} | |
result_metadatas = "\n\n".join(f'{value}' for value in filtered_metadatas) | |
prompt = answer_with_context(context, query) | |
answer = gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0)).answer | |
answer = f"{answer}\n\n\nSource(s) :\n\n{result_metadatas}" | |
if not st.session_state.documents_only: | |
answer = f'Documents Based :\n\n{answer}' | |
else: | |
gen_choice = outlines.generate.choice(llm, choices=['Domain-Specific Question', 'General Question']) | |
prompt = build_initial_prompt(query) | |
action = gen_choice(prompt, max_tokens=3, sampling_params=SamplingParams(temperature=0)) | |
print(f'Choice 2: {action}') | |
if action == 'General Question': | |
prompt = open_query_prompt(past_messages, query) | |
answer = gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10)).answer | |
else: | |
print(f'GLOBAL STATE : {global_state_documents_only}') | |
if global_state_documents_only: | |
prompt = idk(query) | |
answer = gen_text(prompt, max_tokens=128, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10)).answer | |
else: | |
prompt = self_knowledge(query) | |
answer = gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10)).answer | |
answer = f'Internal Knowledge :\n\n{answer}' | |
torch.cuda.empty_cache() | |
e = time.time() | |
f = e - s | |
print(f'SEARCH TIME : {f}') | |
return answer | |
def collect_files(conn, cursor, directory, pattern): | |
array = [] | |
for filename in os.listdir(directory): | |
if pattern in filename: | |
if filename.endswith('.msgpack'): | |
with open(os.path.join(directory, filename), "rb") as data_file_payload: | |
decompressed_payload = data_file_payload.read() | |
array.extend(msgpack.unpackb(decompressed_payload, raw=False)) | |
elif (filename.endswith('.npz')) and (pattern == '_dense'): | |
array.extend(list(np.load(os.path.join(directory, filename)).values())) | |
elif (filename.endswith('.npz')) and (pattern == '_sparse'): | |
sparse_embeddings = [] | |
loaded_sparse_matrix = load_npz(os.path.join(directory, filename)) | |
for i in range(loaded_sparse_matrix.shape[0]): | |
row = loaded_sparse_matrix.getrow(i) | |
values = row.data.tolist() | |
indices = row.indices.tolist() | |
embedding = SparseVector(indices=indices, values=values) | |
sparse_embeddings.append(embedding) | |
array.extend(sparse_embeddings) | |
elif (filename.endswith('.npy')): | |
ids_list = np.load(os.path.join(directory, filename), allow_pickle=True).tolist() | |
insert_data(conn, cursor, os.path.splitext(filename)[0], ids_list) | |
array.extend(ids_list) | |
return array | |
def int_to_bytes(value): | |
return base64.b64encode(str(value).encode()) | |
def bytes_to_int(value): | |
return int(base64.b64decode(value).decode()) | |
def insert_data(conn, cursor, name, ids_array): | |
cursor.execute('INSERT INTO table_names (doc_name) VALUES (?)', (name,)) | |
for ids in ids_array: | |
cursor.execute('INSERT INTO table_ids (name, ids_value) VALUES (?, ?)', (name, int_to_bytes(ids))) | |
conn.commit() | |
def retrieve_ids_value(conn, cursor, name): | |
cursor.execute('SELECT ids_value FROM table_ids WHERE name = ?', (name,)) | |
rows = cursor.fetchall() | |
return [bytes_to_int(row[0]) for row in rows] | |
def delete_document(conn, cursor, name): | |
conn.execute('BEGIN') | |
try: | |
cursor.execute('DELETE FROM table_ids WHERE name = ?', (name,)) | |
cursor.execute('DELETE FROM table_names WHERE doc_name = ?', (name,)) | |
conn.commit() | |
print(f"Deleted document '{name}' and its associated IDs.") | |
except sqlite3.Error as e: | |
conn.rollback() | |
print(f"An error occurred: {e}") | |
def load_models_and_documents(): | |
container = st.empty() | |
with container.status("Load AI Models and Prepare Documents...", expanded=True) as status: | |
st.write('Downloading and Loading MixedBread Mxbai Dense Embedding Model with vLLM as backend...') | |
dense_model = AsyncEngineArray.from_args( | |
[ | |
EngineArgs( | |
model_name_or_path='mixedbread-ai/mxbai-embed-large-v1', | |
engine='torch', | |
device='cuda', | |
embedding_dtype='float32', | |
dtype='float16', | |
pooling_method='cls', | |
lengths_via_tokenize=True | |
) | |
] | |
) | |
st.write('Downloading and Loading Qdrant BM42 Sparse Embedding Model under ONNX using the CPU...') | |
sparse_model = SparseTextEmbedding( | |
'Qdrant/bm42-all-minilm-l6-v2-attentions', | |
cache_dir=os.getenv('HF_HOME'), | |
providers=['CPUExecutionProvider'] | |
) | |
st.write('Downloading and Loading Mistral Nemo quantized with GPTQ and using Outlines + vLLM Engine as backend...') | |
llm = LLM( | |
model="shuyuej/Mistral-Nemo-Instruct-2407-GPTQ", | |
tensor_parallel_size=1, | |
enforce_eager=True, | |
gpu_memory_utilization=1, | |
max_model_len=11264, | |
dtype=torch.float16, | |
max_num_seqs=128, | |
quantization="gptq" | |
) | |
model = models.VLLM(llm) | |
st.write('Loading Spacy Natural Language Processing Model for English...') | |
nlp = spacy.load("en_core_web_sm") | |
download_nltk_packages() | |
st.write('Creating Collection for our Qdrant Vector Database in RAM memory...') | |
client = QdrantClient(':memory:') | |
collection_name = 'collection_demo' | |
client.create_collection( | |
collection_name, | |
{ | |
'text-dense': VectorParams( | |
size=1024, | |
distance=Distance.COSINE, | |
datatype=Datatype.FLOAT16, | |
on_disk=False | |
) | |
}, | |
{ | |
'text-sparse': SparseVectorParams( | |
index=SparseIndexParams( | |
on_disk=False | |
), | |
modifier=Modifier.IDF | |
) | |
}, | |
2, | |
optimizers_config=OptimizersConfigDiff( | |
indexing_threshold=0, | |
default_segment_number=4 | |
), | |
hnsw_config=HnswConfigDiff( | |
on_disk=False, | |
m=32, | |
ef_construct=200 | |
) | |
) | |
conn = sqlite3.connect(':memory:', check_same_thread=False) | |
conn.execute('PRAGMA foreign_keys = ON;') | |
cursor = conn.cursor() | |
cursor.execute(''' | |
CREATE TABLE table_names ( | |
doc_name TEXT PRIMARY KEY | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE table_ids ( | |
name TEXT, | |
ids_value BLOB, | |
FOREIGN KEY(name) REFERENCES table_names(doc_name) | |
) | |
''') | |
cursor.execute('SELECT COUNT(*) FROM table_names') | |
count = cursor.fetchone()[0] | |
name = 'action_rpg' | |
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings') | |
payload_path = os.path.join(embeddings_path, name + '_payload.msgpack') | |
dense_path = os.path.join(embeddings_path, name + '_dense.npz') | |
sparse_path = os.path.join(embeddings_path, name + '_sparse.npz') | |
ids_path = os.path.join(embeddings_path, name + '_ids.npy') | |
if not os.path.exists(embeddings_path): | |
os.mkdir(embeddings_path) | |
st.write('Downloading and Loading Video Games Dataset coming from Wikipedia...') | |
dataset = pd.read_parquet(os.path.join(os.getenv('HOME'),'data', 'train_pages.parquet.zst'), engine='pyarrow') | |
for columnName, columnData in dataset.iteritems(): | |
if columnName == 'text': | |
documents = columnData.values.tolist() | |
elif columnName == 'section_title': | |
metadatas_titles = columnData.values.tolist() | |
elif columnName == 'url': | |
metadatas_url = columnData.values.tolist() | |
st.write('Transforming the Wikipedia Video Games Dataset into ingestable format for our Qdrant Vector Database...') | |
payload_docs = [{ 'text': text, 'metadata': { 'url': url } } for text, url in zip(documents, metadatas_url)] | |
start_dense = time.time() | |
dense_embeddings, tokens_count = asyncio.run(embed_text(dense_model[0], metadatas_titles)) | |
print(f'DENSE EMBED : {dense_embeddings}') | |
end_dense = time.time() | |
final_dense = end_dense - start_dense | |
print(f'DENSE TIME: {final_dense}') | |
start_sparse = time.time() | |
sparse_embeddings = [SparseVector(indices=s.indices.tolist(), values=s.values.tolist()) for s in sparse_model.embed(metadatas_titles, 32)] | |
end_sparse = time.time() | |
final_sparse = end_sparse - start_sparse | |
print(f'SPARSE TIME: {final_sparse}') | |
st.write('Saving on disk the Wikipedia Video Games Dataset into quickly ingestable format...') | |
with open(payload_path, "wb") as outfile_texts: | |
packed_payload = msgpack.packb(payload_docs, use_bin_type=True) | |
outfile_texts.write(packed_payload) | |
np.savez_compressed(dense_path, *dense_embeddings) | |
max_index = 0 | |
for embedding in sparse_embeddings: | |
if len(embedding.indices) > 0: | |
max_index = max(max_index, max(embedding.indices)) | |
sparse_matrices = [] | |
for embedding in sparse_embeddings: | |
data = np.array(embedding.values) | |
indices = np.array(embedding.indices) | |
indptr = np.array([0, len(data)]) | |
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1)) | |
sparse_matrices.append(matrix) | |
combined_sparse_matrix = vstack(sparse_matrices) | |
save_npz(sparse_path, combined_sparse_matrix) | |
unique_ids = [] | |
while len(unique_ids) < len(payload_docs): | |
new_id = uuid.uuid4() | |
while new_id.hex[0] == '0': | |
new_id = uuid.uuid4() | |
unique_ids.append(new_id.int) | |
insert_data(conn, cursor, name, unique_ids) | |
np.save(ids_path, np.array(unique_ids), allow_pickle=True) | |
else: | |
st.write('Loading the saved documents on disk') | |
patterns = ['_ids', '_payload', '_dense', '_sparse'] | |
unique_ids, payload_docs, dense_embeddings, sparse_embeddings = [ | |
collect_files(conn, cursor, embeddings_path, pattern) for pattern in patterns | |
] | |
st.write('Ingesting saved documents on disk into our Qdrant Vector Database...') | |
print(f'LEN FIRST : {len(unique_ids)}, {len(payload_docs)}, {len(dense_embeddings)}, {len(sparse_embeddings)}') | |
client.upsert( | |
collection_name, | |
points=Batch( | |
ids=unique_ids, | |
payloads=payload_docs, | |
vectors={ | |
'text-dense': dense_embeddings, | |
'text-sparse': sparse_embeddings | |
} | |
) | |
) | |
client.update_collection( | |
collection_name=collection_name, | |
optimizer_config=OptimizersConfigDiff(indexing_threshold=20000) | |
) | |
st.write('Building FSM Index for Agentic Behaviour of our AI...') | |
answer = main('Tell who is David Beckham', client, collection_name, model, dense_model, sparse_model, '') | |
status.update( | |
label="Processing Complete!", state="complete", expanded=False | |
) | |
time.sleep(5) | |
container.empty() | |
return client, collection_name, llm, model, dense_model, sparse_model, nlp, conn, cursor | |
def on_change_documents_only(): | |
if st.session_state.documents_only: | |
global_state_documents_only = True | |
st.session_state.toggle_docs = { | |
'tooltip': 'The AI answer your questions only considering the documents provided', | |
'display': True | |
} | |
else: | |
global_state_documents_only = False | |
st.session_state.toggle_docs = { | |
'tooltip': """The AI answer your questions considering the documents provided, and if it doesn't found the answer in them, try to find in its own internal knowledge""", | |
'display': False | |
} | |
if __name__ == '__main__': | |
st.set_page_config(page_title="Multipurpose AI Agent",layout="wide", initial_sidebar_state='auto') | |
client, collection_name, llm, model, dense_model, sparse_model, nlp, conn, cursor = load_models_and_documents() | |
styles = { | |
"nav": { | |
"background-color": "rgb(204, 200, 194)", | |
}, | |
"div": { | |
"max-width": "32rem", | |
}, | |
"span": { | |
"border-radius": "0.5rem", | |
"color": "rgb(125, 102, 84)", | |
"margin": "0 0.125rem", | |
"padding": "0.4375rem 0.625rem", | |
}, | |
"active": { | |
"background-color": "rgba(255, 255, 255, 0.25)", | |
}, | |
"hover": { | |
"background-color": "rgba(255, 255, 255, 0.35)", | |
}, | |
} | |
if 'menu_id' not in st.session_state: | |
st.session_state.menu_id = 'ChatBot' | |
st.session_state.menu_id = st_navbar( | |
['ChatBot', 'Documents'], | |
st.session_state.menu_id, | |
options={ | |
'hide_nav': False, | |
'fix_shadow': False, | |
'use_padding': False | |
}, | |
styles=styles | |
) | |
st.title('Multipurpose AI Agent') | |
#st.markdown("<h1 style='position: fixed; top: 0; left: 0; width: 100%; padding: 10px; text-align: left; color: black;'>Multipurpose AI Agent</h1>", unsafe_allow_html=True) | |
data_editor_path = os.path.join(os.getenv('HF_HOME'), 'documents') | |
if 'df' not in st.session_state: | |
if os.path.exists(data_editor_path): | |
st.session_state.df = pd.read_parquet(os.path.join(data_editor_path, 'data_editor.parquet.lz4'), engine='pyarrow') | |
else: | |
st.session_state.df = pd.DataFrame() | |
os.mkdir(data_editor_path) | |
st.session_state.df.to_parquet( | |
os.path.join( | |
data_editor_path, | |
'data_editor.parquet.lz4' | |
), | |
compression='lz4', | |
engine='pyarrow' | |
) | |
if 'filter_ids' not in st.session_state: | |
st.session_state.filter_ids = [] | |
def on_change_data_editor(conn, cursor, client, collection_name): | |
print(f'Check : {st.session_state.key_data_editor}') | |
if st.session_state.key_data_editor['deleted_rows']: | |
st.toast('Wait for deletion to complete...') | |
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings') | |
for deleted_idx in st.session_state.key_data_editor['deleted_rows']: | |
name = st.session_state.df.loc[deleted_idx, 'document'] | |
print(f'WHAT IS THAT : {name}') | |
os.remove(os.path.join(embeddings_path, name + '_ids.npy')) | |
ids_values = retrieve_ids_value(conn, cursor, name) | |
client.delete( | |
collection_name=collection_name, | |
points_selector=PointIdsList(points=ids_values) | |
) | |
delete_document(conn, cursor, name) | |
st.session_state.df.drop(deleted_idx) | |
st.toast('Deletion Completed !', icon='🎉') | |
elif st.session_state.key_data_editor['edited_rows']: | |
edit_dict = st.session_state.key_data_editor['edited_rows'] | |
for key, value in edit_dict.items(): | |
toggle = value['toggle'] | |
st.session_state.df.loc[key, 'toggle'] = toggle | |
retrieved_ids = retrieve_ids_value(conn, cursor, st.session_state.df.loc[key, 'document']) | |
if not toggle: | |
st.session_state.filter_ids.extend(retrieved_ids) | |
else: | |
st.session_state.filter_ids = [i for i in st.session_state.filter_ids if i not in retrieved_ids] | |
if st.session_state.menu_id == 'Documents': | |
st.session_state.df = st.data_editor( | |
st.session_state.df, | |
num_rows="dynamic", | |
use_container_width=True, | |
hide_index=True, | |
on_change=on_change_data_editor, | |
args=(conn, cursor, client, collection_name), | |
key='key_data_editor', | |
column_config={ | |
'icon': st.column_config.ImageColumn( | |
'Document' | |
), | |
"document": st.column_config.TextColumn( | |
"Name", | |
help="Name of the document", | |
required=True | |
), | |
"type": st.column_config.SelectboxColumn( | |
'File type', | |
help='The file format extension of this document', | |
required=True, | |
options=[ | |
'Powerpoint', | |
'Microsoft Word', | |
'Excel' | |
] | |
), | |
"path": st.column_config.TextColumn( | |
'Path', | |
help='Path to the document', | |
required=False | |
), | |
"time": st.column_config.DatetimeColumn( | |
'Date and hour', | |
help='When this document has been ingested here for the last time', | |
format="D MMM YYYY, h:mm a", | |
required=True | |
), | |
"toggle": st.column_config.CheckboxColumn( | |
'Enable/Disable', | |
help='Either to enable or disable the ability for the ai to find this document', | |
required=True, | |
default=True | |
) | |
} | |
) | |
conversations_path = os.path.join(os.getenv('HF_HOME'), 'conversations') | |
try: | |
with open(conversations_path, 'rb') as fp: | |
packed_bytes = fp.read() | |
conversations: Dict[str, list] = msgpack.unpackb(packed_bytes, raw=False) | |
except: | |
conversations = {} | |
if st.session_state.menu_id == 'ChatBot': | |
if 'id_chat' not in st.session_state: | |
st.session_state.id_chat = 'New Conversation' | |
def options_list(conversations: Dict[str, list]): | |
if st.session_state.id_chat == 'New Conversation': | |
return [st.session_state.id_chat] + list(conversations.keys()) | |
else: | |
return ['New Conversation'] + list(conversations.keys()) | |
with st.sidebar: | |
st.session_state.id_chat = st.selectbox( | |
label='Choose a conversation', | |
options=options_list(conversations), | |
index=0, | |
placeholder='_', | |
key='chat_id' | |
) | |
st.session_state.messages = conversations[st.session_state.id_chat] if st.session_state.id_chat != 'New Conversation' else [] | |
def update_selectbox_remove(conversations_path, conversations): | |
conversations.pop(st.session_state.chat_id) | |
with open(conversations_path, 'wb') as fp: | |
packed_bytes = msgpack.packb(conversations, use_bin_type=True) | |
fp.write(packed_bytes) | |
st.session_state.chat_id = 'New Conversation' | |
st.button( | |
'Delete Conversation', | |
use_container_width=True, | |
disabled=False if st.session_state.id_chat != 'New Conversation' else True, | |
on_click=update_selectbox_remove, | |
args=(conversations_path, conversations) | |
) | |
def generate_conv_title(llm): | |
if st.session_state.chat_id == 'New Conversation': | |
output = llm.chat( | |
build_prompt_conv(), | |
SamplingParams(temperature=0.6,top_p=0.9, max_tokens=10, top_k=10) | |
) | |
print(f'OUTPUT : {output[0].outputs[0].text}') | |
st.session_state.chat_id = output[0].outputs[0].text.replace('"', '') | |
st.session_state.messages = [] | |
torch.cuda.empty_cache() | |
conversations.update({st.session_state.chat_id: st.session_state.messages}) | |
with open(conversations_path, 'wb') as fp: | |
packed_bytes = msgpack.packb(conversations, use_bin_type=True) | |
fp.write(packed_bytes) | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input( | |
"Message Video Game Assistant", | |
on_submit=generate_conv_title, | |
key='user_input', | |
args=(llm,) | |
): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
ai_response = main(prompt, client, collection_name, model, dense_model, sparse_model, "\n".join([f'{msg["role"]}: {msg["content"]}' for msg in st.session_state.messages])) | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
for chunk in re.split(r'(\s+)', ai_response): | |
full_response += chunk + " " | |
time.sleep(0.05) | |
message_placeholder.write(full_response + '▌') | |
message_placeholder.write(re.sub('▌', '', full_response)) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
conversations.update({st.session_state.id_chat: st.session_state.messages}) | |
with open(conversations_path, 'wb') as fp: | |
packed_bytes = msgpack.packb(conversations, use_bin_type=True) | |
fp.write(packed_bytes) | |
if "cached_files" not in st.session_state: | |
st.session_state.cached_files = [] | |
with st.sidebar: | |
st.divider() | |
if 'toggle_docs' not in st.session_state: | |
st.session_state.toggle_docs = { | |
'tooltip': 'The AI answer your questions only considering the documents provided', | |
'display': True | |
} | |
st.toggle( | |
label="""Enable 'Documents-Only' Mode""", | |
value=st.session_state.toggle_docs['display'], | |
on_change=on_change_documents_only, | |
key="documents_only", | |
help=st.session_state.toggle_docs['tooltip'] | |
) | |
st.divider() | |
uploaded_files = st.file_uploader("Upload a file :", accept_multiple_files=True, type=['pptx', 'ppt']) | |
for uploaded_file in uploaded_files: | |
if uploaded_file not in st.session_state.cached_files: | |
st.session_state.cached_files.append(uploaded_file) | |
file_name = os.path.basename(uploaded_file.name) | |
base_name, ext = os.path.splitext(file_name) | |
processing_time = datetime.now().strftime('%d %b %Y, %I:%M %p') | |
full_path = os.path.realpath(uploaded_file.name) | |
file_type = ext.lstrip('.') | |
d = { | |
'icon': icon_to_types[file_type][0], | |
'document': base_name, | |
'type': icon_to_types[file_type][1], | |
'path': full_path, | |
'time': [datetime.strptime(processing_time, '%d %b %Y, %I:%M %p')], | |
'toggle': True | |
} | |
if (st.session_state.df.empty) or (base_name not in st.session_state.df['document'].tolist()): | |
st.session_state.df = pd.concat( | |
[ | |
st.session_state.df, | |
pd.DataFrame(data={ | |
'icon': icon_to_types[file_type][0], | |
'document': base_name, | |
'type': icon_to_types[file_type][1], | |
'path': full_path, | |
'time': [datetime.strptime(processing_time, '%d %b %Y, %I:%M %p')], | |
'toggle': True | |
}) | |
], | |
ignore_index=True | |
) | |
else: | |
idx = st.session_state.df.index[st.session_state.df['document']==base_name].tolist()[0] | |
st.session_state.df.loc[idx] = { | |
'icon': icon_to_types[file_type][0], | |
'document': base_name, | |
'type': icon_to_types[file_type][1], | |
'path': full_path, | |
'time': datetime.strptime(processing_time, '%d %b %Y, %I:%M %p'), | |
'toggle': True | |
} | |
st.session_state.df.to_parquet( | |
os.path.join( | |
data_editor_path, | |
'data_editor.parquet.lz4' | |
), | |
compression='lz4', | |
engine='pyarrow' | |
) | |
documents, ids = ppt_chunker(uploaded_file, llm) | |
dense, tokens_count = asyncio.run(embed_text(dense_model[0], documents)) | |
sparse = [s for s in sparse_model.embed(documents, 32)] | |
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings') | |
def generate_unique_id(existing_ids): | |
while True: | |
new_id = uuid.uuid4() | |
while new_id.hex[0] == '0': | |
new_id = uuid.uuid4() | |
new_id = new_id.int | |
if new_id not in existing_ids: | |
return new_id | |
for filename in os.listdir(embeddings_path): | |
if '_ids' in filename: | |
list_ids = np.load(os.path.join(embeddings_path, filename), allow_pickle=True).tolist() | |
for i, ids_ in enumerate(ids): | |
if ids_ in list_ids: | |
ids[i] = generate_unique_id(list_ids) | |
metadatas_list = [{'url': full_path}] * len(documents) | |
payload_docs = [{ 'text': documents[i], 'metadata': metadata } for i, metadata in enumerate(metadatas_list)] | |
print(f'LEN : {len(ids)}, {len(payload_docs)}, {len(dense)}, {len([SparseVector(indices=s.indices.tolist(), values=s.values.tolist()) for s in sparse])}') | |
client.upsert( | |
collection_name=collection_name, | |
points=Batch( | |
ids=ids, | |
payloads=payload_docs, | |
vectors={ | |
'text-dense': dense, | |
'text-sparse': [SparseVector(indices=s.indices.tolist(), values=s.values.tolist()) for s in sparse] | |
} | |
) | |
) | |
payload_path = os.path.join(embeddings_path, base_name + '_payload.msgpack') | |
dense_path = os.path.join(embeddings_path, base_name + '_dense.npz') | |
sparse_path = os.path.join(embeddings_path, base_name + '_sparse.npz') | |
ids_path = os.path.join(embeddings_path, base_name + '_ids.npy') | |
with open(payload_path, "wb") as outfile_texts: | |
packed_payload = msgpack.packb(payload_docs, use_bin_type=True) | |
outfile_texts.write(packed_payload) | |
np.savez_compressed(dense_path, *dense) | |
max_index = 0 | |
for embedding in sparse: | |
if len(embedding.indices) > 0: | |
max_index = max(max_index, max(embedding.indices)) | |
sparse_matrices = [] | |
for embedding in sparse: | |
data = np.array(embedding.values) | |
indices = np.array(embedding.indices) | |
indptr = np.array([0, len(data)]) | |
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1)) | |
sparse_matrices.append(matrix) | |
combined_sparse_matrix = vstack(sparse_matrices) | |
save_npz(sparse_path, combined_sparse_matrix) | |
insert_data(conn, cursor, base_name, ids) | |
np.save(ids_path, np.array(ids), allow_pickle=True) | |
st.toast('Document(s) Ingested !', icon='🎉') |