|
import os |
|
import re |
|
import lz4 |
|
import time |
|
import uuid |
|
import torch |
|
import spacy |
|
import base64 |
|
import asyncio |
|
import msgpack |
|
import sqlite3 |
|
import outlines |
|
import validators |
|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
from vllm import LLM |
|
from numpy import ndarray |
|
from datetime import datetime |
|
from typing import List, Dict |
|
from ppt_chunker import ppt_chunk |
|
from dense_embed import embed_text |
|
from outlines import models, generate |
|
from qdrant_client import QdrantClient |
|
from unstructured.cleaners.core import clean |
|
from streamlit_navigation_bar import st_navbar |
|
from vllm.sampling_params import SamplingParams |
|
from fastembed import SparseTextEmbedding, SparseEmbedding |
|
from unstructured.nlp.tokenize import download_nltk_packages |
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
from scipy.sparse import csr_matrix, save_npz, load_npz, vstack |
|
from langchain_experimental.text_splitter import SemanticChunker |
|
from infinity_emb import AsyncEngineArray, EngineArgs, AsyncEmbeddingEngine |
|
from langchain_community.document_loaders import WikipediaLoader, WebBaseLoader |
|
from sqlalchemy import ( |
|
create_engine, |
|
MetaData, |
|
Table, |
|
Column, |
|
String, |
|
Integer, |
|
select, |
|
column, |
|
) |
|
from infinity_emb.primitives import ( |
|
Device, |
|
Dtype, |
|
EmbeddingDtype, |
|
InferenceEngine, |
|
PoolingMethod, |
|
) |
|
from qdrant_client.models import ( |
|
NamedSparseVector, |
|
NamedVector, |
|
SparseVector, |
|
PointStruct, |
|
ScoredPoint, |
|
Prefetch, |
|
FusionQuery, |
|
Fusion, |
|
SearchRequest, |
|
Modifier, |
|
OptimizersConfigDiff, |
|
HnswConfigDiff, |
|
Distance, |
|
VectorParams, |
|
SparseVectorParams, |
|
SparseIndexParams, |
|
Batch, |
|
PointIdsList, |
|
QueryRequest, |
|
Filter, |
|
HasIdCondition, |
|
Datatype, |
|
BinaryQuantization, |
|
BinaryQuantizationConfig, |
|
MultiVectorConfig |
|
) |
|
|
|
icon_to_types = { |
|
'ppt':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAC4ElEQVR4nO2YS2gTQRzGF/VgPVuPpsXSQ3eshdbXqVDMEtlN1UN8HqTZYouCVJTsiIfgTcWbD0o0mp3GKq1oC55E2AbbKn3dakISCVjwUNtLlZ1a24xs0oqG2uxmJ2si+8F32NnL95v/fx67DGPLli1b/5UUkauJtDsvDHm5XkXkphSvc17xcksqgksYSXMYwSksS71Yhudx6MouphSk+Ju3RLzcGUV0jg6JHFnPGMH1LcMRVZZOkz7P5n8TXuRcisjF/xY8LwDKWpVhDIcgZ1nwiXPCNkXkHuYLrhcAr4EgKUD6LlUUNfxIp3OHIjon9YY3AoCzHl8IXq0sWvgh0RkzEr4AAKK1FHWIbNsYm/lCAfBqJchj/1ZqAEZ6nhIAURHsNhQy0OToCjQ5vj1ochAzDu6vJgOte00DYATJYsh32AiA6fC/Q9AAUJEU1X1O0Aq/ZhoAOAPhO1nWABhJw2UNoMowjdG16oIBcrWy/IPMx6Pk9eV2iyoACQ75OkwDxF4+JdEXT8jMaCTznF5ZJoNtR60BkKWwaYDgwZpfY/FXzzNj0/3IGgAEJ6gCjN29ma3KwDOLAKRZ0wBhrpGglnoy2HaMLHyeyYy98XVaAqAiuGgaIFcf+ns2XMRJAdD0ommA8Xu3yNidG+Td7etk4Gxr3m2ULgA7S3UN6DFlgIlyB+gpcwBQ+EFWqGmFTwggnTqyp6p8AXjwNm/4kgZwg+N6Ab7SCv9oXxUdAB5ME4/eD5rGnRdpQGjhIy21dADcdS0MDSUFEKC8qxAdi/c+Q0ufPAcqEjw7biHA+1Szg95vFU1xV0NlgmdjRZ95no3GhNrtVMP/CQHGijnzcVdDcX4t5rRTdzF6PkW7bTbSR373Ia3cpsPzYJrabmNU2h6ddNedSvBgWDvyDcx2WjthEwJ7gviZTUwpKOaur9YuXQmBDSd5djLJg7kkD76v+ot2Jc68E0CHrruNLVu2bDGlrJ8c/urSuEn7XgAAAABJRU5ErkJggg==', |
|
'Powerpoint'), |
|
'pptx':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAC4ElEQVR4nO2YS2gTQRzGF/VgPVuPpsXSQ3eshdbXqVDMEtlN1UN8HqTZYouCVJTsiIfgTcWbD0o0mp3GKq1oC55E2AbbKn3dakISCVjwUNtLlZ1a24xs0oqG2uxmJ2si+8F32NnL95v/fx67DGPLli1b/5UUkauJtDsvDHm5XkXkphSvc17xcksqgksYSXMYwSksS71Yhudx6MouphSk+Ju3RLzcGUV0jg6JHFnPGMH1LcMRVZZOkz7P5n8TXuRcisjF/xY8LwDKWpVhDIcgZ1nwiXPCNkXkHuYLrhcAr4EgKUD6LlUUNfxIp3OHIjon9YY3AoCzHl8IXq0sWvgh0RkzEr4AAKK1FHWIbNsYm/lCAfBqJchj/1ZqAEZ6nhIAURHsNhQy0OToCjQ5vj1ochAzDu6vJgOte00DYATJYsh32AiA6fC/Q9AAUJEU1X1O0Aq/ZhoAOAPhO1nWABhJw2UNoMowjdG16oIBcrWy/IPMx6Pk9eV2iyoACQ75OkwDxF4+JdEXT8jMaCTznF5ZJoNtR60BkKWwaYDgwZpfY/FXzzNj0/3IGgAEJ6gCjN29ma3KwDOLAKRZ0wBhrpGglnoy2HaMLHyeyYy98XVaAqAiuGgaIFcf+ns2XMRJAdD0ommA8Xu3yNidG+Td7etk4Gxr3m2ULgA7S3UN6DFlgIlyB+gpcwBQ+EFWqGmFTwggnTqyp6p8AXjwNm/4kgZwg+N6Ab7SCv9oXxUdAB5ME4/eD5rGnRdpQGjhIy21dADcdS0MDSUFEKC8qxAdi/c+Q0ufPAcqEjw7biHA+1Szg95vFU1xV0NlgmdjRZ95no3GhNrtVMP/CQHGijnzcVdDcX4t5rRTdzF6PkW7bTbSR373Ia3cpsPzYJrabmNU2h6ddNedSvBgWDvyDcx2WjthEwJ7gviZTUwpKOaur9YuXQmBDSd5djLJg7kkD76v+ot2Jc68E0CHrruNLVu2bDGlrJ8c/urSuEn7XgAAAABJRU5ErkJggg==', |
|
'Powerpoint'), |
|
'txt':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAYAAACqaXHeAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEA0lEQVR4nO2bW08TQRTH+8QDH8DPpF+EGLt+AENI44OC7u6L+mAxYQ0xLdZoDE8kCpr4oEYUGu4kEOgFRWm5FLbd3Zk9Ztpus132WmZ3WtqTnGQzLbP7/82ZM2dmaSw2sIENbGAdWHIkOSxw0pgQlzYFTlJFTgIGXhQ4aSKRyAzFojT+9vQNMT71i5HoS04gRCpe4KZWyY2TiUzrIdTj3cCOzooA1ZO64/ND23azm79DfD+73IqEyMSLcWmF3HQykYG5ldOOAVxVvHayB7pcbt2fifgPu0pHAGiJJ59FAoB3EN8JAJriIwHAu4gPCoC2+NAB8B7i/QDA1ePQxIcKgPch3gsAkv8BMQIhDPGhAeB9incDgOQStJlWDSyegHMTHwqA5Ehy2ChyXtx/A3OrZ47ivSIAGxCUi8DitcpB/U915dxRfCgABE4aM4ocL/Hvv+TrN39+L+Wc8FS5I/E6VhvwdB1AvYgQQFzaDFqKLryep5btUbUMtoaUyCJA9SucjPznzDxUj3boLnVKBUDHpgiQo4sAkWFt35bwyNwn5iK+qwCgMNZ5reYqvmsAoJCKHD/OHABiKJ4pgK3vi5AcTTE/DPHtcSnPcy9vUQMwOZpmLyqgC3elAhUA6KzY6vTBj1pPuO+pInoAMOZ2XwJApsTWdwCQJdvTBPD4Zw0mFrsYALJZ6vwAyFea5ayN5Sq4/p2HizWoqDocynobBHJ9cK7X+wjSH3UAyGGd9wMg5/OBiXhiC3mt1Uauif2V9cD9UQOAXIqcoFPAMLvPpjcaW18VAzzLKvAkq4DS1JreVgP3RwUA8qjwaAIgvlZqKF4vYVgvN6/LzqMaOgDwKG9pAzCPuhENT7MKWwDYpbanDcA87635gAkA7LGxCQPAp0IXAVA9dnW0AZBwJ2FvGJkOZFowBYBctrS0AWw0E9+aKQmSa2YAkMd+niaA1FZjGVRQY9TNCfHVJqNlEDwOM2gCIBUfsY+5y4VQwaHAiRwAtpzkBAVAKjWnau3rb1QP+3FLKbx9jOHbHxS4P+oAsM0x1rXfDYLHGV5fAMDmWkAugXaau/SfGdcWAG4Tf0Re0QBgDaB22h8A1Law3wcgLyvJK6vm1JgcnYn0QJPZoajWOrdvjLzhu0tbPQVB4KS8eEe6GQiAdsWXFl5uTbCduO/Q9mtis8NeEB8qAL0HxIcKAHpAfM8BwJTF9xQAHIL4SACk+VmYEWfbxARtM4tPPXoLaf5dm4irtIUOQLSJiKBtdg9Mu20AgJaJlpEkIWwNbT9tJOxJuFpDlnZbVyZBHFLCiyoJFkmHhfWdrhe/t7xE/yczAidNsN64BPb41Dg1AIlEZqgJoR4JXe5sfjY3sIENLHZd7D/x1k4dCUv1GwAAAABJRU5ErkJggg==', |
|
'Txt'), |
|
'doc':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAACOElEQVR4nO2ZQWgTQRSGh+DFq70VD2lvQuKpV714cGtLu3ioiiJeWr0oaSnSm0KhKnabJiKYYhHRQxORBmLqoYVKqaGGWmizAYsaCBhDpdCkqatpap68l3QO4kHJsBnN/PDDzts34f+YN3sJY0pKSkr/pFz31q64/Ym822+CPU7kXX7zsjAAe8ObZJfPzAkEsDe8u2oFsC91An41Qv/rCGmGh2nGDtMMqMWOznFoHpipC8BOreE5RMc4uIdNoWZ/AAAirQCGG/kEjg8GARVZ+shrR3ofwb5ODD3j9Wg8RTXcIw3AwS4ffN/dgy85i9f6fLMc4MaTGK9vbH2lXtwjDQDTDHidzFDYlksPaf14LgnFUgVqdiVNNXyHwl6pRohpBtwJxSncmZEXtE5lc7D07jNMx95DwdqFAx1eOHc7Sj3YKx1A180whTOeL0Pz+QA9e6ffwvXJBXpuu/qU1ijslQ6gqec+lMsAi8kMnQKqZyQCxwan6NnzYJ5GB3uw99f9Rz1vhJr9LQA6md6Eb8USBGZWKfThCxN0WfEuhGMfwCqWqOd3e6UACFSDb1tF+LRZ4PX4ehZKez/oHfZIC3Dx7kv+6QwtrPO6L7zC69gjLUBr9TOJ6g+84vWztyp3AoU90gKwGqwAPI13AqMFUeEdp7x1ADhpXBMB4WgfKx/qnLRaT4e2RJqJUosehHqYKYCq1AnoaoQafIScejBvd3hn95S4PziceqgPf9DO8E492CsMQElJSYmJ0E+635eFCoKREwAAAABJRU5ErkJggg==', |
|
'Microsoft Word'), |
|
'docx':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAACOElEQVR4nO2ZQWgTQRSGh+DFq70VD2lvQuKpV714cGtLu3ioiiJeWr0oaSnSm0KhKnabJiKYYhHRQxORBmLqoYVKqaGGWmizAYsaCBhDpdCkqatpap68l3QO4kHJsBnN/PDDzts34f+YN3sJY0pKSkr/pFz31q64/Ym822+CPU7kXX7zsjAAe8ObZJfPzAkEsDe8u2oFsC91An41Qv/rCGmGh2nGDtMMqMWOznFoHpipC8BOreE5RMc4uIdNoWZ/AAAirQCGG/kEjg8GARVZ+shrR3ofwb5ODD3j9Wg8RTXcIw3AwS4ffN/dgy85i9f6fLMc4MaTGK9vbH2lXtwjDQDTDHidzFDYlksPaf14LgnFUgVqdiVNNXyHwl6pRohpBtwJxSncmZEXtE5lc7D07jNMx95DwdqFAx1eOHc7Sj3YKx1A180whTOeL0Pz+QA9e6ffwvXJBXpuu/qU1ijslQ6gqec+lMsAi8kMnQKqZyQCxwan6NnzYJ5GB3uw99f9Rz1vhJr9LQA6md6Eb8USBGZWKfThCxN0WfEuhGMfwCqWqOd3e6UACFSDb1tF+LRZ4PX4ehZKez/oHfZIC3Dx7kv+6QwtrPO6L7zC69gjLUBr9TOJ6g+84vWztyp3AoU90gKwGqwAPI13AqMFUeEdp7x1ADhpXBMB4WgfKx/qnLRaT4e2RJqJUosehHqYKYCq1AnoaoQafIScejBvd3hn95S4PziceqgPf9DO8E492CsMQElJSYmJ0E+635eFCoKREwAAAABJRU5ErkJggg==', |
|
'Microsoft Word'), |
|
'xslx':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAADGUlEQVR4nO2Za0hTYRzG9z0vu59JpK3vedvmTHOmM1OnE0wqbZWJjjTNS2nz7jZvFVmJUTOCchF0+RaBiAThQkgQKSKKbjohIqMIKlG3J3aiVXicw3PaMToPPF/eT7/fe/7vew4cHo8LFy5c/qlssht1cnvpjNxuxGqqcJhX1dhRs1M5as6kLbBx0OhcLTwdAYVHwtE+TVuADjxdAYXDDE5Azj0BMzdCtLLhTCH8aXifAREXiyC/WhrYQ0z06HVET+4M0a0H0ZUDojMb0g4dpNYsSC1ZkJgzIWnPgKRtB8Rt6RC3pkPcsh2i5jSImrQQNqZC2JACUYMWMose4f2GPyT+vkC33kkXXmhKgeD4NkhMqVB1FyHJVoHkKzVk6V4CPD8EGIEX1CdDUK+BqusAEs+XB1Cgizl4fl3SD4H+sgAKdDIHzz/GgoCUAn69NRsvZ2fgSeuQzQsffaoQcwvz+DY/h5gTBUvg+Ue3siBgpd75/dfbSYHZL58QYckhd/7W5Ai5Zh26RAkfWpvIgoBl+bEZef6QBO4YvgxNXwlcbjdevHdC1phGCR9akxB4AYmPmU84V4z5xQV8/PoZY68fwe12Q2+rXhY+tHoLCwLtvg/sBcdt/MyNiWGf8CFV8SwItPm+bezjd70CNyeGfcKHHImHOCcS0l0KyAxqsvzmVFrlrRRf8DkDNeTYjE89wYNXk6RE3kDtsvAhlWoWBFqp4cOa0/Hs3RsSWm+rgrbPSMpMfXiLMJOWEj64Mo4FgRbql9Tpe9dI+OGnY96xufP4PrnWOzJICR9coQq8gIgCPqH3IHn7uNwuaM4We2de2VOABdci2cST+5bABx1WsiDQtPLnga+Z/x0+qJwFASGD8EFlChYEGpiDX3coNvACAlOKkyn4oJJoFgTqNZn8Oo2TCXjBzkiIc6NA7P4lQLc8uhHpNsPfivVRkObHgChQrR0BSV4M/KnU0/xYELuVkO2NWzsCxB4l/K5n5xmElzEhwCSMjBP4H58AYVA72YInDGr6PzhkhepMNiQIg3paZojPoC3AhQsXLjwm8x3YSSmFlSW/AQAAAABJRU5ErkJggg==', |
|
'Excel') |
|
} |
|
|
|
|
|
def transform_query(query: str) -> str: |
|
""" For retrieval, add the prompt for query (not for documents). |
|
""" |
|
return f'Represent this sentence for searching relevant passages: {query}' |
|
|
|
def query_hybrid_search(query: str, client: QdrantClient, collection_name: str, dense_model: AsyncEmbeddingEngine, sparse_model: SparseTextEmbedding): |
|
dense_embeddings, tokens_count = asyncio.run(embed_text(dense_model[0], transform_query(query))) |
|
sparse_embeddings = list(sparse_model.query_embed(query))[0] |
|
|
|
return client.query_points( |
|
collection_name=collection_name, |
|
prefetch=[ |
|
Prefetch(query=sparse_embeddings.as_object(), using="text-sparse", limit=25), |
|
Prefetch(query=dense_embeddings[0], using="text-dense", limit=25) |
|
], |
|
query=FusionQuery(fusion=Fusion.RRF), |
|
with_vectors=False, |
|
with_payload=True, |
|
limit=10, |
|
score_threshold=0.95 |
|
) |
|
|
|
def build_prompt_conv(): |
|
return [ |
|
{ |
|
'role': 'user', |
|
'content': f"""Generate a short, single-sentence summary of the user's intent or topic based on their question, capturing the main focus of what they want to discuss. |
|
|
|
Question : {st.session_state.user_input} |
|
""" |
|
} |
|
] |
|
|
|
@outlines.prompt |
|
def build_initial_prompt(query: str): |
|
"""Determine whether the following query is a 'Domain-Specific Question' or a 'General Question.' |
|
|
|
A 'Domain-Specific Question' requires knowledge or familiarity with a particular field, niche, or specialized area of interest, including specific video games, movies, books, academic disciplines, or professional fields. |
|
A 'General Question' is broad, open-ended, and can be answered by almost anyone without needing specific context or prior knowledge about any particular domain. |
|
|
|
Examples : |
|
1. Query: "What are the symptoms of Type 2 diabetes?" |
|
Choose one: Domain-Specific Question |
|
|
|
2. Query: "What is your favorite color?" |
|
Choose one: General Question |
|
|
|
3. Query: "Who is the main character in Dark Souls?" |
|
Choose one: Domain-Specific Question |
|
|
|
4. Query: "How do you bake a cake?" |
|
Choose one: General Question |
|
|
|
5. Query: "Explain the difference between RAM and ROM." |
|
Choose one: Domain-Specific Question |
|
|
|
6. Query: "Tell me more about your weekend." |
|
Choose one: General Question |
|
|
|
7. Query: "Explain me more" |
|
Choose one: General Question |
|
|
|
Now, determine the following query : {{ query }} |
|
|
|
Choose one: 'Domain-Specific Question' or 'General Question' |
|
""" |
|
|
|
@outlines.prompt |
|
def open_query_prompt(past_messages: str, query: str): |
|
"""{{ past_messages }} |
|
|
|
user: {{ query }} |
|
assistant: |
|
""" |
|
|
|
@outlines.prompt |
|
def route_llm(context: str, query: str): |
|
"""Based on the following context, decide if it is relevant to the given query. Return 'Yes' for relevant and 'No' for irrelevant. |
|
|
|
Context : {{ context }} |
|
|
|
Query: {{ query }} |
|
""" |
|
|
|
@outlines.prompt |
|
def answer_with_context(context: str, query: str): |
|
"""Context information is below. |
|
--------------------- |
|
{context} |
|
--------------------- |
|
Given the context information and not prior knowledge, answer the query. |
|
Query: {query} |
|
Answer: |
|
""" |
|
|
|
@outlines.prompt |
|
def idk(query: str): |
|
"""When you encounter a question that falls outside your knowledge or expertise, respond in a way that politely conveys you don't have the information needed to answer. |
|
|
|
Question: {{ query }} |
|
""" |
|
|
|
@outlines.prompt |
|
def self_knowledge(query: str): |
|
"""Answer the following question by using your own knowledge about the topic. |
|
|
|
Question: {{ query }} |
|
""" |
|
|
|
def main(query: str, client: QdrantClient, collection_name: str, llm, dense_model: AsyncEmbeddingEngine, sparse_model: SparseTextEmbedding, past_messages: str): |
|
scored_points = query_hybrid_search(query, client, collection_name, dense_model, sparse_model).points |
|
|
|
docs = [(scored_point.payload['text'], scored_point.payload['metadata']) for scored_point in scored_points] |
|
contents, metadatas = [list(t) for t in zip(*docs)] |
|
|
|
context = "\n".join(contents) |
|
|
|
gen_text = outlines.generate.text(llm) |
|
|
|
gen_choice = outlines.generate.choice(llm, choices=['Yes', 'No']) |
|
prompt = route_llm(context, 'Is the context relevant to the question ?') |
|
action = gen_choice(prompt, max_tokens=2, sampling_params=SamplingParams(temperature=0)) |
|
print(f'Choice: {action}') |
|
|
|
if action == 'Yes': |
|
seen_values = set() |
|
result_metadatas = "\n\n".join( |
|
f'{value}' |
|
for metadata in metadatas |
|
for key, value in metadata.items() |
|
if (value not in seen_values and not seen_values.add(value)) |
|
) |
|
|
|
prompt = answer_with_context(context, query) |
|
answer = gen_text(prompt, max_tokens=45, sampling_params=SamplingParams(temperature=0.3)) |
|
answer = f"{answer}\n\n\nSource(s) :\n\n{result_metadatas}" |
|
|
|
if not st.session_state.documents_only: |
|
answer = f'Documents Based :\n\n{answer}' |
|
else: |
|
gen_choice = outlines.generate.choice(llm, choices=['Domain-Specific Question', 'General Question']) |
|
prompt = build_initial_prompt(query) |
|
action = gen_choice(prompt, max_tokens=3, sampling_params=SamplingParams(temperature=0)) |
|
print(f'Choice 2: {action}') |
|
if action == 'General Question': |
|
prompt = open_query_prompt(past_messages, query) |
|
answer = gen_text(prompt, max_tokens=45, sampling_params=SamplingParams(temperature=0.3)) |
|
else: |
|
if st.session_state.documents_only: |
|
prompt = idk(query) |
|
answer = gen_text(prompt, max_tokens=20, sampling_params=SamplingParams(temperature=0.3)) |
|
else: |
|
prompt = self_knowledge(query) |
|
answer = gen_text(prompt, max_tokens=45, sampling_params=SamplingParams(temperature=0.3)) |
|
answer = f'Internal Knowledge :\n\n{answer}' |
|
|
|
torch.cuda.empty_cache() |
|
|
|
return answer |
|
|
|
def collect_files(conn, cursor, directory, pattern): |
|
array = [] |
|
|
|
for filename in os.listdir(directory): |
|
if pattern in filename: |
|
if filename.endswith('.msgpack'): |
|
with open(os.path.join(directory, filename), "rb") as data_file_payload: |
|
decompressed_payload = data_file_payload.read() |
|
|
|
array.extend(msgpack.unpackb(decompressed_payload, raw=False)) |
|
elif (filename.endswith('.npz')) and (pattern == '_dense'): |
|
array.extend(list(np.load(os.path.join(directory, filename)).values())) |
|
elif (filename.endswith('.npz')) and (pattern == '_sparse'): |
|
sparse_embeddings = [] |
|
loaded_sparse_matrix = load_npz(os.path.join(directory, filename)) |
|
|
|
for i in range(loaded_sparse_matrix.shape[0]): |
|
row = loaded_sparse_matrix.getrow(i) |
|
values = row.data.tolist() |
|
indices = row.indices.tolist() |
|
embedding = SparseVector(indices=indices, values=values) |
|
sparse_embeddings.append(embedding) |
|
array.extend(sparse_embeddings) |
|
elif (filename.endswith('.npy')): |
|
ids_list = np.load(os.path.join(directory, filename), allow_pickle=True).tolist() |
|
insert_data(conn, cursor, os.path.splitext(filename)[0], ids_list) |
|
array.extend(ids_list) |
|
|
|
return array |
|
|
|
def int_to_bytes(value): |
|
return base64.b64encode(str(value).encode()) |
|
|
|
def bytes_to_int(value): |
|
return int(base64.b64decode(value).decode()) |
|
|
|
def insert_data(conn, cursor, name, ids_array): |
|
cursor.execute('INSERT INTO table_names (doc_name) VALUES (?)', (name,)) |
|
for ids in ids_array: |
|
cursor.execute('INSERT INTO table_ids (name, ids_value) VALUES (?, ?)', (name, int_to_bytes(ids))) |
|
conn.commit() |
|
|
|
def retrieve_ids_value(conn, cursor, name): |
|
cursor.execute('SELECT ids_value FROM table_ids WHERE name = ?', (name,)) |
|
rows = cursor.fetchall() |
|
return [bytes_to_int(row[0]) for row in rows] |
|
|
|
def delete_document(conn, cursor, name): |
|
conn.execute('BEGIN') |
|
|
|
try: |
|
cursor.execute('DELETE FROM table_ids WHERE name = ?', (name,)) |
|
cursor.execute('DELETE FROM table_names WHERE doc_name = ?', (name,)) |
|
|
|
conn.commit() |
|
print(f"Deleted document '{name}' and its associated IDs.") |
|
except sqlite3.Error as e: |
|
conn.rollback() |
|
print(f"An error occurred: {e}") |
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_models_and_documents(): |
|
container = st.empty() |
|
|
|
with container.status("Load AI Models and Prepare Documents...", expanded=True) as status: |
|
st.write('Downloading and Loading MixedBread Mxbai Dense Embedding Model with vLLM as backend...') |
|
|
|
dense_model = AsyncEngineArray.from_args( |
|
[ |
|
EngineArgs( |
|
model_name_or_path='GameScribes/mxbai-embed-large-v1', |
|
engine='optimum', |
|
device='cuda', |
|
embedding_dtype='float32', |
|
dtype='float16', |
|
pooling_method='cls', |
|
lengths_via_tokenize=True |
|
) |
|
] |
|
) |
|
|
|
st.write('Downloading and Loading Qdrant BM42 Sparse Embedding Model under ONNX using the CPU...') |
|
|
|
sparse_model = SparseTextEmbedding( |
|
'Qdrant/bm42-all-minilm-l6-v2-attentions', |
|
cache_dir=os.getenv('HF_HOME'), |
|
providers=['CPUExecutionProvider'] |
|
) |
|
|
|
st.write('Downloading and Loading Mistral v0.2 by AWS Prototyping quantized with AWQ and using Outlines + vLLM Engine as backend...') |
|
|
|
llm = LLM( |
|
model='aws-prototyping/MegaBeam-Mistral-7B-300k-AWQ', |
|
revision='MegaBeam-Mistral-7B-300k-AWQ-64g-4b-GEMM', |
|
tensor_parallel_size=1, |
|
trust_remote_code=True, |
|
enforce_eager=True, |
|
quantization="awq", |
|
gpu_memory_utilization=0.7, |
|
max_model_len=12288, |
|
dtype=torch.float16, |
|
max_num_seqs=128 |
|
) |
|
model = models.VLLM(llm) |
|
|
|
st.write('Loading Spacy Natural Language Processing Model for English...') |
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
download_nltk_packages() |
|
|
|
st.write('Creating Collection for our Qdrant Vector Database in RAM memory...') |
|
|
|
client = QdrantClient(':memory:') |
|
collection_name = 'collection_demo' |
|
|
|
client.create_collection( |
|
collection_name, |
|
{ |
|
'text-dense': VectorParams( |
|
size=1024, |
|
distance=Distance.COSINE, |
|
datatype=Datatype.FLOAT16, |
|
on_disk=False |
|
) |
|
}, |
|
{ |
|
'text-sparse': SparseVectorParams( |
|
index=SparseIndexParams( |
|
on_disk=False |
|
), |
|
modifier=Modifier.IDF |
|
) |
|
}, |
|
2, |
|
optimizers_config=OptimizersConfigDiff( |
|
indexing_threshold=0, |
|
default_segment_number=4 |
|
), |
|
hnsw_config=HnswConfigDiff( |
|
on_disk=False, |
|
m=32, |
|
ef_construct=200 |
|
) |
|
) |
|
|
|
conn = sqlite3.connect(':memory:', check_same_thread=False) |
|
conn.execute('PRAGMA foreign_keys = ON;') |
|
cursor = conn.cursor() |
|
|
|
cursor.execute(''' |
|
CREATE TABLE table_names ( |
|
doc_name TEXT PRIMARY KEY |
|
) |
|
''') |
|
cursor.execute(''' |
|
CREATE TABLE table_ids ( |
|
name TEXT, |
|
ids_value BLOB, |
|
FOREIGN KEY(name) REFERENCES table_names(doc_name) |
|
) |
|
''') |
|
|
|
cursor.execute('SELECT COUNT(*) FROM table_names') |
|
count = cursor.fetchone()[0] |
|
print(f'Is the table empty? {"Yes" if count == 0 else "No"}') |
|
|
|
name = 'action_rpg' |
|
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings') |
|
|
|
payload_path = os.path.join(embeddings_path, name + '_payload.msgpack') |
|
dense_path = os.path.join(embeddings_path, name + '_dense.npz') |
|
sparse_path = os.path.join(embeddings_path, name + '_sparse.npz') |
|
ids_path = os.path.join(embeddings_path, name + '_ids.npy') |
|
|
|
if not os.path.exists(embeddings_path): |
|
os.mkdir(embeddings_path) |
|
|
|
st.write('Downloading and Loading Video Games Dataset coming from Wikipedia...') |
|
|
|
docs_1 = WikipediaLoader(query='Action-RPG').load() |
|
docs_2 = WikipediaLoader(query='Real-time strategy').load() |
|
docs_3 = WikipediaLoader(query='First-person shooter').load() |
|
docs_4 = WikipediaLoader(query='Multiplayer online battle arena').load() |
|
docs_5 = WikipediaLoader(query='List of video game genres').load() |
|
docs = docs_1 + docs_2 + docs_3 + docs_4 + docs_5 |
|
|
|
texts, metadatas = [], [] |
|
for doc in docs: |
|
texts.append(doc.page_content) |
|
del doc.metadata['title'] |
|
del doc.metadata['summary'] |
|
metadatas.append(doc.metadata) |
|
|
|
st.write('Transforming the Wikipedia Video Games Dataset into ingestable format for our Qdrant Vector Database...') |
|
|
|
payload_docs, dense_embeddings, sparse_embeddings = chunk_documents(texts, metadatas, dense_model, sparse_model) |
|
|
|
st.write('Saving on disk the Wikipedia Video Games Dataset into quickly ingestable format...') |
|
|
|
with open(payload_path, "wb") as outfile_texts: |
|
packed_payload = msgpack.packb(payload_docs, use_bin_type=True) |
|
outfile_texts.write(packed_payload) |
|
|
|
np.savez_compressed(dense_path, *dense_embeddings) |
|
max_index = 0 |
|
for embedding in sparse_embeddings: |
|
if len(embedding.indices) > 0: |
|
max_index = max(max_index, max(embedding.indices)) |
|
|
|
sparse_matrices = [] |
|
for embedding in sparse_embeddings: |
|
data = np.array(embedding.values) |
|
indices = np.array(embedding.indices) |
|
indptr = np.array([0, len(data)]) |
|
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1)) |
|
sparse_matrices.append(matrix) |
|
|
|
combined_sparse_matrix = vstack(sparse_matrices) |
|
save_npz(sparse_path, combined_sparse_matrix) |
|
|
|
unique_ids = [] |
|
while len(unique_ids) < len(payload_docs): |
|
new_id = uuid.uuid4() |
|
while new_id.hex[0] == '0': |
|
new_id = uuid.uuid4() |
|
unique_ids.append(new_id.int) |
|
|
|
insert_data(conn, cursor, name, unique_ids) |
|
|
|
np.save(ids_path, np.array(unique_ids), allow_pickle=True) |
|
else: |
|
st.write('Loading the saved documents on disk') |
|
|
|
patterns = ['_ids', '_payload', '_dense', '_sparse'] |
|
|
|
unique_ids, payload_docs, dense_embeddings, sparse_embeddings = [ |
|
collect_files(conn, cursor, embeddings_path, pattern) for pattern in patterns |
|
] |
|
|
|
st.write('Ingesting saved documents on disk into our Qdrant Vector Database...') |
|
|
|
print(f'LEN FIRST : {len(unique_ids)}, {len(payload_docs)}, {len(dense_embeddings)}, {len(sparse_embeddings)}') |
|
|
|
client.upsert( |
|
collection_name, |
|
points=Batch( |
|
ids=unique_ids, |
|
payloads=payload_docs, |
|
vectors={ |
|
'text-dense': dense_embeddings, |
|
'text-sparse': sparse_embeddings |
|
} |
|
) |
|
) |
|
|
|
client.update_collection( |
|
collection_name=collection_name, |
|
optimizer_config=OptimizersConfigDiff(indexing_threshold=20000) |
|
) |
|
status.update( |
|
label="Processing Complete!", state="complete", expanded=False |
|
) |
|
|
|
time.sleep(5) |
|
container.empty() |
|
|
|
return client, collection_name, llm, model, dense_model, sparse_model, nlp, conn, cursor |
|
|
|
def chunk_documents(texts: List[str], metadatas: List[dict], dense_model: AsyncEmbeddingEngine, sparse_model: SparseTextEmbedding): |
|
text_splitter = SemanticChunker( |
|
dense_model, |
|
breakpoint_threshold_type='standard_deviation' |
|
) |
|
docs = text_splitter.create_documents(texts, metadatas) |
|
|
|
payload_docs, documents = [], [] |
|
|
|
for doc in docs: |
|
payload_docs.append({ 'text': doc.page_content, 'metadata': doc.metadata }) |
|
documents.append(doc.page_content) |
|
|
|
start_dense = time.time() |
|
dense_embeddings, tokens_count = asyncio.run(embed_text(dense_model[0], documents)) |
|
print(f'DENSE EMBED : {dense_embeddings}') |
|
end_dense = time.time() |
|
final_dense = end_dense - start_dense |
|
print(f'DENSE TIME: {final_dense}') |
|
|
|
start_sparse = time.time() |
|
|
|
sparse_embeddings = [SparseVector(indices=s.indices.tolist(), values=s.values.tolist()) for s in sparse_model.embed(documents, 32)] |
|
|
|
end_sparse = time.time() |
|
final_sparse = end_sparse - start_sparse |
|
print(f'SPARSE TIME: {final_sparse}') |
|
|
|
return payload_docs, dense_embeddings, sparse_embeddings |
|
|
|
def on_change_documents_only(): |
|
if st.session_state.documents_only: |
|
st.session_state.toggle_docs = { |
|
'tooltip': 'The AI answer your questions only considering the documents provided', |
|
'display': True |
|
} |
|
else: |
|
st.session_state.toggle_docs = { |
|
'tooltip': """The AI answer your questions considering the documents provided, and if it doesn't found the answer in them, try to find in its own internal knowledge""", |
|
'display': False |
|
} |
|
|
|
|
|
if __name__ == '__main__': |
|
st.set_page_config(page_title="Multipurpose AI Agent",layout="wide", initial_sidebar_state='auto') |
|
|
|
client, collection_name, llm, model, dense_model, sparse_model, nlp, conn, cursor = load_models_and_documents() |
|
|
|
styles = { |
|
"nav": { |
|
"background-color": "rgb(204, 200, 194)", |
|
}, |
|
"div": { |
|
"max-width": "32rem", |
|
}, |
|
"span": { |
|
"border-radius": "0.5rem", |
|
"color": "rgb(125, 102, 84)", |
|
"margin": "0 0.125rem", |
|
"padding": "0.4375rem 0.625rem", |
|
}, |
|
"active": { |
|
"background-color": "rgba(255, 255, 255, 0.25)", |
|
}, |
|
"hover": { |
|
"background-color": "rgba(255, 255, 255, 0.35)", |
|
}, |
|
} |
|
|
|
if 'menu_id' not in st.session_state: |
|
st.session_state.menu_id = 'ChatBot' |
|
|
|
st.session_state.menu_id = st_navbar( |
|
['ChatBot', 'Documents'], |
|
st.session_state.menu_id, |
|
options={ |
|
'hide_nav': False, |
|
'fix_shadow': False, |
|
'use_padding': False |
|
}, |
|
styles=styles |
|
) |
|
|
|
st.title('Multipurpose AI Agent') |
|
|
|
|
|
data_editor_path = os.path.join(os.getenv('HF_HOME'), 'documents') |
|
|
|
if 'df' not in st.session_state: |
|
if os.path.exists(data_editor_path): |
|
st.session_state.df = pd.read_parquet(os.path.join(data_editor_path, 'data_editor.parquet.sz'), engine='pyarrow') |
|
else: |
|
st.session_state.df = pd.DataFrame() |
|
os.mkdir(data_editor_path) |
|
st.session_state.df.to_parquet( |
|
os.path.join( |
|
data_editor_path, |
|
'data_editor.parquet.sz' |
|
), |
|
compression='snappy', |
|
engine='pyarrow' |
|
) |
|
|
|
if 'filter_ids' not in st.session_state: |
|
st.session_state.filter_ids = [] |
|
|
|
def on_change_data_editor(conn, cursor, client, collection_name): |
|
print(f'Check : {st.session_state.key_data_editor}') |
|
|
|
if st.session_state.key_data_editor['deleted_rows']: |
|
st.toast('Wait for deletion to complete...') |
|
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings') |
|
|
|
for deleted_idx in st.session_state.key_data_editor['deleted_rows']: |
|
name = st.session_state.df.loc[deleted_idx, 'document'] |
|
print(f'WHAT IS THAT : {name}') |
|
os.remove(os.path.join(embeddings_path, name + '_ids.npy')) |
|
ids_values = retrieve_ids_value(conn, cursor, name) |
|
|
|
client.delete( |
|
collection_name=collection_name, |
|
points_selector=PointIdsList(points=ids_values) |
|
) |
|
delete_document(conn, cursor, name) |
|
st.session_state.df.drop(deleted_idx) |
|
|
|
st.toast('Deletion Completed !', icon='π') |
|
elif st.session_state.key_data_editor['edited_rows']: |
|
edit_dict = st.session_state.key_data_editor['edited_rows'] |
|
for key, value in edit_dict.items(): |
|
toggle = value['toggle'] |
|
st.session_state.df.loc[key, 'toggle'] = toggle |
|
retrieved_ids = retrieve_ids_value(conn, cursor, st.session_state.df.loc[key, 'document']) |
|
|
|
if not toggle: |
|
st.session_state.filter_ids.extend(retrieved_ids) |
|
else: |
|
st.session_state.filter_ids = [i for i in st.session_state.filter_ids if i not in retrieved_ids] |
|
|
|
if st.session_state.menu_id == 'Documents': |
|
st.session_state.df = st.data_editor( |
|
st.session_state.df, |
|
num_rows="dynamic", |
|
use_container_width=True, |
|
hide_index=True, |
|
on_change=on_change_data_editor, |
|
args=(conn, cursor, client, collection_name), |
|
key='key_data_editor', |
|
column_config={ |
|
'icon': st.column_config.ImageColumn( |
|
'Document' |
|
), |
|
"document": st.column_config.TextColumn( |
|
"Name", |
|
help="Name of the document", |
|
required=True |
|
), |
|
"type": st.column_config.SelectboxColumn( |
|
'File type', |
|
help='The file format extension of this document', |
|
required=True, |
|
options=[ |
|
'Powerpoint', |
|
'Microsoft Word', |
|
'Excel' |
|
] |
|
), |
|
"path": st.column_config.TextColumn( |
|
'Path', |
|
help='Path to the document', |
|
required=False |
|
), |
|
"time": st.column_config.DatetimeColumn( |
|
'Date and hour', |
|
help='When this document has been ingested here for the last time', |
|
format="D MMM YYYY, h:mm a", |
|
required=True |
|
), |
|
"toggle": st.column_config.CheckboxColumn( |
|
'Enable/Disable', |
|
help='Either to enable or disable the ability for the ai to find this document', |
|
required=True, |
|
default=True |
|
) |
|
} |
|
) |
|
|
|
conversations_path = os.path.join(os.getenv('HF_HOME'), 'conversations') |
|
try: |
|
with open(conversations_path, 'rb') as fp: |
|
packed_bytes = fp.read() |
|
conversations: Dict[str, list] = msgpack.unpackb(packed_bytes, raw=False) |
|
except: |
|
conversations = {} |
|
|
|
if st.session_state.menu_id == 'ChatBot': |
|
if 'id_chat' not in st.session_state: |
|
st.session_state.id_chat = 'New Conversation' |
|
|
|
def options_list(conversations: Dict[str, list]): |
|
if st.session_state.id_chat == 'New Conversation': |
|
return [st.session_state.id_chat] + list(conversations.keys()) |
|
else: |
|
return ['New Conversation'] + list(conversations.keys()) |
|
|
|
with st.sidebar: |
|
st.session_state.id_chat = st.selectbox( |
|
label='Choose a conversation', |
|
options=options_list(conversations), |
|
index=0, |
|
placeholder='_', |
|
key='chat_id' |
|
) |
|
|
|
st.session_state.messages = conversations[st.session_state.id_chat] if st.session_state.id_chat != 'New Conversation' else [] |
|
|
|
def update_selectbox_remove(conversations_path, conversations): |
|
conversations.pop(st.session_state.chat_id) |
|
with open(conversations_path, 'wb') as fp: |
|
packed_bytes = msgpack.packb(conversations, use_bin_type=True) |
|
fp.write(packed_bytes) |
|
st.session_state.chat_id = 'New Conversation' |
|
|
|
|
|
st.button( |
|
'Delete Conversation', |
|
use_container_width=True, |
|
disabled=False if st.session_state.id_chat != 'New Conversation' else True, |
|
on_click=update_selectbox_remove, |
|
args=(conversations_path, conversations) |
|
) |
|
|
|
def generate_conv_title(llm): |
|
if st.session_state.chat_id == 'New Conversation': |
|
output = llm.chat( |
|
build_prompt_conv(), |
|
SamplingParams(temperature=0.3, max_tokens=10) |
|
) |
|
print(f'OUTPUT : {output[0].outputs[0].text}') |
|
st.session_state.chat_id = output[0].outputs[0].text |
|
st.session_state.messages = [] |
|
|
|
torch.cuda.empty_cache() |
|
|
|
conversations.update({st.session_state.chat_id: st.session_state.messages}) |
|
with open(conversations_path, 'wb') as fp: |
|
packed_bytes = msgpack.packb(conversations, use_bin_type=True) |
|
fp.write(packed_bytes) |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
if prompt := st.chat_input( |
|
"Message Video Game Assistant", |
|
on_submit=generate_conv_title, |
|
key='user_input', |
|
args=(llm,) |
|
): |
|
st.chat_message("user").markdown(prompt) |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
ai_response = main(prompt, client, collection_name, model, dense_model, sparse_model, "\n".join([f'{msg["role"]}: {msg["content"]}' for msg in st.session_state.messages])) |
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
for chunk in re.split(r'(\s+)', ai_response): |
|
full_response += chunk + " " |
|
time.sleep(0.05) |
|
message_placeholder.write(full_response + 'β') |
|
message_placeholder.write(re.sub('β', '', full_response)) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
conversations.update({st.session_state.id_chat: st.session_state.messages}) |
|
with open(conversations_path, 'wb') as fp: |
|
packed_bytes = msgpack.packb(conversations, use_bin_type=True) |
|
fp.write(packed_bytes) |
|
|
|
if "cached_files" not in st.session_state: |
|
st.session_state.cached_files = [] |
|
|
|
with st.sidebar: |
|
st.divider() |
|
|
|
if 'toggle_docs' not in st.session_state: |
|
st.session_state.toggle_docs = { |
|
'tooltip': 'The AI answer your questions only considering the documents provided', |
|
'display': True |
|
} |
|
|
|
st.toggle( |
|
label="""Enable 'Documents-Only' Mode""", |
|
value=st.session_state.toggle_docs['display'], |
|
on_change=on_change_documents_only, |
|
key="documents_only", |
|
help=st.session_state.toggle_docs['tooltip'] |
|
) |
|
|
|
st.divider() |
|
|
|
url = st.text_input("Scrape an URL link :") |
|
if validators.url(url): |
|
docs = WebBaseLoader(url).load() |
|
print(f'WebBaseLoader: {docs[0].metadata}') |
|
|
|
texts, metadatas = [], [] |
|
for doc in docs: |
|
texts.append(doc.page_content) |
|
del doc.metadata['title'] |
|
del doc.metadata['description'] |
|
del doc.metadata['language'] |
|
metadatas.append(doc.metadata) |
|
|
|
payload_docs, dense_embeddings, sparse_embeddings = chunk_documents(texts, metadatas, dense_model, sparse_model) |
|
|
|
client.upsert( |
|
collection_name, |
|
make_points( |
|
texts, |
|
metadatas, |
|
dense_embeddings, |
|
sparse_embeddings |
|
) |
|
) |
|
|
|
st.toast('URL Content Ingested !', icon='π') |
|
|
|
st.divider() |
|
|
|
uploaded_files = st.file_uploader("Upload a file :", accept_multiple_files=True, type=['pptx', 'ppt']) |
|
|
|
for uploaded_file in uploaded_files: |
|
|
|
if uploaded_file not in st.session_state.cached_files: |
|
st.session_state.cached_files.append(uploaded_file) |
|
|
|
file_name = os.path.basename(uploaded_file.name) |
|
base_name, ext = os.path.splitext(file_name) |
|
|
|
processing_time = datetime.now().strftime('%d %b %Y, %I:%M %p') |
|
|
|
full_path = os.path.realpath(uploaded_file.name) |
|
file_type = ext.lstrip('.') |
|
|
|
d = { |
|
'icon': icon_to_types[file_type][0], |
|
'document': base_name, |
|
'type': icon_to_types[file_type][1], |
|
'path': full_path, |
|
'time': [datetime.strptime(processing_time, '%d %b %Y, %I:%M %p')], |
|
'toggle': True |
|
} |
|
|
|
if (st.session_state.df.empty) or (base_name not in st.session_state.df['document'].tolist()): |
|
st.session_state.df = pd.concat( |
|
[ |
|
st.session_state.df, |
|
pd.DataFrame(data={ |
|
'icon': icon_to_types[file_type][0], |
|
'document': base_name, |
|
'type': icon_to_types[file_type][1], |
|
'path': full_path, |
|
'time': [datetime.strptime(processing_time, '%d %b %Y, %I:%M %p')], |
|
'toggle': True |
|
}) |
|
], |
|
ignore_index=True |
|
) |
|
else: |
|
idx = st.session_state.df.index[st.session_state.df['document']==base_name].tolist()[0] |
|
st.session_state.df.loc[idx] = { |
|
'icon': icon_to_types[file_type][0], |
|
'document': base_name, |
|
'type': icon_to_types[file_type][1], |
|
'path': full_path, |
|
'time': datetime.strptime(processing_time, '%d %b %Y, %I:%M %p'), |
|
'toggle': True |
|
} |
|
|
|
st.session_state.df.to_parquet( |
|
os.path.join( |
|
data_editor_path, |
|
'data_editor.parquet.sz' |
|
), |
|
compression='snappy', |
|
engine='pyarrow' |
|
) |
|
|
|
weakDict, tables = ppt_chunk(uploaded_file, nlp) |
|
documents = weakDict.all_texts() |
|
|
|
dense, tokens_count = asyncio.run(embed_text(dense_model[0], documents)) |
|
sparse = [s for s in sparse_model.embed(documents, 32)] |
|
|
|
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings') |
|
|
|
def generate_unique_id(existing_ids): |
|
while True: |
|
new_id = uuid.uuid4() |
|
while new_id.hex[0] == '0': |
|
new_id = uuid.uuid4() |
|
new_id = new_id.int |
|
if new_id not in existing_ids: |
|
return new_id |
|
|
|
ids = weakDict.all_ids() |
|
|
|
for filename in os.listdir(embeddings_path): |
|
if '_ids' in filename: |
|
list_ids = np.load(os.path.join(embeddings_path, filename), allow_pickle=True).tolist() |
|
|
|
for i, ids_ in enumerate(ids): |
|
if ids_ in list_ids: |
|
ids[i] = generate_unique_id(list_ids) |
|
|
|
payload_docs = [{ 'text': documents[i], 'metadata': metadata } for i, metadata in enumerate(weakDict.all_metadatas())] |
|
|
|
print(f'LEN : {len(ids)}, {len(payload_docs)}, {len(dense)}, {len([SparseVector(indices=s.indices.tolist(), values=s.values.tolist()) for s in sparse])}') |
|
|
|
client.upsert( |
|
collection_name=collection_name, |
|
points=Batch( |
|
ids=ids, |
|
payloads=payload_docs, |
|
vectors={ |
|
'text-dense': dense, |
|
'text-sparse': [SparseVector(indices=s.indices.tolist(), values=s.values.tolist()) for s in sparse] |
|
} |
|
) |
|
) |
|
|
|
payload_path = os.path.join(embeddings_path, base_name + '_payload.msgpack') |
|
dense_path = os.path.join(embeddings_path, base_name + '_dense.npz') |
|
sparse_path = os.path.join(embeddings_path, base_name + '_sparse.npz') |
|
ids_path = os.path.join(embeddings_path, base_name + '_ids.npy') |
|
|
|
with open(payload_path, "wb") as outfile_texts: |
|
packed_payload = msgpack.packb(payload_docs, use_bin_type=True) |
|
outfile_texts.write(packed_payload) |
|
|
|
np.savez_compressed(dense_path, *dense) |
|
max_index = 0 |
|
for embedding in sparse: |
|
if len(embedding.indices) > 0: |
|
max_index = max(max_index, max(embedding.indices)) |
|
|
|
sparse_matrices = [] |
|
|
|
for embedding in sparse: |
|
data = np.array(embedding.values) |
|
indices = np.array(embedding.indices) |
|
indptr = np.array([0, len(data)]) |
|
matrix = csr_matrix((data, indices, indptr), shape=(1, max_index + 1)) |
|
sparse_matrices.append(matrix) |
|
|
|
combined_sparse_matrix = vstack(sparse_matrices) |
|
save_npz(sparse_path, combined_sparse_matrix) |
|
|
|
insert_data(conn, cursor, base_name, ids) |
|
np.save(ids_path, np.array(ids), allow_pickle=True) |
|
|
|
st.toast('Document(s) Ingested !', icon='π') |