Spaces:
Running
Running
Hammaad
commited on
Commit
β’
93bc171
1
Parent(s):
92e5bb7
code refactor part 1 complete need to test
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .dockerignore +28 -28
- .gitignore +137 -135
- CHANGELOG.txt +3 -2
- Dockerfile +41 -41
- LICENSE +21 -21
- app.py +4 -110
- faiss_embeddings_2024/index.faiss +0 -3
- faiss_embeddings_2024/index.pkl +0 -3
- prompts.py +0 -123
- .env.example β reggpt/.env.example +0 -0
- {utils β reggpt}/__init__.py +0 -0
- reggpt/agent/agent.py +33 -0
- reggpt/api/__init__.py +0 -0
- server.py β reggpt/api/router.py +10 -79
- {configs β reggpt/chains}/__init__.py +5 -5
- llmChain.py β reggpt/chains/llmChain.py +99 -96
- {data β reggpt/configs}/__init__.py +5 -5
- reggpt/configs/api.py +28 -0
- config.py β reggpt/configs/config.py +35 -35
- {configs β reggpt/configs}/logger.py +39 -39
- reggpt/controller/__init__.py +0 -0
- qaPipeline.py β reggpt/controller/agent.py +73 -150
- reggpt/controller/router.py +62 -0
- reggpt/data/__init__.py +5 -0
- {data β reggpt/data}/splitted_texts.jsonl +0 -0
- llm.py β reggpt/llms/llm.py +46 -46
- reggpt/memory/__init__.py +0 -0
- conversationBufferWindowMemory.py β reggpt/memory/conversationBufferWindowMemory.py +133 -133
- reggpt/output_parsers/__init__.py +0 -0
- output_parser.py β reggpt/output_parsers/output_parser.py +0 -0
- reggpt/prompts/__init__.py +0 -0
- reggpt/prompts/document_combine.py +7 -0
- reggpt/prompts/general.py +25 -0
- reggpt/prompts/multi_query.py +20 -0
- reggpt/prompts/retrieval.py +33 -0
- reggpt/prompts/router.py +17 -0
- ensemble_retriever.py β reggpt/retriever/ensemble_retriever.py +228 -228
- multi_query_retriever.py β reggpt/retriever/multi_query_retriever.py +253 -253
- reggpt/routers/__init__.py +0 -0
- controller.py β reggpt/routers/controller.py +2 -2
- reggpt/routers/general.py +49 -0
- reggpt/routers/out_of_domain.py +31 -0
- reggpt/routers/qa.py +66 -0
- reggpt/routers/qaPipeline.py +45 -0
- reggpt/schemas/__init__.py +0 -0
- schema.py β reggpt/schemas/schema.py +0 -0
- reggpt/utils/__init__.py +0 -0
- retriever.py β reggpt/utils/retriever.py +136 -136
- {utils β reggpt/utils}/utils.py +40 -40
- reggpt/vectorstores/__init__.py +0 -0
.dockerignore
CHANGED
@@ -1,29 +1,29 @@
|
|
1 |
-
# Ignore node_modules
|
2 |
-
node_modules
|
3 |
-
|
4 |
-
# Ignore logs
|
5 |
-
logs
|
6 |
-
*.log
|
7 |
-
|
8 |
-
# Ignore temporary files
|
9 |
-
tmp
|
10 |
-
*.tmp
|
11 |
-
|
12 |
-
# Ignore build directories
|
13 |
-
dist
|
14 |
-
build
|
15 |
-
|
16 |
-
# Ignore environment variables
|
17 |
-
.env
|
18 |
-
|
19 |
-
# Ignore Docker files
|
20 |
-
Dockerfile
|
21 |
-
docker-compose.yml
|
22 |
-
|
23 |
-
# Ignore IDE specific files
|
24 |
-
.vscode
|
25 |
-
.idea
|
26 |
-
|
27 |
-
# Ignore OS generated files
|
28 |
-
.DS_Store
|
29 |
Thumbs.db
|
|
|
1 |
+
# Ignore node_modules
|
2 |
+
node_modules
|
3 |
+
|
4 |
+
# Ignore logs
|
5 |
+
logs
|
6 |
+
*.log
|
7 |
+
|
8 |
+
# Ignore temporary files
|
9 |
+
tmp
|
10 |
+
*.tmp
|
11 |
+
|
12 |
+
# Ignore build directories
|
13 |
+
dist
|
14 |
+
build
|
15 |
+
|
16 |
+
# Ignore environment variables
|
17 |
+
.env
|
18 |
+
|
19 |
+
# Ignore Docker files
|
20 |
+
Dockerfile
|
21 |
+
docker-compose.yml
|
22 |
+
|
23 |
+
# Ignore IDE specific files
|
24 |
+
.vscode
|
25 |
+
.idea
|
26 |
+
|
27 |
+
# Ignore OS generated files
|
28 |
+
.DS_Store
|
29 |
Thumbs.db
|
.gitignore
CHANGED
@@ -1,135 +1,137 @@
|
|
1 |
-
# Byte-compiled / optimized / DLL files
|
2 |
-
__pycache__/
|
3 |
-
*.py[cod]
|
4 |
-
*$py.class
|
5 |
-
|
6 |
-
# C extensions
|
7 |
-
*.so
|
8 |
-
|
9 |
-
# Distribution / packaging
|
10 |
-
.Python
|
11 |
-
build/
|
12 |
-
develop-eggs/
|
13 |
-
dist/
|
14 |
-
downloads/
|
15 |
-
eggs/
|
16 |
-
.eggs/
|
17 |
-
lib/
|
18 |
-
lib64/
|
19 |
-
parts/
|
20 |
-
sdist/
|
21 |
-
var/
|
22 |
-
wheels/
|
23 |
-
pip-wheel-metadata/
|
24 |
-
share/python-wheels/
|
25 |
-
*.egg-info/
|
26 |
-
.installed.cfg
|
27 |
-
*.egg
|
28 |
-
MANIFEST
|
29 |
-
|
30 |
-
# PyInstaller
|
31 |
-
# Usually these files are written by a python script from a template
|
32 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
-
*.manifest
|
34 |
-
*.spec
|
35 |
-
|
36 |
-
# Installer logs
|
37 |
-
pip-log.txt
|
38 |
-
pip-delete-this-directory.txt
|
39 |
-
|
40 |
-
# Unit test / coverage reports
|
41 |
-
htmlcov/
|
42 |
-
.tox/
|
43 |
-
.nox/
|
44 |
-
.coverage
|
45 |
-
.coverage.*
|
46 |
-
.cache
|
47 |
-
nosetests.xml
|
48 |
-
coverage.xml
|
49 |
-
*.cover
|
50 |
-
*.py,cover
|
51 |
-
.hypothesis/
|
52 |
-
.pytest_cache/
|
53 |
-
|
54 |
-
# Translations
|
55 |
-
*.mo
|
56 |
-
*.pot
|
57 |
-
|
58 |
-
# Django stuff:
|
59 |
-
*.log
|
60 |
-
local_settings.py
|
61 |
-
db.sqlite3
|
62 |
-
db.sqlite3-journal
|
63 |
-
|
64 |
-
# Flask stuff:
|
65 |
-
instance/
|
66 |
-
.webassets-cache
|
67 |
-
|
68 |
-
# Scrapy stuff:
|
69 |
-
.scrapy
|
70 |
-
|
71 |
-
# Sphinx documentation
|
72 |
-
docs/_build/
|
73 |
-
|
74 |
-
# PyBuilder
|
75 |
-
target/
|
76 |
-
|
77 |
-
# Jupyter Notebook
|
78 |
-
.ipynb_checkpoints
|
79 |
-
|
80 |
-
# IPython
|
81 |
-
profile_default/
|
82 |
-
ipython_config.py
|
83 |
-
|
84 |
-
# pyenv
|
85 |
-
.python-version
|
86 |
-
|
87 |
-
# pipenv
|
88 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
-
# install all needed dependencies.
|
92 |
-
#Pipfile.lock
|
93 |
-
|
94 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
-
__pypackages__/
|
96 |
-
|
97 |
-
# Celery stuff
|
98 |
-
celerybeat-schedule
|
99 |
-
celerybeat.pid
|
100 |
-
|
101 |
-
# SageMath parsed files
|
102 |
-
*.sage.py
|
103 |
-
|
104 |
-
# Environments
|
105 |
-
.env
|
106 |
-
.venv
|
107 |
-
env/
|
108 |
-
venv/
|
109 |
-
ENV/
|
110 |
-
env.bak/
|
111 |
-
venv.bak/
|
112 |
-
|
113 |
-
# Spyder project settings
|
114 |
-
.spyderproject
|
115 |
-
.spyproject
|
116 |
-
|
117 |
-
# Rope project settings
|
118 |
-
.ropeproject
|
119 |
-
|
120 |
-
# mkdocs documentation
|
121 |
-
/site
|
122 |
-
|
123 |
-
# mypy
|
124 |
-
.mypy_cache/
|
125 |
-
.dmypy.json
|
126 |
-
dmypy.json
|
127 |
-
|
128 |
-
# Pyre type checker
|
129 |
-
.pyre/
|
130 |
-
|
131 |
-
# testing files generated
|
132 |
-
*.txt.json
|
133 |
-
|
134 |
-
*.ipynb
|
135 |
-
env
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# testing files generated
|
132 |
+
*.txt.json
|
133 |
+
|
134 |
+
*.ipynb
|
135 |
+
env
|
136 |
+
|
137 |
+
.reggpt/vectorstores/faiss_embeddings_2024/
|
CHANGELOG.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
-
2023-11-30 pipeline with only document retrieval
|
2 |
-
2024-08-23 azure app serice , open ai 'gpt4o mini'
|
|
|
|
1 |
+
2023-11-30 pipeline with only document retrieval
|
2 |
+
2024-08-23 azure app serice , open ai 'gpt4o mini'
|
3 |
+
2024-10-16 Code Refactor
|
Dockerfile
CHANGED
@@ -1,42 +1,42 @@
|
|
1 |
-
# Step 1: Use Python 3.11.9 as required
|
2 |
-
FROM python:3.11.9
|
3 |
-
|
4 |
-
# Step 2: Set up environment variables and timezone configuration
|
5 |
-
ENV TZ=Asia/Colombo
|
6 |
-
RUN apt-get update && apt-get install -y libaio1 wget unzip tzdata \
|
7 |
-
&& ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
8 |
-
|
9 |
-
# Step 4: Add a user for running the app (after installations)
|
10 |
-
RUN useradd -m -u 1000 user
|
11 |
-
|
12 |
-
# Step 5: Create the /app directory and set ownership to the new user
|
13 |
-
RUN mkdir -p /app && chown -R user:user /app
|
14 |
-
|
15 |
-
# Step 6: Switch to non-root user after the directory has the right permissions
|
16 |
-
USER user
|
17 |
-
ENV PATH="/home/user/.local/bin:$PATH"
|
18 |
-
|
19 |
-
# Step 7: Set up the working directory for the app
|
20 |
-
WORKDIR /app
|
21 |
-
|
22 |
-
# Step 8: Copy the requirements file and install dependencies
|
23 |
-
COPY --chown=user ./requirements.txt requirements.txt
|
24 |
-
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
25 |
-
|
26 |
-
# Step 9: Install pipenv and handle pipenv environment
|
27 |
-
RUN pip install pipenv
|
28 |
-
COPY --chown=user . /app
|
29 |
-
RUN pipenv install
|
30 |
-
|
31 |
-
# Step 10: Expose the necessary port (7860 for Hugging Face Spaces)
|
32 |
-
EXPOSE 7860
|
33 |
-
|
34 |
-
# Step 11: Set environment variables for the app
|
35 |
-
ENV APP_HOST=0.0.0.0
|
36 |
-
ENV APP_PORT=7860
|
37 |
-
|
38 |
-
# Step 12: Create logs directory (if necessary)
|
39 |
-
RUN mkdir -p /app/logs
|
40 |
-
|
41 |
-
# Step 13: Run the app using Uvicorn, listening on port 7860
|
42 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
1 |
+
# Step 1: Use Python 3.11.9 as required
|
2 |
+
FROM python:3.11.9
|
3 |
+
|
4 |
+
# Step 2: Set up environment variables and timezone configuration
|
5 |
+
ENV TZ=Asia/Colombo
|
6 |
+
RUN apt-get update && apt-get install -y libaio1 wget unzip tzdata \
|
7 |
+
&& ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
8 |
+
|
9 |
+
# Step 4: Add a user for running the app (after installations)
|
10 |
+
RUN useradd -m -u 1000 user
|
11 |
+
|
12 |
+
# Step 5: Create the /app directory and set ownership to the new user
|
13 |
+
RUN mkdir -p /app && chown -R user:user /app
|
14 |
+
|
15 |
+
# Step 6: Switch to non-root user after the directory has the right permissions
|
16 |
+
USER user
|
17 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
18 |
+
|
19 |
+
# Step 7: Set up the working directory for the app
|
20 |
+
WORKDIR /app
|
21 |
+
|
22 |
+
# Step 8: Copy the requirements file and install dependencies
|
23 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
24 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
25 |
+
|
26 |
+
# Step 9: Install pipenv and handle pipenv environment
|
27 |
+
RUN pip install pipenv
|
28 |
+
COPY --chown=user . /app
|
29 |
+
RUN pipenv install
|
30 |
+
|
31 |
+
# Step 10: Expose the necessary port (7860 for Hugging Face Spaces)
|
32 |
+
EXPOSE 7860
|
33 |
+
|
34 |
+
# Step 11: Set environment variables for the app
|
35 |
+
ENV APP_HOST=0.0.0.0
|
36 |
+
ENV APP_PORT=7860
|
37 |
+
|
38 |
+
# Step 12: Create logs directory (if necessary)
|
39 |
+
RUN mkdir -p /app/logs
|
40 |
+
|
41 |
+
# Step 13: Run the app using Uvicorn, listening on port 7860
|
42 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
CHANGED
@@ -1,22 +1,22 @@
|
|
1 |
-
License
|
2 |
-
|
3 |
-
Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
|
4 |
-
All Rights Reserved
|
5 |
-
|
6 |
-
This source code is protected under international copyright law. All rights
|
7 |
-
reserved and protected by the copyright holders.
|
8 |
-
This file is confidential and only available to authorized individuals with the
|
9 |
-
permission of the copyright holders.
|
10 |
-
|
11 |
-
Permission is hereby granted, to {User}. for testing and development purposes.
|
12 |
-
|
13 |
-
The above copyright notice and this permission notice shall be included in all
|
14 |
-
copies or substantial portions of the Software.
|
15 |
-
|
16 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
SOFTWARE.
|
|
|
1 |
+
License
|
2 |
+
|
3 |
+
Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
|
4 |
+
All Rights Reserved
|
5 |
+
|
6 |
+
This source code is protected under international copyright law. All rights
|
7 |
+
reserved and protected by the copyright holders.
|
8 |
+
This file is confidential and only available to authorized individuals with the
|
9 |
+
permission of the copyright holders.
|
10 |
+
|
11 |
+
Permission is hereby granted, to {User}. for testing and development purposes.
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
SOFTWARE.
|
app.py
CHANGED
@@ -15,20 +15,16 @@
|
|
15 |
"""
|
16 |
|
17 |
import os
|
18 |
-
import time
|
19 |
-
import sys
|
20 |
import logging
|
21 |
-
import datetime
|
22 |
import uvicorn
|
23 |
from dotenv import load_dotenv
|
24 |
|
25 |
-
from fastapi import FastAPI
|
26 |
-
from fastapi import HTTPException, status
|
27 |
from fastapi.middleware.cors import CORSMiddleware
|
28 |
|
29 |
-
from
|
30 |
-
from controller import get_QA_Answers, get_avaliable_models
|
31 |
|
|
|
32 |
|
33 |
def filer():
|
34 |
return "logs/log"
|
@@ -55,110 +51,11 @@ load_dotenv()
|
|
55 |
host = os.environ.get("APP_HOST")
|
56 |
port = int(os.environ.get("APP_PORT"))
|
57 |
|
58 |
-
|
59 |
-
class ChatAPI:
|
60 |
-
|
61 |
-
def __init__(self):
|
62 |
-
self.router = APIRouter()
|
63 |
-
self.router.add_api_route("/api/v1/health", self.hello, methods=["GET"])
|
64 |
-
self.router.add_api_route("/api/v1/models", self.avaliable_models, methods=["GET"])
|
65 |
-
self.router.add_api_route(
|
66 |
-
"/api/v1/login", self.login, methods=["POST"], response_model=UserModel
|
67 |
-
)
|
68 |
-
self.router.add_api_route("/api/v1/chat", self.chat, methods=["POST"])
|
69 |
-
|
70 |
-
async def hello(self):
|
71 |
-
return "Hello there!"
|
72 |
-
|
73 |
-
async def avaliable_models(self):
|
74 |
-
logger.info("getting avaliable models")
|
75 |
-
models = get_avaliable_models()
|
76 |
-
|
77 |
-
if not models:
|
78 |
-
logger.exception("models not found")
|
79 |
-
raise HTTPException(
|
80 |
-
status_code=status.HTTP_404_NOT_FOUND, detail="models not found"
|
81 |
-
)
|
82 |
-
|
83 |
-
return models
|
84 |
-
|
85 |
-
async def login(self, login_request: LoginRequest):
|
86 |
-
logger.info(f"username password: {login_request} ")
|
87 |
-
# Dummy user data for demonstration (normally, you'd use a database)
|
88 |
-
dummy_users_db = {
|
89 |
-
"john_doe": {
|
90 |
-
"userId": 1,
|
91 |
-
"firstName": "John",
|
92 |
-
"lastName": "Doe",
|
93 |
-
"userName": "john_doe",
|
94 |
-
"password": "password", # Normally, passwords would be hashed and stored securely
|
95 |
-
"token": "dummy_token_123", # In a real scenario, this would be a JWT or another kind of token
|
96 |
-
}
|
97 |
-
}
|
98 |
-
# Fetch user by username
|
99 |
-
# user = dummy_users_db.get(login_request.username)
|
100 |
-
user = dummy_users_db.get("john_doe")
|
101 |
-
# Validate user credentials
|
102 |
-
if not user or user["password"] != login_request.password:
|
103 |
-
raise HTTPException(status_code=401, detail="Invalid username or password")
|
104 |
-
|
105 |
-
# Return the user model without the password
|
106 |
-
return UserModel(
|
107 |
-
userId=user["userId"],
|
108 |
-
firstName=user["firstName"],
|
109 |
-
lastName=user["lastName"],
|
110 |
-
userName=user["userName"],
|
111 |
-
token=user["token"],
|
112 |
-
)
|
113 |
-
|
114 |
-
async def chat(
|
115 |
-
self, userQuery: UserQuery
|
116 |
-
): #:UserQuery):# -> ResponseModel: #chat: QueryModel): # -> ResponseModel:
|
117 |
-
"""Makes query to doc store via Langchain pipeline.
|
118 |
-
|
119 |
-
:param chat.: question, model, dataset location, history of the chat.
|
120 |
-
:type chat: QueryModel
|
121 |
-
"""
|
122 |
-
logger.info(f"userQuery: {userQuery} ")
|
123 |
-
|
124 |
-
try:
|
125 |
-
start = time.time()
|
126 |
-
res = get_QA_Answers(userQuery)
|
127 |
-
logger.info(
|
128 |
-
f"-------------------------- answer: {res} -------------------------- "
|
129 |
-
)
|
130 |
-
# return res
|
131 |
-
end = time.time()
|
132 |
-
logger.info(
|
133 |
-
f"-------------------------- Server process (took {round(end - start, 2)} s.) \n: {res}"
|
134 |
-
)
|
135 |
-
print(
|
136 |
-
f" \n -------------------------- Server process (took {round(end - start, 2)} s.) ------------------------- \n"
|
137 |
-
)
|
138 |
-
# return res
|
139 |
-
return {
|
140 |
-
"user_input": userQuery.user_question,
|
141 |
-
"bot_response": res["bot_response"],
|
142 |
-
"format_data": res["format_data"],
|
143 |
-
}
|
144 |
-
|
145 |
-
|
146 |
-
except HTTPException as e:
|
147 |
-
logger.exception(e)
|
148 |
-
raise e
|
149 |
-
|
150 |
-
except Exception as e:
|
151 |
-
logger.exception(e)
|
152 |
-
raise HTTPException(status_code=400, detail=f"Error : {e}")
|
153 |
-
|
154 |
-
|
155 |
# initialize API
|
156 |
-
app = FastAPI(title=
|
157 |
api = ChatAPI()
|
158 |
app.include_router(api.router)
|
159 |
|
160 |
-
# origins = ['http://localhost:8000','http://192.168.10.100:8000']
|
161 |
-
|
162 |
app.add_middleware(
|
163 |
CORSMiddleware,
|
164 |
allow_origins=["*"], # origins,
|
@@ -169,9 +66,6 @@ app.add_middleware(
|
|
169 |
|
170 |
if __name__ == "__main__":
|
171 |
|
172 |
-
host = "0.0.0.0"
|
173 |
-
port = 8000
|
174 |
-
|
175 |
# config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
|
176 |
config = uvicorn.Config("server:app", host=host, port=port)
|
177 |
server = uvicorn.Server(config)
|
|
|
15 |
"""
|
16 |
|
17 |
import os
|
|
|
|
|
18 |
import logging
|
|
|
19 |
import uvicorn
|
20 |
from dotenv import load_dotenv
|
21 |
|
22 |
+
from fastapi import FastAPI
|
|
|
23 |
from fastapi.middleware.cors import CORSMiddleware
|
24 |
|
25 |
+
from reggpt.configs.api import API_TITLE, API_VERSION, API_DESCRIPTION
|
|
|
26 |
|
27 |
+
from reggpt.api.router import ChatAPI
|
28 |
|
29 |
def filer():
|
30 |
return "logs/log"
|
|
|
51 |
host = os.environ.get("APP_HOST")
|
52 |
port = int(os.environ.get("APP_PORT"))
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# initialize API
|
55 |
+
app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION)
|
56 |
api = ChatAPI()
|
57 |
app.include_router(api.router)
|
58 |
|
|
|
|
|
59 |
app.add_middleware(
|
60 |
CORSMiddleware,
|
61 |
allow_origins=["*"], # origins,
|
|
|
66 |
|
67 |
if __name__ == "__main__":
|
68 |
|
|
|
|
|
|
|
69 |
# config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
|
70 |
config = uvicorn.Config("server:app", host=host, port=port)
|
71 |
server = uvicorn.Server(config)
|
faiss_embeddings_2024/index.faiss
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a3087f02172887cbf8bfb0fb3b371843548619c2a2873fdf4629339e2031a2c1
|
3 |
-
size 10895405
|
|
|
|
|
|
|
|
faiss_embeddings_2024/index.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3da917c3758e2bfe0aedbd050199f4c80ec372d5b0349b49126b790fb1757db9
|
3 |
-
size 3935715
|
|
|
|
|
|
|
|
prompts.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
/*************************************************************************
|
3 |
-
*
|
4 |
-
* CONFIDENTIAL
|
5 |
-
* __________________
|
6 |
-
*
|
7 |
-
* Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
|
8 |
-
* All Rights Reserved
|
9 |
-
*
|
10 |
-
* Author : Theekshana Samaradiwakara
|
11 |
-
* Description :Python Backend API to chat with private data
|
12 |
-
* CreatedDate : 14/11/2023
|
13 |
-
* LastModifiedDate : 19/03/2024
|
14 |
-
*************************************************************************/
|
15 |
-
"""
|
16 |
-
|
17 |
-
from langchain.prompts import PromptTemplate
|
18 |
-
|
19 |
-
# multi query prompt
|
20 |
-
MULTY_QUERY_PROMPT = PromptTemplate(
|
21 |
-
input_variables=["question"],
|
22 |
-
template="""You are an AI language model assistant. Your task is to generate three
|
23 |
-
different versions of the given user question to retrieve relevant documents from a vector
|
24 |
-
database. By generating multiple perspectives on the user question, your goal is to help
|
25 |
-
the user overcome some of the limitations of the distance-based similarity search.
|
26 |
-
Provide these alternative questions separated by newlines.
|
27 |
-
|
28 |
-
Dont add anything extra before or after to the 3 questions. Just give 3 lines with 3 questions.
|
29 |
-
Just provide 3 lines having 3 questions only.
|
30 |
-
Answer should be in following format.
|
31 |
-
|
32 |
-
1. alternative question 1
|
33 |
-
2. alternative question 2
|
34 |
-
3. alternative question 3
|
35 |
-
|
36 |
-
Original question: {question}""",
|
37 |
-
)
|
38 |
-
|
39 |
-
#retrieval prompt
|
40 |
-
B_INST, E_INST = "[INST]", "[/INST]"
|
41 |
-
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
42 |
-
|
43 |
-
retrieval_qa_template = (
|
44 |
-
"""<<SYS>>
|
45 |
-
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
46 |
-
|
47 |
-
please answer the question based on the chat history provided below. Answer should be short and simple as possible and on to the point.
|
48 |
-
<chat history>: {chat_history}
|
49 |
-
|
50 |
-
If the question is related to welcomes and greetings answer accordingly.
|
51 |
-
|
52 |
-
Else If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
|
53 |
-
please answer the question based only on the information provided in following central bank documents published in various years.
|
54 |
-
The published year is mentioned as the metadata 'year' of each source document.
|
55 |
-
Please notice that content of a one document of a past year can updated by a new document from a recent year.
|
56 |
-
Always try to answer with latest information and mention the year which information extracted.
|
57 |
-
If you dont know the answer say you dont know, dont try to makeup answers. Dont add any extra details that is not mentioned in the context.
|
58 |
-
|
59 |
-
<</SYS>>
|
60 |
-
|
61 |
-
[INST]
|
62 |
-
<DOCUMENTS>
|
63 |
-
{context}
|
64 |
-
</DOCUMENTS>
|
65 |
-
|
66 |
-
Question : {question}[/INST]"""
|
67 |
-
)
|
68 |
-
|
69 |
-
|
70 |
-
retrieval_qa_chain_prompt = PromptTemplate(
|
71 |
-
input_variables=["question", "context", "chat_history"],
|
72 |
-
template=retrieval_qa_template
|
73 |
-
)
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
#document combine prompt
|
78 |
-
document_combine_prompt = PromptTemplate(
|
79 |
-
input_variables=["source","year", "page","page_content"],
|
80 |
-
template=
|
81 |
-
"""<doc> source: {source}, year: {year}, page: {page}, page content: {page_content} </doc>"""
|
82 |
-
)
|
83 |
-
|
84 |
-
|
85 |
-
router_template_Mixtral_V0= """
|
86 |
-
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
87 |
-
|
88 |
-
If a user asks a question you have to classify it to following 3 types Relevant, Greeting, Other.
|
89 |
-
|
90 |
-
"Relevantβ: If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
|
91 |
-
"Greetingβ: If the question is a greeting like good morning, hi my name is., thank you or General Question ask about the AI assistance of a company boardpac.
|
92 |
-
"Otherβ: If the question is not related to research papers.
|
93 |
-
|
94 |
-
Give the correct name of question type. If you are not sure return "Not Sure" instead.
|
95 |
-
|
96 |
-
Question : {question}
|
97 |
-
"""
|
98 |
-
router_prompt=PromptTemplate.from_template(router_template_Mixtral_V0)
|
99 |
-
|
100 |
-
|
101 |
-
general_qa_template_Mixtral_V0= """
|
102 |
-
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
103 |
-
you can answer Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations related question .
|
104 |
-
|
105 |
-
Is the provided question below a greeting? First, evaluate whether the input resembles a typical greeting or not.
|
106 |
-
|
107 |
-
Greetings are used to say 'hello' and 'how are you?' and to say 'goodbye' and 'nice speaking with you.' and 'hi, I'm (user's name).'
|
108 |
-
Greetings are words used when we want to introduce ourselves to others and when we want to find out how someone is feeling.
|
109 |
-
|
110 |
-
You can only reply to the user's greetings.
|
111 |
-
If the question is a greeting, reply accordingly as the AI assistant of company boardpac.
|
112 |
-
If the question is not related to greetings and research papers, say that it is out of your domain.
|
113 |
-
If the question is not clear enough, ask for more details and don't try to make up answers.
|
114 |
-
|
115 |
-
Answer should be polite, short, and simple.
|
116 |
-
|
117 |
-
Additionally, it's important to note that this AI assistant has access to an internal collection of research papers, and answers can be provided using the information available in those CBSL Dataset.
|
118 |
-
|
119 |
-
Question: {question}
|
120 |
-
"""
|
121 |
-
|
122 |
-
general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template_Mixtral_V0)
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.env.example β reggpt/.env.example
RENAMED
File without changes
|
{utils β reggpt}/__init__.py
RENAMED
File without changes
|
reggpt/agent/agent.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logger = logging.getLogger(__name__)
|
3 |
+
from fastapi import HTTPException
|
4 |
+
import time
|
5 |
+
from reggpt.routers.qaPipeline import run_router_chain, chain_selector
|
6 |
+
|
7 |
+
def run_agent(query):
|
8 |
+
try:
|
9 |
+
logger.info(f"run_agent : Question: {query}")
|
10 |
+
print(f"---------------- run_agent : Question: {query} ----------------")
|
11 |
+
# Get the answer from the chain
|
12 |
+
start = time.time()
|
13 |
+
chain_type = run_router_chain(query)
|
14 |
+
res = chain_selector(chain_type,query)
|
15 |
+
end = time.time()
|
16 |
+
|
17 |
+
# log the result
|
18 |
+
logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
|
19 |
+
print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
|
20 |
+
|
21 |
+
return res
|
22 |
+
|
23 |
+
except HTTPException as e:
|
24 |
+
print('HTTPException')
|
25 |
+
print(e)
|
26 |
+
logger.exception(e)
|
27 |
+
raise e
|
28 |
+
|
29 |
+
except Exception as e:
|
30 |
+
print('Exception')
|
31 |
+
print(e)
|
32 |
+
logger.exception(e)
|
33 |
+
raise e
|
reggpt/api/__init__.py
ADDED
File without changes
|
server.py β reggpt/api/router.py
RENAMED
@@ -1,71 +1,28 @@
|
|
1 |
-
"""
|
2 |
-
/*************************************************************************
|
3 |
-
*
|
4 |
-
* CONFIDENTIAL
|
5 |
-
* __________________
|
6 |
-
*
|
7 |
-
* Copyright (2023-2025) AI Labs, IronOne Technologies, LLC
|
8 |
-
* All Rights Reserved
|
9 |
-
*
|
10 |
-
* Author : Theekshana Samaradiwakara
|
11 |
-
* Description :Python Backend API to chat with private data
|
12 |
-
* CreatedDate : 14/11/2023
|
13 |
-
* LastModifiedDate : 15/10/2024
|
14 |
-
*************************************************************************/
|
15 |
-
"""
|
16 |
-
|
17 |
-
import os
|
18 |
import time
|
19 |
-
import sys
|
20 |
-
import logging
|
21 |
-
import datetime
|
22 |
-
import uvicorn
|
23 |
-
from dotenv import load_dotenv
|
24 |
-
|
25 |
-
from fastapi import FastAPI, APIRouter, HTTPException, status
|
26 |
-
from fastapi import HTTPException, status
|
27 |
-
from fastapi.middleware.cors import CORSMiddleware
|
28 |
-
|
29 |
-
from schema import UserQuery, ResponseModel, Document, LoginRequest, UserModel
|
30 |
-
from controller import get_QA_Answers, get_avaliable_models
|
31 |
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
# today = datetime.datetime.today()
|
36 |
-
# log_filename = f"logs/{today.year}-{today.month:02d}-{today.day:02d}.log"
|
37 |
-
# return log_filename
|
38 |
-
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
file_handler.setLevel(logging.INFO)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
format="%(asctime)s %(levelname)s (%(name)s) : %(message)s",
|
47 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
48 |
-
handlers=[file_handler],
|
49 |
-
force=True,
|
50 |
-
)
|
51 |
|
52 |
logger = logging.getLogger(__name__)
|
53 |
|
54 |
-
load_dotenv()
|
55 |
-
host = os.environ.get("APP_HOST")
|
56 |
-
port = int(os.environ.get("APP_PORT"))
|
57 |
-
|
58 |
|
59 |
class ChatAPI:
|
60 |
|
61 |
def __init__(self):
|
62 |
self.router = APIRouter()
|
63 |
-
self.router.add_api_route(
|
64 |
-
self.router.add_api_route(
|
65 |
self.router.add_api_route(
|
66 |
-
|
67 |
)
|
68 |
-
self.router.add_api_route(
|
69 |
|
70 |
async def hello(self):
|
71 |
return "Hello there!"
|
@@ -151,29 +108,3 @@ class ChatAPI:
|
|
151 |
logger.exception(e)
|
152 |
raise HTTPException(status_code=400, detail=f"Error : {e}")
|
153 |
|
154 |
-
|
155 |
-
# initialize API
|
156 |
-
app = FastAPI(title="Boardpac chatbot API")
|
157 |
-
api = ChatAPI()
|
158 |
-
app.include_router(api.router)
|
159 |
-
|
160 |
-
# origins = ['http://localhost:8000','http://192.168.10.100:8000']
|
161 |
-
|
162 |
-
app.add_middleware(
|
163 |
-
CORSMiddleware,
|
164 |
-
allow_origins=["*"], # origins,
|
165 |
-
allow_credentials=True,
|
166 |
-
allow_methods=["*"],
|
167 |
-
allow_headers=["*"],
|
168 |
-
)
|
169 |
-
|
170 |
-
if __name__ == "__main__":
|
171 |
-
|
172 |
-
host = "0.0.0.0"
|
173 |
-
port = 8000
|
174 |
-
|
175 |
-
# config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
|
176 |
-
config = uvicorn.Config("server:app", host=host, port=port)
|
177 |
-
server = uvicorn.Server(config)
|
178 |
-
server.run()
|
179 |
-
# uvicorn.run(app)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
+
from fastapi import APIRouter, HTTPException, status
|
5 |
+
from fastapi import HTTPException, status
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
from reggpt.schemas.schema import UserQuery, LoginRequest, UserModel
|
8 |
+
from reggpt.routers.controller import get_QA_Answers, get_avaliable_models
|
|
|
9 |
|
10 |
+
from reggpt.configs.api import API_ENDPOINT_LOGIN,API_ENDPOINT_CHAT, API_ENDPOINT_HEALTH, API_ENDPOINT_MODEL
|
11 |
+
import logging
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
|
|
|
|
|
|
15 |
|
16 |
class ChatAPI:
|
17 |
|
18 |
def __init__(self):
|
19 |
self.router = APIRouter()
|
20 |
+
self.router.add_api_route(API_ENDPOINT_HEALTH, self.hello, methods=["GET"])
|
21 |
+
self.router.add_api_route(API_ENDPOINT_MODEL, self.avaliable_models, methods=["GET"])
|
22 |
self.router.add_api_route(
|
23 |
+
API_ENDPOINT_LOGIN, self.login, methods=["POST"], response_model=UserModel
|
24 |
)
|
25 |
+
self.router.add_api_route(API_ENDPOINT_CHAT, self.chat, methods=["POST"])
|
26 |
|
27 |
async def hello(self):
|
28 |
return "Hello there!"
|
|
|
108 |
logger.exception(e)
|
109 |
raise HTTPException(status_code=400, detail=f"Error : {e}")
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{configs β reggpt/chains}/__init__.py
RENAMED
@@ -1,5 +1,5 @@
|
|
1 |
-
# import os
|
2 |
-
# import sys
|
3 |
-
|
4 |
-
# if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
-
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
1 |
+
# import os
|
2 |
+
# import sys
|
3 |
+
|
4 |
+
# if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
+
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
llmChain.py β reggpt/chains/llmChain.py
RENAMED
@@ -1,96 +1,99 @@
|
|
1 |
-
"""
|
2 |
-
/*************************************************************************
|
3 |
-
*
|
4 |
-
* CONFIDENTIAL
|
5 |
-
* __________________
|
6 |
-
*
|
7 |
-
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
-
* All Rights Reserved
|
9 |
-
*
|
10 |
-
* Author : Theekshana Samaradiwakara
|
11 |
-
* Description :Python Backend API to chat with private data
|
12 |
-
* CreatedDate : 14/11/2023
|
13 |
-
* LastModifiedDate : 18/03/2024
|
14 |
-
*************************************************************************/
|
15 |
-
"""
|
16 |
-
|
17 |
-
import os
|
18 |
-
import logging
|
19 |
-
logger = logging.getLogger(__name__)
|
20 |
-
from dotenv import load_dotenv
|
21 |
-
|
22 |
-
load_dotenv()
|
23 |
-
|
24 |
-
verbose = os.environ.get('VERBOSE')
|
25 |
-
|
26 |
-
from llm import get_model
|
27 |
-
from langchain.chains import ConversationalRetrievalChain
|
28 |
-
# from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
29 |
-
|
30 |
-
# from langchain.prompts import PromptTemplate
|
31 |
-
from langchain.chains import LLMChain
|
32 |
-
|
33 |
-
from prompts import
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import logging
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
|
22 |
+
load_dotenv()
|
23 |
+
|
24 |
+
verbose = os.environ.get('VERBOSE')
|
25 |
+
|
26 |
+
from reggpt.llms.llm import get_model
|
27 |
+
from langchain.chains import ConversationalRetrievalChain
|
28 |
+
# from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
29 |
+
|
30 |
+
# from langchain.prompts import PromptTemplate
|
31 |
+
from langchain.chains import LLMChain
|
32 |
+
from reggpt.prompts import document_combine as document_combine_prompt
|
33 |
+
from reggpt.prompts import retrieval as retrieval_qa_chain_prompt
|
34 |
+
from reggpt.prompts import general as general_qa_chain_prompt
|
35 |
+
from reggpt.prompts import router as router_prompt
|
36 |
+
|
37 |
+
|
38 |
+
def get_qa_chain(model_type,retriever):
|
39 |
+
logger.info("creating qa_chain")
|
40 |
+
|
41 |
+
try:
|
42 |
+
qa_llm = get_model(model_type)
|
43 |
+
|
44 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
45 |
+
llm=qa_llm,
|
46 |
+
chain_type="stuff",
|
47 |
+
retriever = retriever,
|
48 |
+
# retriever = self.retriever(search_kwargs={"k": target_source_chunks}
|
49 |
+
return_source_documents= True,
|
50 |
+
get_chat_history=lambda h : h,
|
51 |
+
combine_docs_chain_kwargs={
|
52 |
+
"prompt": retrieval_qa_chain_prompt,
|
53 |
+
"document_prompt": document_combine_prompt,
|
54 |
+
},
|
55 |
+
verbose=True,
|
56 |
+
# memory=memory,
|
57 |
+
)
|
58 |
+
|
59 |
+
logger.info("qa_chain created")
|
60 |
+
return qa_chain
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
msg=f"Error : {e}"
|
64 |
+
logger.exception(msg)
|
65 |
+
raise e
|
66 |
+
|
67 |
+
|
68 |
+
def get_general_qa_chain(model_type):
|
69 |
+
logger.info("creating general_qa_chain")
|
70 |
+
|
71 |
+
try:
|
72 |
+
general_qa_llm = get_model(model_type)
|
73 |
+
general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt)
|
74 |
+
|
75 |
+
logger.info("general_qa_chain created")
|
76 |
+
return general_qa_chain
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
msg=f"Error : {e}"
|
80 |
+
logger.exception(msg)
|
81 |
+
raise e
|
82 |
+
|
83 |
+
|
84 |
+
def get_router_chain(model_type):
|
85 |
+
logger.info("creating router_chain")
|
86 |
+
|
87 |
+
try:
|
88 |
+
router_llm = get_model(model_type)
|
89 |
+
router_chain = LLMChain(llm=router_llm, prompt=router_prompt)
|
90 |
+
|
91 |
+
logger.info("router_chain created")
|
92 |
+
return router_chain
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
msg=f"Error : {e}"
|
96 |
+
logger.exception(msg)
|
97 |
+
raise e
|
98 |
+
|
99 |
+
|
{data β reggpt/configs}/__init__.py
RENAMED
@@ -1,5 +1,5 @@
|
|
1 |
-
# import os
|
2 |
-
# import sys
|
3 |
-
|
4 |
-
# if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
-
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
1 |
+
# import os
|
2 |
+
# import sys
|
3 |
+
|
4 |
+
# if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
+
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
reggpt/configs/api.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Authors : Hammaad Rizwan, Theekshana Samaradiwakara
|
11 |
+
* Description : API Configurations
|
12 |
+
* CreatedDate : 14/10/2024
|
13 |
+
* LastModifiedDate : 15/10/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
API_TITLE = "RegGPT Back End v1"
|
17 |
+
API_VERSION = "0.1.0"
|
18 |
+
API_DESCRIPTION = "API_DESC"
|
19 |
+
|
20 |
+
API_ENDPOINT_PREFIX = "/api/v1"
|
21 |
+
API_DOCS_URL = f"{API_ENDPOINT_PREFIX}/docs"
|
22 |
+
API_REDOC_URL = f"{API_ENDPOINT_PREFIX}/redoc"
|
23 |
+
API_OPENAPI_URL = f"{API_ENDPOINT_PREFIX}/openapi.json"
|
24 |
+
|
25 |
+
API_ENDPOINT_HEALTH = f"{API_ENDPOINT_PREFIX}/health"
|
26 |
+
API_ENDPOINT_CHAT = f"{API_ENDPOINT_PREFIX}/chat"
|
27 |
+
API_ENDPOINT_MODEL = f"{API_ENDPOINT_PREFIX}/models"
|
28 |
+
API_ENDPOINT_LOGIN = f"{API_ENDPOINT_PREFIX}/login"
|
config.py β reggpt/configs/config.py
RENAMED
@@ -1,36 +1,36 @@
|
|
1 |
-
AVALIABLE_MODELS=[
|
2 |
-
{
|
3 |
-
"id":"gpt-4o-mini",
|
4 |
-
"model_name":"openai/gpt-4o-mini",
|
5 |
-
"description":"gpt-4o-mini model from openai"
|
6 |
-
}
|
7 |
-
]
|
8 |
-
|
9 |
-
MODELS={
|
10 |
-
"DEFAULT":"openai",
|
11 |
-
"gpt-4o-mini":"openai",
|
12 |
-
|
13 |
-
}
|
14 |
-
|
15 |
-
DATASETS={
|
16 |
-
"DEFAULT":"faiss",
|
17 |
-
"a":"A",
|
18 |
-
"b":"B",
|
19 |
-
"c":"C"
|
20 |
-
|
21 |
-
}
|
22 |
-
|
23 |
-
MEMORY_WINDOW_K = 1
|
24 |
-
|
25 |
-
QA_MODEL_TYPE = "openai"
|
26 |
-
GENERAL_QA_MODEL_TYPE = "openai"
|
27 |
-
ROUTER_MODEL_TYPE = "openai"
|
28 |
-
Multi_Query_MODEL_TYPE = "openai"
|
29 |
-
|
30 |
-
|
31 |
-
ANSWER_TYPES = [
|
32 |
-
"relevant",
|
33 |
-
"greeting",
|
34 |
-
"other",
|
35 |
-
"not sure",
|
36 |
]
|
|
|
1 |
+
AVALIABLE_MODELS=[
|
2 |
+
{
|
3 |
+
"id":"gpt-4o-mini",
|
4 |
+
"model_name":"openai/gpt-4o-mini",
|
5 |
+
"description":"gpt-4o-mini model from openai"
|
6 |
+
}
|
7 |
+
]
|
8 |
+
|
9 |
+
MODELS={
|
10 |
+
"DEFAULT":"openai",
|
11 |
+
"gpt-4o-mini":"openai",
|
12 |
+
|
13 |
+
}
|
14 |
+
|
15 |
+
DATASETS={
|
16 |
+
"DEFAULT":"faiss",
|
17 |
+
"a":"A",
|
18 |
+
"b":"B",
|
19 |
+
"c":"C"
|
20 |
+
|
21 |
+
}
|
22 |
+
|
23 |
+
MEMORY_WINDOW_K = 1
|
24 |
+
|
25 |
+
QA_MODEL_TYPE = "openai"
|
26 |
+
GENERAL_QA_MODEL_TYPE = "openai"
|
27 |
+
ROUTER_MODEL_TYPE = "openai"
|
28 |
+
Multi_Query_MODEL_TYPE = "openai"
|
29 |
+
|
30 |
+
|
31 |
+
ANSWER_TYPES = [
|
32 |
+
"relevant",
|
33 |
+
"greeting",
|
34 |
+
"other",
|
35 |
+
"not sure",
|
36 |
]
|
{configs β reggpt/configs}/logger.py
RENAMED
@@ -1,40 +1,40 @@
|
|
1 |
-
import logging
|
2 |
-
import time
|
3 |
-
# from functools import wraps
|
4 |
-
|
5 |
-
logger = logging.getLogger(__name__)
|
6 |
-
|
7 |
-
stream_handler = logging.StreamHandler()
|
8 |
-
log_filename = "output.log"
|
9 |
-
file_handler = logging.FileHandler(filename=log_filename)
|
10 |
-
handlers = [stream_handler, file_handler]
|
11 |
-
|
12 |
-
|
13 |
-
class TimeFilter(logging.Filter):
|
14 |
-
def filter(self, record):
|
15 |
-
return "Running" in record.getMessage()
|
16 |
-
|
17 |
-
|
18 |
-
logger.addFilter(TimeFilter())
|
19 |
-
|
20 |
-
# Configure the logging module
|
21 |
-
logging.basicConfig(
|
22 |
-
level=logging.INFO,
|
23 |
-
format="%(name)s %(asctime)s - %(levelname)s - %(message)s",
|
24 |
-
handlers=handlers,
|
25 |
-
)
|
26 |
-
|
27 |
-
|
28 |
-
def time_logger(func):
|
29 |
-
"""Decorator function to log time taken by any function."""
|
30 |
-
|
31 |
-
# @wraps(func)
|
32 |
-
def wrapper(*args, **kwargs):
|
33 |
-
start_time = time.time() # Start time before function execution
|
34 |
-
result = func(*args, **kwargs) # Function execution
|
35 |
-
end_time = time.time() # End time after function execution
|
36 |
-
execution_time = end_time - start_time # Calculate execution time
|
37 |
-
logger.info(f"Running {func.__name__}: --- {execution_time} seconds ---")
|
38 |
-
return result
|
39 |
-
|
40 |
return wrapper
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
# from functools import wraps
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
stream_handler = logging.StreamHandler()
|
8 |
+
log_filename = "output.log"
|
9 |
+
file_handler = logging.FileHandler(filename=log_filename)
|
10 |
+
handlers = [stream_handler, file_handler]
|
11 |
+
|
12 |
+
|
13 |
+
class TimeFilter(logging.Filter):
|
14 |
+
def filter(self, record):
|
15 |
+
return "Running" in record.getMessage()
|
16 |
+
|
17 |
+
|
18 |
+
logger.addFilter(TimeFilter())
|
19 |
+
|
20 |
+
# Configure the logging module
|
21 |
+
logging.basicConfig(
|
22 |
+
level=logging.INFO,
|
23 |
+
format="%(name)s %(asctime)s - %(levelname)s - %(message)s",
|
24 |
+
handlers=handlers,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def time_logger(func):
|
29 |
+
"""Decorator function to log time taken by any function."""
|
30 |
+
|
31 |
+
# @wraps(func)
|
32 |
+
def wrapper(*args, **kwargs):
|
33 |
+
start_time = time.time() # Start time before function execution
|
34 |
+
result = func(*args, **kwargs) # Function execution
|
35 |
+
end_time = time.time() # End time after function execution
|
36 |
+
execution_time = end_time - start_time # Calculate execution time
|
37 |
+
logger.info(f"Running {func.__name__}: --- {execution_time} seconds ---")
|
38 |
+
return result
|
39 |
+
|
40 |
return wrapper
|
reggpt/controller/__init__.py
ADDED
File without changes
|
qaPipeline.py β reggpt/controller/agent.py
RENAMED
@@ -1,150 +1,73 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
return res
|
76 |
-
|
77 |
-
except HTTPException as e:
|
78 |
-
print('HTTPException eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
|
79 |
-
print(e)
|
80 |
-
logger.exception(e)
|
81 |
-
raise e
|
82 |
-
|
83 |
-
except Exception as e:
|
84 |
-
print('Exception eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
|
85 |
-
print(e)
|
86 |
-
logger.exception(e)
|
87 |
-
raise e
|
88 |
-
|
89 |
-
|
90 |
-
def run_router_chain(query):
|
91 |
-
try:
|
92 |
-
logger.info(f"run_router_chain : Question: {query}")
|
93 |
-
# Get the answer from the chain
|
94 |
-
start = time.time()
|
95 |
-
chain_type = router_chain.invoke(query)['text']
|
96 |
-
end = time.time()
|
97 |
-
|
98 |
-
# log the result
|
99 |
-
logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}")
|
100 |
-
|
101 |
-
return chain_type
|
102 |
-
|
103 |
-
except Exception as e:
|
104 |
-
logger.exception(e)
|
105 |
-
raise e
|
106 |
-
|
107 |
-
|
108 |
-
def run_qa_chain(query):
|
109 |
-
try:
|
110 |
-
logger.info(f"run_qa_chain : Question: {query}")
|
111 |
-
# Get the answer from the chain
|
112 |
-
start = time.time()
|
113 |
-
# res = qa_chain(query)
|
114 |
-
res = qa_chain.invoke({"question": query, "chat_history":""})
|
115 |
-
# res = response
|
116 |
-
# answer, docs = res['result'],res['source_documents']
|
117 |
-
end = time.time()
|
118 |
-
|
119 |
-
# log the result
|
120 |
-
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
|
121 |
-
|
122 |
-
return qa_chain_output_parser(res)
|
123 |
-
|
124 |
-
except Exception as e:
|
125 |
-
logger.exception(e)
|
126 |
-
raise e
|
127 |
-
|
128 |
-
|
129 |
-
def run_general_qa_chain(query):
|
130 |
-
try:
|
131 |
-
logger.info(f"run_general_qa_chain : Question: {query}")
|
132 |
-
|
133 |
-
# Get the answer from the chain
|
134 |
-
start = time.time()
|
135 |
-
res = general_qa_chain.invoke(query)
|
136 |
-
end = time.time()
|
137 |
-
|
138 |
-
# log the result
|
139 |
-
|
140 |
-
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
|
141 |
-
|
142 |
-
return general_qa_chain_output_parser(res)
|
143 |
-
|
144 |
-
except Exception as e:
|
145 |
-
logger.exception(e)
|
146 |
-
raise e
|
147 |
-
|
148 |
-
|
149 |
-
def run_out_of_domain_chain(query):
|
150 |
-
return out_of_domain_chain_parser(query)
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from fastapi import HTTPException
|
7 |
+
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
8 |
+
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
9 |
+
|
10 |
+
from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
11 |
+
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
verbose = os.environ.get('VERBOSE')
|
15 |
+
|
16 |
+
qa_model_type=QA_MODEL_TYPE
|
17 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
18 |
+
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
19 |
+
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
20 |
+
# model_type="tiiuae/falcon-7b-instruct"
|
21 |
+
|
22 |
+
# retriever=load_faiss_retriever()
|
23 |
+
retriever=load_ensemble_retriever()
|
24 |
+
# retriever=load_multi_query_retriever(multi_query_model_type)
|
25 |
+
logger.info("retriever loaded:")
|
26 |
+
|
27 |
+
qa_chain= get_qa_chain(qa_model_type,retriever)
|
28 |
+
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
29 |
+
router_chain= get_router_chain(router_model_type)
|
30 |
+
|
31 |
+
|
32 |
+
def chain_selector(chain_type, query):
|
33 |
+
chain_type = chain_type.lower().strip()
|
34 |
+
logger.info(f"chain_selector : chain_type: {chain_type} Question: {query}")
|
35 |
+
if "greeting" in chain_type:
|
36 |
+
return run_general_qa_chain(query)
|
37 |
+
elif "other" in chain_type:
|
38 |
+
return run_out_of_domain_chain(query)
|
39 |
+
elif ("relevant" in chain_type) or ("not sure" in chain_type) :
|
40 |
+
return run_qa_chain(query)
|
41 |
+
else:
|
42 |
+
raise ValueError(
|
43 |
+
f"Received invalid type '{chain_type}'"
|
44 |
+
)
|
45 |
+
|
46 |
+
def run_agent(query):
|
47 |
+
try:
|
48 |
+
logger.info(f"run_agent : Question: {query}")
|
49 |
+
print(f"---------------- run_agent : Question: {query} ----------------")
|
50 |
+
# Get the answer from the chain
|
51 |
+
start = time.time()
|
52 |
+
chain_type = run_router_chain(query)
|
53 |
+
res = chain_selector(chain_type,query)
|
54 |
+
end = time.time()
|
55 |
+
|
56 |
+
# log the result
|
57 |
+
logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
|
58 |
+
print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
|
59 |
+
|
60 |
+
return res
|
61 |
+
|
62 |
+
except HTTPException as e:
|
63 |
+
print('HTTPException')
|
64 |
+
print(e)
|
65 |
+
logger.exception(e)
|
66 |
+
raise e
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
print('Exception')
|
70 |
+
print(e)
|
71 |
+
logger.exception(e)
|
72 |
+
raise e
|
73 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reggpt/controller/router.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
import logging
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
from fastapi import HTTPException
|
23 |
+
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
24 |
+
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
25 |
+
|
26 |
+
from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
27 |
+
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
28 |
+
load_dotenv()
|
29 |
+
|
30 |
+
verbose = os.environ.get('VERBOSE')
|
31 |
+
|
32 |
+
qa_model_type=QA_MODEL_TYPE
|
33 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
34 |
+
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
35 |
+
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
36 |
+
# model_type="tiiuae/falcon-7b-instruct"
|
37 |
+
|
38 |
+
# retriever=load_faiss_retriever()
|
39 |
+
retriever=load_ensemble_retriever()
|
40 |
+
# retriever=load_multi_query_retriever(multi_query_model_type)
|
41 |
+
logger.info("retriever loaded:")
|
42 |
+
|
43 |
+
qa_chain= get_qa_chain(qa_model_type,retriever)
|
44 |
+
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
45 |
+
router_chain= get_router_chain(router_model_type)
|
46 |
+
|
47 |
+
def run_router_chain(query):
|
48 |
+
try:
|
49 |
+
logger.info(f"run_router_chain : Question: {query}")
|
50 |
+
# Get the answer from the chain
|
51 |
+
start = time.time()
|
52 |
+
chain_type = router_chain.invoke(query)['text']
|
53 |
+
end = time.time()
|
54 |
+
|
55 |
+
# log the result
|
56 |
+
logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}")
|
57 |
+
|
58 |
+
return chain_type
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
logger.exception(e)
|
62 |
+
raise e
|
reggpt/data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import os
|
2 |
+
# import sys
|
3 |
+
|
4 |
+
# if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
+
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
{data β reggpt/data}/splitted_texts.jsonl
RENAMED
The diff for this file is too large to render.
See raw diff
|
|
llm.py β reggpt/llms/llm.py
RENAMED
@@ -1,47 +1,47 @@
|
|
1 |
-
"""
|
2 |
-
/*************************************************************************
|
3 |
-
*
|
4 |
-
* CONFIDENTIAL
|
5 |
-
* __________________
|
6 |
-
*
|
7 |
-
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
-
* All Rights Reserved
|
9 |
-
*
|
10 |
-
* Author : Theekshana Samaradiwakara
|
11 |
-
* Description :Python Backend API to chat with private data
|
12 |
-
* CreatedDate : 14/11/2023
|
13 |
-
* LastModifiedDate : 18/03/2024
|
14 |
-
*************************************************************************/
|
15 |
-
"""
|
16 |
-
|
17 |
-
import os
|
18 |
-
# import time
|
19 |
-
import logging
|
20 |
-
logger = logging.getLogger(__name__)
|
21 |
-
from dotenv import load_dotenv
|
22 |
-
|
23 |
-
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
24 |
-
from langchain_openai import ChatOpenAI
|
25 |
-
|
26 |
-
load_dotenv()
|
27 |
-
|
28 |
-
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
29 |
-
# openai_api_key = "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"
|
30 |
-
print(f"--- ---- ---- openai_api_key: {openai_api_key}")
|
31 |
-
|
32 |
-
verbose = os.environ.get('VERBOSE')
|
33 |
-
|
34 |
-
def get_model(model_type):
|
35 |
-
|
36 |
-
match model_type:
|
37 |
-
case "openai":
|
38 |
-
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, openai_api_key=openai_api_key)
|
39 |
-
case _default:
|
40 |
-
# raise exception if model_type is not supported
|
41 |
-
msg=f"Model type '{model_type}' is not supported. Please choose a valid one"
|
42 |
-
logger.error(msg)
|
43 |
-
return Exception(msg)
|
44 |
-
|
45 |
-
|
46 |
-
logger.info(f"model_type: {model_type} loaded:")
|
47 |
return llm
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
# import time
|
19 |
+
import logging
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
|
23 |
+
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
24 |
+
from langchain_openai import ChatOpenAI
|
25 |
+
|
26 |
+
load_dotenv()
|
27 |
+
|
28 |
+
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
29 |
+
# openai_api_key = "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"
|
30 |
+
print(f"--- ---- ---- openai_api_key: {openai_api_key}")
|
31 |
+
|
32 |
+
verbose = os.environ.get('VERBOSE')
|
33 |
+
|
34 |
+
def get_model(model_type):
|
35 |
+
|
36 |
+
match model_type:
|
37 |
+
case "openai":
|
38 |
+
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, openai_api_key=openai_api_key)
|
39 |
+
case _default:
|
40 |
+
# raise exception if model_type is not supported
|
41 |
+
msg=f"Model type '{model_type}' is not supported. Please choose a valid one"
|
42 |
+
logger.error(msg)
|
43 |
+
return Exception(msg)
|
44 |
+
|
45 |
+
|
46 |
+
logger.info(f"model_type: {model_type} loaded:")
|
47 |
return llm
|
reggpt/memory/__init__.py
ADDED
File without changes
|
conversationBufferWindowMemory.py β reggpt/memory/conversationBufferWindowMemory.py
RENAMED
@@ -1,134 +1,134 @@
|
|
1 |
-
"""
|
2 |
-
/*************************************************************************
|
3 |
-
*
|
4 |
-
* CONFIDENTIAL
|
5 |
-
* __________________
|
6 |
-
*
|
7 |
-
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
-
* All Rights Reserved
|
9 |
-
*
|
10 |
-
* Author : Theekshana Samaradiwakara
|
11 |
-
* Description :Python Backend API to chat with private data
|
12 |
-
* CreatedDate : 14/11/2023
|
13 |
-
* LastModifiedDate : 18/11/2020
|
14 |
-
*************************************************************************/
|
15 |
-
"""
|
16 |
-
|
17 |
-
from abc import ABC
|
18 |
-
from typing import Any, Dict, Optional, Tuple
|
19 |
-
# import json
|
20 |
-
|
21 |
-
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
22 |
-
from langchain.memory.utils import get_prompt_input_key
|
23 |
-
from langchain.pydantic_v1 import Field
|
24 |
-
from langchain.schema import BaseChatMessageHistory, BaseMemory
|
25 |
-
|
26 |
-
from typing import List, Union
|
27 |
-
|
28 |
-
# from langchain.memory.chat_memory import BaseChatMemory
|
29 |
-
from langchain.schema.messages import BaseMessage, get_buffer_string
|
30 |
-
|
31 |
-
|
32 |
-
class BaseChatMemory(BaseMemory, ABC):
|
33 |
-
"""Abstract base class for chat memory."""
|
34 |
-
|
35 |
-
chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
36 |
-
output_key: Optional[str] = None
|
37 |
-
input_key: Optional[str] = None
|
38 |
-
return_messages: bool = False
|
39 |
-
|
40 |
-
def _get_input_output(
|
41 |
-
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
42 |
-
) -> Tuple[str, str]:
|
43 |
-
|
44 |
-
|
45 |
-
if self.input_key is None:
|
46 |
-
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
47 |
-
else:
|
48 |
-
prompt_input_key = self.input_key
|
49 |
-
|
50 |
-
if self.output_key is None:
|
51 |
-
"""
|
52 |
-
output for agent with LLM chain tool = {answer}
|
53 |
-
output for agent with ConversationalRetrievalChain tool = {'question', 'chat_history', 'answer','source_documents'}
|
54 |
-
"""
|
55 |
-
|
56 |
-
LLM_key = 'output'
|
57 |
-
Retrieval_key = 'answer'
|
58 |
-
if isinstance(outputs[LLM_key], dict):
|
59 |
-
Retrieval_dict = outputs[LLM_key]
|
60 |
-
if Retrieval_key in Retrieval_dict.keys():
|
61 |
-
#output keys are 'answer' , 'source_documents'
|
62 |
-
output = Retrieval_dict[Retrieval_key]
|
63 |
-
else:
|
64 |
-
raise ValueError(f"output key: {LLM_key} not a valid dictionary")
|
65 |
-
|
66 |
-
else:
|
67 |
-
#otherwise output key will be 'output'
|
68 |
-
output_key = list(outputs.keys())[0]
|
69 |
-
output = outputs[output_key]
|
70 |
-
|
71 |
-
# if len(outputs) != 1:
|
72 |
-
# raise ValueError(f"One output key expected, got {outputs.keys()}")
|
73 |
-
|
74 |
-
|
75 |
-
else:
|
76 |
-
output_key = self.output_key
|
77 |
-
output = outputs[output_key]
|
78 |
-
|
79 |
-
return inputs[prompt_input_key], output
|
80 |
-
|
81 |
-
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
82 |
-
"""Save context from this conversation to buffer."""
|
83 |
-
input_str, output_str = self._get_input_output(inputs, outputs)
|
84 |
-
self.chat_memory.add_user_message(input_str)
|
85 |
-
self.chat_memory.add_ai_message(output_str)
|
86 |
-
|
87 |
-
def clear(self) -> None:
|
88 |
-
"""Clear memory contents."""
|
89 |
-
self.chat_memory.clear()
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
class ConversationBufferWindowMemory(BaseChatMemory):
|
96 |
-
"""Buffer for storing conversation memory inside a limited size window."""
|
97 |
-
|
98 |
-
human_prefix: str = "Human"
|
99 |
-
ai_prefix: str = "AI"
|
100 |
-
memory_key: str = "history" #: :meta private:
|
101 |
-
k: int = 5
|
102 |
-
"""Number of messages to store in buffer."""
|
103 |
-
|
104 |
-
@property
|
105 |
-
def buffer(self) -> Union[str, List[BaseMessage]]:
|
106 |
-
"""String buffer of memory."""
|
107 |
-
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
108 |
-
|
109 |
-
@property
|
110 |
-
def buffer_as_str(self) -> str:
|
111 |
-
"""Exposes the buffer as a string in case return_messages is True."""
|
112 |
-
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
113 |
-
return get_buffer_string(
|
114 |
-
messages,
|
115 |
-
human_prefix=self.human_prefix,
|
116 |
-
ai_prefix=self.ai_prefix,
|
117 |
-
)
|
118 |
-
|
119 |
-
@property
|
120 |
-
def buffer_as_messages(self) -> List[BaseMessage]:
|
121 |
-
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
122 |
-
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
123 |
-
|
124 |
-
@property
|
125 |
-
def memory_variables(self) -> List[str]:
|
126 |
-
"""Will always return list of memory variables.
|
127 |
-
|
128 |
-
:meta private:
|
129 |
-
"""
|
130 |
-
return [self.memory_key]
|
131 |
-
|
132 |
-
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
133 |
-
"""Return history buffer."""
|
134 |
return {self.memory_key: self.buffer}
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/11/2020
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
from abc import ABC
|
18 |
+
from typing import Any, Dict, Optional, Tuple
|
19 |
+
# import json
|
20 |
+
|
21 |
+
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
|
22 |
+
from langchain.memory.utils import get_prompt_input_key
|
23 |
+
from langchain.pydantic_v1 import Field
|
24 |
+
from langchain.schema import BaseChatMessageHistory, BaseMemory
|
25 |
+
|
26 |
+
from typing import List, Union
|
27 |
+
|
28 |
+
# from langchain.memory.chat_memory import BaseChatMemory
|
29 |
+
from langchain.schema.messages import BaseMessage, get_buffer_string
|
30 |
+
|
31 |
+
|
32 |
+
class BaseChatMemory(BaseMemory, ABC):
|
33 |
+
"""Abstract base class for chat memory."""
|
34 |
+
|
35 |
+
chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
|
36 |
+
output_key: Optional[str] = None
|
37 |
+
input_key: Optional[str] = None
|
38 |
+
return_messages: bool = False
|
39 |
+
|
40 |
+
def _get_input_output(
|
41 |
+
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
42 |
+
) -> Tuple[str, str]:
|
43 |
+
|
44 |
+
|
45 |
+
if self.input_key is None:
|
46 |
+
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
47 |
+
else:
|
48 |
+
prompt_input_key = self.input_key
|
49 |
+
|
50 |
+
if self.output_key is None:
|
51 |
+
"""
|
52 |
+
output for agent with LLM chain tool = {answer}
|
53 |
+
output for agent with ConversationalRetrievalChain tool = {'question', 'chat_history', 'answer','source_documents'}
|
54 |
+
"""
|
55 |
+
|
56 |
+
LLM_key = 'output'
|
57 |
+
Retrieval_key = 'answer'
|
58 |
+
if isinstance(outputs[LLM_key], dict):
|
59 |
+
Retrieval_dict = outputs[LLM_key]
|
60 |
+
if Retrieval_key in Retrieval_dict.keys():
|
61 |
+
#output keys are 'answer' , 'source_documents'
|
62 |
+
output = Retrieval_dict[Retrieval_key]
|
63 |
+
else:
|
64 |
+
raise ValueError(f"output key: {LLM_key} not a valid dictionary")
|
65 |
+
|
66 |
+
else:
|
67 |
+
#otherwise output key will be 'output'
|
68 |
+
output_key = list(outputs.keys())[0]
|
69 |
+
output = outputs[output_key]
|
70 |
+
|
71 |
+
# if len(outputs) != 1:
|
72 |
+
# raise ValueError(f"One output key expected, got {outputs.keys()}")
|
73 |
+
|
74 |
+
|
75 |
+
else:
|
76 |
+
output_key = self.output_key
|
77 |
+
output = outputs[output_key]
|
78 |
+
|
79 |
+
return inputs[prompt_input_key], output
|
80 |
+
|
81 |
+
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
82 |
+
"""Save context from this conversation to buffer."""
|
83 |
+
input_str, output_str = self._get_input_output(inputs, outputs)
|
84 |
+
self.chat_memory.add_user_message(input_str)
|
85 |
+
self.chat_memory.add_ai_message(output_str)
|
86 |
+
|
87 |
+
def clear(self) -> None:
|
88 |
+
"""Clear memory contents."""
|
89 |
+
self.chat_memory.clear()
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
class ConversationBufferWindowMemory(BaseChatMemory):
|
96 |
+
"""Buffer for storing conversation memory inside a limited size window."""
|
97 |
+
|
98 |
+
human_prefix: str = "Human"
|
99 |
+
ai_prefix: str = "AI"
|
100 |
+
memory_key: str = "history" #: :meta private:
|
101 |
+
k: int = 5
|
102 |
+
"""Number of messages to store in buffer."""
|
103 |
+
|
104 |
+
@property
|
105 |
+
def buffer(self) -> Union[str, List[BaseMessage]]:
|
106 |
+
"""String buffer of memory."""
|
107 |
+
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
108 |
+
|
109 |
+
@property
|
110 |
+
def buffer_as_str(self) -> str:
|
111 |
+
"""Exposes the buffer as a string in case return_messages is True."""
|
112 |
+
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
113 |
+
return get_buffer_string(
|
114 |
+
messages,
|
115 |
+
human_prefix=self.human_prefix,
|
116 |
+
ai_prefix=self.ai_prefix,
|
117 |
+
)
|
118 |
+
|
119 |
+
@property
|
120 |
+
def buffer_as_messages(self) -> List[BaseMessage]:
|
121 |
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
122 |
+
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
123 |
+
|
124 |
+
@property
|
125 |
+
def memory_variables(self) -> List[str]:
|
126 |
+
"""Will always return list of memory variables.
|
127 |
+
|
128 |
+
:meta private:
|
129 |
+
"""
|
130 |
+
return [self.memory_key]
|
131 |
+
|
132 |
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
133 |
+
"""Return history buffer."""
|
134 |
return {self.memory_key: self.buffer}
|
reggpt/output_parsers/__init__.py
ADDED
File without changes
|
output_parser.py β reggpt/output_parsers/output_parser.py
RENAMED
File without changes
|
reggpt/prompts/__init__.py
ADDED
File without changes
|
reggpt/prompts/document_combine.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
document_combine_prompt = PromptTemplate(
|
4 |
+
input_variables=["source","year", "page","page_content"],
|
5 |
+
template=
|
6 |
+
"""<doc> source: {source}, year: {year}, page: {page}, page content: {page_content} </doc>"""
|
7 |
+
)
|
reggpt/prompts/general.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
general_qa_template_Mixtral_V0= """
|
4 |
+
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
5 |
+
you can answer Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations related question .
|
6 |
+
|
7 |
+
Is the provided question below a greeting? First, evaluate whether the input resembles a typical greeting or not.
|
8 |
+
|
9 |
+
Greetings are used to say 'hello' and 'how are you?' and to say 'goodbye' and 'nice speaking with you.' and 'hi, I'm (user's name).'
|
10 |
+
Greetings are words used when we want to introduce ourselves to others and when we want to find out how someone is feeling.
|
11 |
+
|
12 |
+
You can only reply to the user's greetings.
|
13 |
+
If the question is a greeting, reply accordingly as the AI assistant of company boardpac.
|
14 |
+
If the question is not related to greetings and research papers, say that it is out of your domain.
|
15 |
+
If the question is not clear enough, ask for more details and don't try to make up answers.
|
16 |
+
|
17 |
+
Answer should be polite, short, and simple.
|
18 |
+
|
19 |
+
Additionally, it's important to note that this AI assistant has access to an internal collection of research papers, and answers can be provided using the information available in those CBSL Dataset.
|
20 |
+
|
21 |
+
Question: {question}
|
22 |
+
"""
|
23 |
+
|
24 |
+
general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template_Mixtral_V0)
|
25 |
+
|
reggpt/prompts/multi_query.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
MULTY_QUERY_PROMPT = PromptTemplate(
|
4 |
+
input_variables=["question"],
|
5 |
+
template="""You are an AI language model assistant. Your task is to generate three
|
6 |
+
different versions of the given user question to retrieve relevant documents from a vector
|
7 |
+
database. By generating multiple perspectives on the user question, your goal is to help
|
8 |
+
the user overcome some of the limitations of the distance-based similarity search.
|
9 |
+
Provide these alternative questions separated by newlines.
|
10 |
+
|
11 |
+
Dont add anything extra before or after to the 3 questions. Just give 3 lines with 3 questions.
|
12 |
+
Just provide 3 lines having 3 questions only.
|
13 |
+
Answer should be in following format.
|
14 |
+
|
15 |
+
1. alternative question 1
|
16 |
+
2. alternative question 2
|
17 |
+
3. alternative question 3
|
18 |
+
|
19 |
+
Original question: {question}""",
|
20 |
+
)
|
reggpt/prompts/retrieval.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
retrieval_qa_template = (
|
4 |
+
"""<<SYS>>
|
5 |
+
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
6 |
+
|
7 |
+
please answer the question based on the chat history provided below. Answer should be short and simple as possible and on to the point.
|
8 |
+
<chat history>: {chat_history}
|
9 |
+
|
10 |
+
If the question is related to welcomes and greetings answer accordingly.
|
11 |
+
|
12 |
+
Else If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
|
13 |
+
please answer the question based only on the information provided in following central bank documents published in various years.
|
14 |
+
The published year is mentioned as the metadata 'year' of each source document.
|
15 |
+
Please notice that content of a one document of a past year can updated by a new document from a recent year.
|
16 |
+
Always try to answer with latest information and mention the year which information extracted.
|
17 |
+
If you dont know the answer say you dont know, dont try to makeup answers. Dont add any extra details that is not mentioned in the context.
|
18 |
+
|
19 |
+
<</SYS>>
|
20 |
+
|
21 |
+
[INST]
|
22 |
+
<DOCUMENTS>
|
23 |
+
{context}
|
24 |
+
</DOCUMENTS>
|
25 |
+
|
26 |
+
Question : {question}[/INST]"""
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
retrieval_qa_chain_prompt = PromptTemplate(
|
31 |
+
input_variables=["question", "context", "chat_history"],
|
32 |
+
template=retrieval_qa_template
|
33 |
+
)
|
reggpt/prompts/router.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
router_template_Mixtral_V0= """
|
4 |
+
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
5 |
+
|
6 |
+
If a user asks a question you have to classify it to following 3 types Relevant, Greeting, Other.
|
7 |
+
|
8 |
+
"Relevantβ: If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
|
9 |
+
"Greetingβ: If the question is a greeting like good morning, hi my name is., thank you or General Question ask about the AI assistance of a company boardpac.
|
10 |
+
"Otherβ: If the question is not related to research papers.
|
11 |
+
|
12 |
+
Give the correct name of question type. If you are not sure return "Not Sure" instead.
|
13 |
+
|
14 |
+
Question : {question}
|
15 |
+
"""
|
16 |
+
router_prompt=PromptTemplate.from_template(router_template_Mixtral_V0)
|
17 |
+
|
ensemble_retriever.py β reggpt/retriever/ensemble_retriever.py
RENAMED
@@ -1,228 +1,228 @@
|
|
1 |
-
"""
|
2 |
-
/*************************************************************************
|
3 |
-
*
|
4 |
-
* CONFIDENTIAL
|
5 |
-
* __________________
|
6 |
-
*
|
7 |
-
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
-
* All Rights Reserved
|
9 |
-
*
|
10 |
-
* Author : Theekshana Samaradiwakara
|
11 |
-
* Description :Python Backend API to chat with private data
|
12 |
-
* CreatedDate : 14/11/2023
|
13 |
-
* LastModifiedDate : 18/03/2024
|
14 |
-
*************************************************************************/
|
15 |
-
"""
|
16 |
-
|
17 |
-
"""
|
18 |
-
Ensemble retriever that ensemble the results of
|
19 |
-
multiple retrievers by using weighted Reciprocal Rank Fusion
|
20 |
-
"""
|
21 |
-
|
22 |
-
import os
|
23 |
-
import sys
|
24 |
-
|
25 |
-
from pathlib import Path
|
26 |
-
Path(__file__).resolve().parent.parent
|
27 |
-
|
28 |
-
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
29 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
30 |
-
|
31 |
-
|
32 |
-
import logging
|
33 |
-
logger = logging.getLogger(__name__)
|
34 |
-
from typing import Any, Dict, List
|
35 |
-
|
36 |
-
from langchain.callbacks.manager import (
|
37 |
-
AsyncCallbackManagerForRetrieverRun,
|
38 |
-
CallbackManagerForRetrieverRun,
|
39 |
-
)
|
40 |
-
from langchain.pydantic_v1 import root_validator
|
41 |
-
from langchain.schema import BaseRetriever, Document
|
42 |
-
|
43 |
-
import numpy as np
|
44 |
-
import pandas as pd
|
45 |
-
|
46 |
-
|
47 |
-
class EnsembleRetriever(BaseRetriever):
|
48 |
-
"""Retriever that ensembles the multiple retrievers.
|
49 |
-
|
50 |
-
It uses a rank fusion.
|
51 |
-
|
52 |
-
Args:
|
53 |
-
retrievers: A list of retrievers to ensemble.
|
54 |
-
weights: A list of weights corresponding to the retrievers. Defaults to equal
|
55 |
-
weighting for all retrievers.
|
56 |
-
c: A constant added to the rank, controlling the balance between the importance
|
57 |
-
of high-ranked items and the consideration given to lower-ranked items.
|
58 |
-
Default is 60.
|
59 |
-
"""
|
60 |
-
|
61 |
-
retrievers: List[BaseRetriever]
|
62 |
-
weights: List[float]
|
63 |
-
c: int = 60
|
64 |
-
date_key: str = "year"
|
65 |
-
top_k: int = 4
|
66 |
-
|
67 |
-
@root_validator(pre=True,allow_reuse=True)
|
68 |
-
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
69 |
-
if not values.get("weights"):
|
70 |
-
n_retrievers = len(values["retrievers"])
|
71 |
-
values["weights"] = [1 / n_retrievers] * n_retrievers
|
72 |
-
return values
|
73 |
-
|
74 |
-
def _get_relevant_documents(
|
75 |
-
self,
|
76 |
-
query: str,
|
77 |
-
*,
|
78 |
-
run_manager: CallbackManagerForRetrieverRun,
|
79 |
-
) -> List[Document]:
|
80 |
-
"""
|
81 |
-
Get the relevant documents for a given query.
|
82 |
-
|
83 |
-
Args:
|
84 |
-
query: The query to search for.
|
85 |
-
|
86 |
-
Returns:
|
87 |
-
A list of reranked documents.
|
88 |
-
"""
|
89 |
-
|
90 |
-
# Get fused result of the retrievers.
|
91 |
-
fused_documents = self.rank_fusion(query, run_manager)
|
92 |
-
|
93 |
-
# check for key exists
|
94 |
-
if fused_documents[0].metadata[self.date_key] != None:
|
95 |
-
doc_dates = pd.to_datetime(
|
96 |
-
[doc.metadata[self.date_key] for doc in fused_documents]
|
97 |
-
)
|
98 |
-
sorted_node_idxs = np.flip(doc_dates.argsort())
|
99 |
-
fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
|
100 |
-
logger.info('Ensemble Retriever Documents sorted by year')
|
101 |
-
|
102 |
-
# return fused_documents[:self.top_k]
|
103 |
-
return fused_documents
|
104 |
-
|
105 |
-
async def _aget_relevant_documents(
|
106 |
-
self,
|
107 |
-
query: str,
|
108 |
-
*,
|
109 |
-
run_manager: AsyncCallbackManagerForRetrieverRun,
|
110 |
-
) -> List[Document]:
|
111 |
-
"""
|
112 |
-
Asynchronously get the relevant documents for a given query.
|
113 |
-
|
114 |
-
Args:
|
115 |
-
query: The query to search for.
|
116 |
-
|
117 |
-
Returns:
|
118 |
-
A list of reranked documents.
|
119 |
-
"""
|
120 |
-
|
121 |
-
# Get fused result of the retrievers.
|
122 |
-
fused_documents = await self.arank_fusion(query, run_manager)
|
123 |
-
|
124 |
-
return fused_documents
|
125 |
-
|
126 |
-
def rank_fusion(
|
127 |
-
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
128 |
-
) -> List[Document]:
|
129 |
-
"""
|
130 |
-
Retrieve the results of the retrievers and use rank_fusion_func to get
|
131 |
-
the final result.
|
132 |
-
|
133 |
-
Args:
|
134 |
-
query: The query to search for.
|
135 |
-
|
136 |
-
Returns:
|
137 |
-
A list of reranked documents.
|
138 |
-
"""
|
139 |
-
|
140 |
-
# Get the results of all retrievers.
|
141 |
-
retriever_docs = [
|
142 |
-
retriever.get_relevant_documents(
|
143 |
-
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
144 |
-
)
|
145 |
-
for i, retriever in enumerate(self.retrievers)
|
146 |
-
]
|
147 |
-
|
148 |
-
# apply rank fusion
|
149 |
-
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
150 |
-
|
151 |
-
return fused_documents
|
152 |
-
|
153 |
-
async def arank_fusion(
|
154 |
-
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
155 |
-
) -> List[Document]:
|
156 |
-
"""
|
157 |
-
Asynchronously retrieve the results of the retrievers
|
158 |
-
and use rank_fusion_func to get the final result.
|
159 |
-
|
160 |
-
Args:
|
161 |
-
query: The query to search for.
|
162 |
-
|
163 |
-
Returns:
|
164 |
-
A list of reranked documents.
|
165 |
-
"""
|
166 |
-
|
167 |
-
# Get the results of all retrievers.
|
168 |
-
retriever_docs = [
|
169 |
-
await retriever.aget_relevant_documents(
|
170 |
-
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
171 |
-
)
|
172 |
-
for i, retriever in enumerate(self.retrievers)
|
173 |
-
]
|
174 |
-
|
175 |
-
# apply rank fusion
|
176 |
-
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
177 |
-
|
178 |
-
return fused_documents
|
179 |
-
|
180 |
-
def weighted_reciprocal_rank(
|
181 |
-
self, doc_lists: List[List[Document]]
|
182 |
-
) -> List[Document]:
|
183 |
-
"""
|
184 |
-
Perform weighted Reciprocal Rank Fusion on multiple rank lists.
|
185 |
-
You can find more details about RRF here:
|
186 |
-
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
187 |
-
|
188 |
-
Args:
|
189 |
-
doc_lists: A list of rank lists, where each rank list contains unique items.
|
190 |
-
|
191 |
-
Returns:
|
192 |
-
list: The final aggregated list of items sorted by their weighted RRF
|
193 |
-
scores in descending order.
|
194 |
-
"""
|
195 |
-
if len(doc_lists) != len(self.weights):
|
196 |
-
raise ValueError(
|
197 |
-
"Number of rank lists must be equal to the number of weights."
|
198 |
-
)
|
199 |
-
|
200 |
-
# Create a union of all unique documents in the input doc_lists
|
201 |
-
all_documents = set()
|
202 |
-
for doc_list in doc_lists:
|
203 |
-
for doc in doc_list:
|
204 |
-
all_documents.add(doc.page_content)
|
205 |
-
|
206 |
-
# Initialize the RRF score dictionary for each document
|
207 |
-
rrf_score_dic = {doc: 0.0 for doc in all_documents}
|
208 |
-
|
209 |
-
# Calculate RRF scores for each document
|
210 |
-
for doc_list, weight in zip(doc_lists, self.weights):
|
211 |
-
for rank, doc in enumerate(doc_list, start=1):
|
212 |
-
rrf_score = weight * (1 / (rank + self.c))
|
213 |
-
rrf_score_dic[doc.page_content] += rrf_score
|
214 |
-
|
215 |
-
# Sort documents by their RRF scores in descending order
|
216 |
-
sorted_documents = sorted(
|
217 |
-
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
|
218 |
-
)
|
219 |
-
|
220 |
-
# Map the sorted page_content back to the original document objects
|
221 |
-
page_content_to_doc_map = {
|
222 |
-
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
|
223 |
-
}
|
224 |
-
sorted_docs = [
|
225 |
-
page_content_to_doc_map[page_content] for page_content in sorted_documents
|
226 |
-
]
|
227 |
-
|
228 |
-
return sorted_docs
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
"""
|
18 |
+
Ensemble retriever that ensemble the results of
|
19 |
+
multiple retrievers by using weighted Reciprocal Rank Fusion
|
20 |
+
"""
|
21 |
+
|
22 |
+
import os
|
23 |
+
import sys
|
24 |
+
|
25 |
+
from pathlib import Path
|
26 |
+
Path(__file__).resolve().parent.parent
|
27 |
+
|
28 |
+
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
29 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
30 |
+
|
31 |
+
|
32 |
+
import logging
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
from typing import Any, Dict, List
|
35 |
+
|
36 |
+
from langchain.callbacks.manager import (
|
37 |
+
AsyncCallbackManagerForRetrieverRun,
|
38 |
+
CallbackManagerForRetrieverRun,
|
39 |
+
)
|
40 |
+
from langchain.pydantic_v1 import root_validator
|
41 |
+
from langchain.schema import BaseRetriever, Document
|
42 |
+
|
43 |
+
import numpy as np
|
44 |
+
import pandas as pd
|
45 |
+
|
46 |
+
|
47 |
+
class EnsembleRetriever(BaseRetriever):
|
48 |
+
"""Retriever that ensembles the multiple retrievers.
|
49 |
+
|
50 |
+
It uses a rank fusion.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
retrievers: A list of retrievers to ensemble.
|
54 |
+
weights: A list of weights corresponding to the retrievers. Defaults to equal
|
55 |
+
weighting for all retrievers.
|
56 |
+
c: A constant added to the rank, controlling the balance between the importance
|
57 |
+
of high-ranked items and the consideration given to lower-ranked items.
|
58 |
+
Default is 60.
|
59 |
+
"""
|
60 |
+
|
61 |
+
retrievers: List[BaseRetriever]
|
62 |
+
weights: List[float]
|
63 |
+
c: int = 60
|
64 |
+
date_key: str = "year"
|
65 |
+
top_k: int = 4
|
66 |
+
|
67 |
+
@root_validator(pre=True,allow_reuse=True)
|
68 |
+
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
69 |
+
if not values.get("weights"):
|
70 |
+
n_retrievers = len(values["retrievers"])
|
71 |
+
values["weights"] = [1 / n_retrievers] * n_retrievers
|
72 |
+
return values
|
73 |
+
|
74 |
+
def _get_relevant_documents(
|
75 |
+
self,
|
76 |
+
query: str,
|
77 |
+
*,
|
78 |
+
run_manager: CallbackManagerForRetrieverRun,
|
79 |
+
) -> List[Document]:
|
80 |
+
"""
|
81 |
+
Get the relevant documents for a given query.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
query: The query to search for.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
A list of reranked documents.
|
88 |
+
"""
|
89 |
+
|
90 |
+
# Get fused result of the retrievers.
|
91 |
+
fused_documents = self.rank_fusion(query, run_manager)
|
92 |
+
|
93 |
+
# check for key exists
|
94 |
+
if fused_documents[0].metadata[self.date_key] != None:
|
95 |
+
doc_dates = pd.to_datetime(
|
96 |
+
[doc.metadata[self.date_key] for doc in fused_documents]
|
97 |
+
)
|
98 |
+
sorted_node_idxs = np.flip(doc_dates.argsort())
|
99 |
+
fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
|
100 |
+
logger.info('Ensemble Retriever Documents sorted by year')
|
101 |
+
|
102 |
+
# return fused_documents[:self.top_k]
|
103 |
+
return fused_documents
|
104 |
+
|
105 |
+
async def _aget_relevant_documents(
|
106 |
+
self,
|
107 |
+
query: str,
|
108 |
+
*,
|
109 |
+
run_manager: AsyncCallbackManagerForRetrieverRun,
|
110 |
+
) -> List[Document]:
|
111 |
+
"""
|
112 |
+
Asynchronously get the relevant documents for a given query.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
query: The query to search for.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A list of reranked documents.
|
119 |
+
"""
|
120 |
+
|
121 |
+
# Get fused result of the retrievers.
|
122 |
+
fused_documents = await self.arank_fusion(query, run_manager)
|
123 |
+
|
124 |
+
return fused_documents
|
125 |
+
|
126 |
+
def rank_fusion(
|
127 |
+
self, query: str, run_manager: CallbackManagerForRetrieverRun
|
128 |
+
) -> List[Document]:
|
129 |
+
"""
|
130 |
+
Retrieve the results of the retrievers and use rank_fusion_func to get
|
131 |
+
the final result.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
query: The query to search for.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
A list of reranked documents.
|
138 |
+
"""
|
139 |
+
|
140 |
+
# Get the results of all retrievers.
|
141 |
+
retriever_docs = [
|
142 |
+
retriever.get_relevant_documents(
|
143 |
+
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
144 |
+
)
|
145 |
+
for i, retriever in enumerate(self.retrievers)
|
146 |
+
]
|
147 |
+
|
148 |
+
# apply rank fusion
|
149 |
+
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
150 |
+
|
151 |
+
return fused_documents
|
152 |
+
|
153 |
+
async def arank_fusion(
|
154 |
+
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
155 |
+
) -> List[Document]:
|
156 |
+
"""
|
157 |
+
Asynchronously retrieve the results of the retrievers
|
158 |
+
and use rank_fusion_func to get the final result.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
query: The query to search for.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
A list of reranked documents.
|
165 |
+
"""
|
166 |
+
|
167 |
+
# Get the results of all retrievers.
|
168 |
+
retriever_docs = [
|
169 |
+
await retriever.aget_relevant_documents(
|
170 |
+
query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
|
171 |
+
)
|
172 |
+
for i, retriever in enumerate(self.retrievers)
|
173 |
+
]
|
174 |
+
|
175 |
+
# apply rank fusion
|
176 |
+
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
|
177 |
+
|
178 |
+
return fused_documents
|
179 |
+
|
180 |
+
def weighted_reciprocal_rank(
|
181 |
+
self, doc_lists: List[List[Document]]
|
182 |
+
) -> List[Document]:
|
183 |
+
"""
|
184 |
+
Perform weighted Reciprocal Rank Fusion on multiple rank lists.
|
185 |
+
You can find more details about RRF here:
|
186 |
+
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
187 |
+
|
188 |
+
Args:
|
189 |
+
doc_lists: A list of rank lists, where each rank list contains unique items.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
list: The final aggregated list of items sorted by their weighted RRF
|
193 |
+
scores in descending order.
|
194 |
+
"""
|
195 |
+
if len(doc_lists) != len(self.weights):
|
196 |
+
raise ValueError(
|
197 |
+
"Number of rank lists must be equal to the number of weights."
|
198 |
+
)
|
199 |
+
|
200 |
+
# Create a union of all unique documents in the input doc_lists
|
201 |
+
all_documents = set()
|
202 |
+
for doc_list in doc_lists:
|
203 |
+
for doc in doc_list:
|
204 |
+
all_documents.add(doc.page_content)
|
205 |
+
|
206 |
+
# Initialize the RRF score dictionary for each document
|
207 |
+
rrf_score_dic = {doc: 0.0 for doc in all_documents}
|
208 |
+
|
209 |
+
# Calculate RRF scores for each document
|
210 |
+
for doc_list, weight in zip(doc_lists, self.weights):
|
211 |
+
for rank, doc in enumerate(doc_list, start=1):
|
212 |
+
rrf_score = weight * (1 / (rank + self.c))
|
213 |
+
rrf_score_dic[doc.page_content] += rrf_score
|
214 |
+
|
215 |
+
# Sort documents by their RRF scores in descending order
|
216 |
+
sorted_documents = sorted(
|
217 |
+
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
|
218 |
+
)
|
219 |
+
|
220 |
+
# Map the sorted page_content back to the original document objects
|
221 |
+
page_content_to_doc_map = {
|
222 |
+
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
|
223 |
+
}
|
224 |
+
sorted_docs = [
|
225 |
+
page_content_to_doc_map[page_content] for page_content in sorted_documents
|
226 |
+
]
|
227 |
+
|
228 |
+
return sorted_docs
|
multi_query_retriever.py β reggpt/retriever/multi_query_retriever.py
RENAMED
@@ -1,254 +1,254 @@
|
|
1 |
-
|
2 |
-
"""
|
3 |
-
/*************************************************************************
|
4 |
-
*
|
5 |
-
* CONFIDENTIAL
|
6 |
-
* __________________
|
7 |
-
*
|
8 |
-
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
9 |
-
* All Rights Reserved
|
10 |
-
*
|
11 |
-
* Author : Theekshana Samaradiwakara
|
12 |
-
* Description :Python Backend API to chat with private data
|
13 |
-
* CreatedDate : 14/11/2023
|
14 |
-
* LastModifiedDate : 21/03/2024
|
15 |
-
*************************************************************************/
|
16 |
-
"""
|
17 |
-
|
18 |
-
import asyncio
|
19 |
-
import logging
|
20 |
-
from typing import List, Optional, Sequence
|
21 |
-
|
22 |
-
from langchain_core.callbacks import (
|
23 |
-
AsyncCallbackManagerForRetrieverRun,
|
24 |
-
CallbackManagerForRetrieverRun,
|
25 |
-
)
|
26 |
-
from langchain_core.documents import Document
|
27 |
-
from langchain_core.language_models import BaseLanguageModel
|
28 |
-
from langchain_core.output_parsers import BaseOutputParser
|
29 |
-
from langchain_core.prompts.prompt import PromptTemplate
|
30 |
-
from langchain_core.retrievers import BaseRetriever
|
31 |
-
|
32 |
-
from langchain.chains.llm import LLMChain
|
33 |
-
|
34 |
-
import numpy as np
|
35 |
-
import pandas as pd
|
36 |
-
|
37 |
-
logger = logging.getLogger(__name__)
|
38 |
-
|
39 |
-
from prompts import MULTY_QUERY_PROMPT
|
40 |
-
|
41 |
-
class LineListOutputParser(BaseOutputParser[List[str]]):
|
42 |
-
"""Output parser for a list of lines."""
|
43 |
-
|
44 |
-
def parse(self, text: str) -> List[str]:
|
45 |
-
lines = text.strip().split("\n")
|
46 |
-
return lines
|
47 |
-
|
48 |
-
|
49 |
-
# Default prompt
|
50 |
-
# DEFAULT_QUERY_PROMPT = PromptTemplate(
|
51 |
-
# input_variables=["question"],
|
52 |
-
# template="""You are an AI language model assistant. Your task is
|
53 |
-
# to generate 3 different versions of the given user
|
54 |
-
# question to retrieve relevant documents from a vector database.
|
55 |
-
# By generating multiple perspectives on the user question,
|
56 |
-
# your goal is to help the user overcome some of the limitations
|
57 |
-
# of distance-based similarity search. Provide these alternative
|
58 |
-
# questions separated by newlines. Original question: {question}""",
|
59 |
-
# )
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
|
65 |
-
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
|
66 |
-
|
67 |
-
|
68 |
-
class MultiQueryRetriever(BaseRetriever):
|
69 |
-
"""Given a query, use an LLM to write a set of queries.
|
70 |
-
|
71 |
-
Retrieve docs for each query. Return the unique union of all retrieved docs.
|
72 |
-
"""
|
73 |
-
|
74 |
-
retriever: BaseRetriever
|
75 |
-
llm_chain: LLMChain
|
76 |
-
verbose: bool = True
|
77 |
-
parser_key: str = "lines"
|
78 |
-
"""DEPRECATED. parser_key is no longer used and should not be specified."""
|
79 |
-
include_original: bool = False
|
80 |
-
"""Whether to include the original query in the list of generated queries."""
|
81 |
-
date_key: str = "year"
|
82 |
-
top_k: int = 4
|
83 |
-
|
84 |
-
@classmethod
|
85 |
-
def from_llm(
|
86 |
-
cls,
|
87 |
-
retriever: BaseRetriever,
|
88 |
-
llm: BaseLanguageModel,
|
89 |
-
prompt: PromptTemplate = MULTY_QUERY_PROMPT,
|
90 |
-
parser_key: Optional[str] = None,
|
91 |
-
include_original: bool = False,
|
92 |
-
) -> "MultiQueryRetriever":
|
93 |
-
"""Initialize from llm using default template.
|
94 |
-
|
95 |
-
Args:
|
96 |
-
retriever: retriever to query documents from
|
97 |
-
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
98 |
-
include_original: Whether to include the original query in the list of
|
99 |
-
generated queries.
|
100 |
-
|
101 |
-
Returns:
|
102 |
-
MultiQueryRetriever
|
103 |
-
"""
|
104 |
-
output_parser = LineListOutputParser()
|
105 |
-
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
|
106 |
-
return cls(
|
107 |
-
retriever=retriever,
|
108 |
-
llm_chain=llm_chain,
|
109 |
-
include_original=include_original,
|
110 |
-
)
|
111 |
-
|
112 |
-
async def _aget_relevant_documents(
|
113 |
-
self,
|
114 |
-
query: str,
|
115 |
-
*,
|
116 |
-
run_manager: AsyncCallbackManagerForRetrieverRun,
|
117 |
-
) -> List[Document]:
|
118 |
-
"""Get relevant documents given a user query.
|
119 |
-
|
120 |
-
Args:
|
121 |
-
question: user query
|
122 |
-
|
123 |
-
Returns:
|
124 |
-
Unique union of relevant documents from all generated queries
|
125 |
-
"""
|
126 |
-
queries = await self.agenerate_queries(query, run_manager)
|
127 |
-
if self.include_original:
|
128 |
-
queries.append(query)
|
129 |
-
documents = await self.aretrieve_documents(queries, run_manager)
|
130 |
-
return self.unique_union(documents)
|
131 |
-
|
132 |
-
|
133 |
-
async def agenerate_queries(
|
134 |
-
self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
135 |
-
) -> List[str]:
|
136 |
-
"""Generate queries based upon user input.
|
137 |
-
|
138 |
-
Args:
|
139 |
-
question: user query
|
140 |
-
|
141 |
-
Returns:
|
142 |
-
List of LLM generated queries that are similar to the user input
|
143 |
-
"""
|
144 |
-
response = await self.llm_chain.ainvoke(
|
145 |
-
inputs={"question": question}, callbacks=run_manager.get_child()
|
146 |
-
)
|
147 |
-
lines = response["text"]
|
148 |
-
if self.verbose:
|
149 |
-
logger.info(f"Generated queries: {lines}")
|
150 |
-
return lines
|
151 |
-
|
152 |
-
async def aretrieve_documents(
|
153 |
-
self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
|
154 |
-
) -> List[Document]:
|
155 |
-
"""Run all LLM generated queries.
|
156 |
-
|
157 |
-
Args:
|
158 |
-
queries: query list
|
159 |
-
|
160 |
-
Returns:
|
161 |
-
List of retrieved Documents
|
162 |
-
"""
|
163 |
-
document_lists = await asyncio.gather(
|
164 |
-
*(
|
165 |
-
self.retriever.aget_relevant_documents(
|
166 |
-
query, callbacks=run_manager.get_child()
|
167 |
-
)
|
168 |
-
for query in queries
|
169 |
-
)
|
170 |
-
)
|
171 |
-
return [doc for docs in document_lists for doc in docs]
|
172 |
-
|
173 |
-
def _get_relevant_documents(
|
174 |
-
self,
|
175 |
-
query: str,
|
176 |
-
*,
|
177 |
-
run_manager: CallbackManagerForRetrieverRun,
|
178 |
-
) -> List[Document]:
|
179 |
-
"""Get relevant documents given a user query.
|
180 |
-
|
181 |
-
Args:
|
182 |
-
question: user query
|
183 |
-
|
184 |
-
Returns:
|
185 |
-
Unique union of relevant documents from all generated queries
|
186 |
-
"""
|
187 |
-
queries = self.generate_queries(query, run_manager)
|
188 |
-
if self.include_original:
|
189 |
-
queries.append(query)
|
190 |
-
documents = self.retrieve_documents(queries, run_manager)
|
191 |
-
fused_documents= self.unique_union(documents)
|
192 |
-
# check for key exists
|
193 |
-
if fused_documents[0].metadata[self.date_key] != None:
|
194 |
-
doc_dates = pd.to_datetime(
|
195 |
-
[doc.metadata[self.date_key] for doc in fused_documents]
|
196 |
-
)
|
197 |
-
sorted_node_idxs = np.flip(doc_dates.argsort())
|
198 |
-
fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
|
199 |
-
logger.info('Documents sorted by year')
|
200 |
-
|
201 |
-
return fused_documents[:self.top_k]
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
def generate_queries(
|
206 |
-
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
207 |
-
) -> List[str]:
|
208 |
-
"""Generate queries based upon user input.
|
209 |
-
|
210 |
-
Args:
|
211 |
-
question: user query
|
212 |
-
|
213 |
-
Returns:
|
214 |
-
List of LLM generated queries that are similar to the user input
|
215 |
-
"""
|
216 |
-
response = self.llm_chain.invoke(
|
217 |
-
{"question": question}, callbacks=run_manager.get_child()
|
218 |
-
)
|
219 |
-
lines = response["text"]
|
220 |
-
if self.verbose:
|
221 |
-
logger.info(f"Generated queries: {lines}")
|
222 |
-
return lines
|
223 |
-
|
224 |
-
def retrieve_documents(
|
225 |
-
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
|
226 |
-
) -> List[Document]:
|
227 |
-
"""Run all LLM generated queries.
|
228 |
-
|
229 |
-
Args:
|
230 |
-
queries: query list
|
231 |
-
|
232 |
-
Returns:
|
233 |
-
List of retrieved Documents
|
234 |
-
"""
|
235 |
-
documents = []
|
236 |
-
for query in queries:
|
237 |
-
logger.info(f"MQ Retriever question: {query}")
|
238 |
-
docs = self.retriever.get_relevant_documents(
|
239 |
-
query, callbacks=run_manager.get_child()
|
240 |
-
)
|
241 |
-
documents.extend(docs)
|
242 |
-
return documents
|
243 |
-
|
244 |
-
def unique_union(self, documents: List[Document]) -> List[Document]:
|
245 |
-
"""Get unique Documents.
|
246 |
-
|
247 |
-
Args:
|
248 |
-
documents: List of retrieved Documents
|
249 |
-
|
250 |
-
Returns:
|
251 |
-
List of unique retrieved Documents
|
252 |
-
"""
|
253 |
-
return _unique_documents(documents)
|
254 |
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
/*************************************************************************
|
4 |
+
*
|
5 |
+
* CONFIDENTIAL
|
6 |
+
* __________________
|
7 |
+
*
|
8 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
9 |
+
* All Rights Reserved
|
10 |
+
*
|
11 |
+
* Author : Theekshana Samaradiwakara
|
12 |
+
* Description :Python Backend API to chat with private data
|
13 |
+
* CreatedDate : 14/11/2023
|
14 |
+
* LastModifiedDate : 21/03/2024
|
15 |
+
*************************************************************************/
|
16 |
+
"""
|
17 |
+
|
18 |
+
import asyncio
|
19 |
+
import logging
|
20 |
+
from typing import List, Optional, Sequence
|
21 |
+
|
22 |
+
from langchain_core.callbacks import (
|
23 |
+
AsyncCallbackManagerForRetrieverRun,
|
24 |
+
CallbackManagerForRetrieverRun,
|
25 |
+
)
|
26 |
+
from langchain_core.documents import Document
|
27 |
+
from langchain_core.language_models import BaseLanguageModel
|
28 |
+
from langchain_core.output_parsers import BaseOutputParser
|
29 |
+
from langchain_core.prompts.prompt import PromptTemplate
|
30 |
+
from langchain_core.retrievers import BaseRetriever
|
31 |
+
|
32 |
+
from langchain.chains.llm import LLMChain
|
33 |
+
|
34 |
+
import numpy as np
|
35 |
+
import pandas as pd
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
from reggpt.prompts.prompts import MULTY_QUERY_PROMPT
|
40 |
+
|
41 |
+
class LineListOutputParser(BaseOutputParser[List[str]]):
|
42 |
+
"""Output parser for a list of lines."""
|
43 |
+
|
44 |
+
def parse(self, text: str) -> List[str]:
|
45 |
+
lines = text.strip().split("\n")
|
46 |
+
return lines
|
47 |
+
|
48 |
+
|
49 |
+
# Default prompt
|
50 |
+
# DEFAULT_QUERY_PROMPT = PromptTemplate(
|
51 |
+
# input_variables=["question"],
|
52 |
+
# template="""You are an AI language model assistant. Your task is
|
53 |
+
# to generate 3 different versions of the given user
|
54 |
+
# question to retrieve relevant documents from a vector database.
|
55 |
+
# By generating multiple perspectives on the user question,
|
56 |
+
# your goal is to help the user overcome some of the limitations
|
57 |
+
# of distance-based similarity search. Provide these alternative
|
58 |
+
# questions separated by newlines. Original question: {question}""",
|
59 |
+
# )
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
|
65 |
+
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
|
66 |
+
|
67 |
+
|
68 |
+
class MultiQueryRetriever(BaseRetriever):
|
69 |
+
"""Given a query, use an LLM to write a set of queries.
|
70 |
+
|
71 |
+
Retrieve docs for each query. Return the unique union of all retrieved docs.
|
72 |
+
"""
|
73 |
+
|
74 |
+
retriever: BaseRetriever
|
75 |
+
llm_chain: LLMChain
|
76 |
+
verbose: bool = True
|
77 |
+
parser_key: str = "lines"
|
78 |
+
"""DEPRECATED. parser_key is no longer used and should not be specified."""
|
79 |
+
include_original: bool = False
|
80 |
+
"""Whether to include the original query in the list of generated queries."""
|
81 |
+
date_key: str = "year"
|
82 |
+
top_k: int = 4
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def from_llm(
|
86 |
+
cls,
|
87 |
+
retriever: BaseRetriever,
|
88 |
+
llm: BaseLanguageModel,
|
89 |
+
prompt: PromptTemplate = MULTY_QUERY_PROMPT,
|
90 |
+
parser_key: Optional[str] = None,
|
91 |
+
include_original: bool = False,
|
92 |
+
) -> "MultiQueryRetriever":
|
93 |
+
"""Initialize from llm using default template.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
retriever: retriever to query documents from
|
97 |
+
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
98 |
+
include_original: Whether to include the original query in the list of
|
99 |
+
generated queries.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
MultiQueryRetriever
|
103 |
+
"""
|
104 |
+
output_parser = LineListOutputParser()
|
105 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
|
106 |
+
return cls(
|
107 |
+
retriever=retriever,
|
108 |
+
llm_chain=llm_chain,
|
109 |
+
include_original=include_original,
|
110 |
+
)
|
111 |
+
|
112 |
+
async def _aget_relevant_documents(
|
113 |
+
self,
|
114 |
+
query: str,
|
115 |
+
*,
|
116 |
+
run_manager: AsyncCallbackManagerForRetrieverRun,
|
117 |
+
) -> List[Document]:
|
118 |
+
"""Get relevant documents given a user query.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
question: user query
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Unique union of relevant documents from all generated queries
|
125 |
+
"""
|
126 |
+
queries = await self.agenerate_queries(query, run_manager)
|
127 |
+
if self.include_original:
|
128 |
+
queries.append(query)
|
129 |
+
documents = await self.aretrieve_documents(queries, run_manager)
|
130 |
+
return self.unique_union(documents)
|
131 |
+
|
132 |
+
|
133 |
+
async def agenerate_queries(
|
134 |
+
self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
|
135 |
+
) -> List[str]:
|
136 |
+
"""Generate queries based upon user input.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
question: user query
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
List of LLM generated queries that are similar to the user input
|
143 |
+
"""
|
144 |
+
response = await self.llm_chain.ainvoke(
|
145 |
+
inputs={"question": question}, callbacks=run_manager.get_child()
|
146 |
+
)
|
147 |
+
lines = response["text"]
|
148 |
+
if self.verbose:
|
149 |
+
logger.info(f"Generated queries: {lines}")
|
150 |
+
return lines
|
151 |
+
|
152 |
+
async def aretrieve_documents(
|
153 |
+
self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
|
154 |
+
) -> List[Document]:
|
155 |
+
"""Run all LLM generated queries.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
queries: query list
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
List of retrieved Documents
|
162 |
+
"""
|
163 |
+
document_lists = await asyncio.gather(
|
164 |
+
*(
|
165 |
+
self.retriever.aget_relevant_documents(
|
166 |
+
query, callbacks=run_manager.get_child()
|
167 |
+
)
|
168 |
+
for query in queries
|
169 |
+
)
|
170 |
+
)
|
171 |
+
return [doc for docs in document_lists for doc in docs]
|
172 |
+
|
173 |
+
def _get_relevant_documents(
|
174 |
+
self,
|
175 |
+
query: str,
|
176 |
+
*,
|
177 |
+
run_manager: CallbackManagerForRetrieverRun,
|
178 |
+
) -> List[Document]:
|
179 |
+
"""Get relevant documents given a user query.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
question: user query
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
Unique union of relevant documents from all generated queries
|
186 |
+
"""
|
187 |
+
queries = self.generate_queries(query, run_manager)
|
188 |
+
if self.include_original:
|
189 |
+
queries.append(query)
|
190 |
+
documents = self.retrieve_documents(queries, run_manager)
|
191 |
+
fused_documents= self.unique_union(documents)
|
192 |
+
# check for key exists
|
193 |
+
if fused_documents[0].metadata[self.date_key] != None:
|
194 |
+
doc_dates = pd.to_datetime(
|
195 |
+
[doc.metadata[self.date_key] for doc in fused_documents]
|
196 |
+
)
|
197 |
+
sorted_node_idxs = np.flip(doc_dates.argsort())
|
198 |
+
fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
|
199 |
+
logger.info('Documents sorted by year')
|
200 |
+
|
201 |
+
return fused_documents[:self.top_k]
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
def generate_queries(
|
206 |
+
self, question: str, run_manager: CallbackManagerForRetrieverRun
|
207 |
+
) -> List[str]:
|
208 |
+
"""Generate queries based upon user input.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
question: user query
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
List of LLM generated queries that are similar to the user input
|
215 |
+
"""
|
216 |
+
response = self.llm_chain.invoke(
|
217 |
+
{"question": question}, callbacks=run_manager.get_child()
|
218 |
+
)
|
219 |
+
lines = response["text"]
|
220 |
+
if self.verbose:
|
221 |
+
logger.info(f"Generated queries: {lines}")
|
222 |
+
return lines
|
223 |
+
|
224 |
+
def retrieve_documents(
|
225 |
+
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
|
226 |
+
) -> List[Document]:
|
227 |
+
"""Run all LLM generated queries.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
queries: query list
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
List of retrieved Documents
|
234 |
+
"""
|
235 |
+
documents = []
|
236 |
+
for query in queries:
|
237 |
+
logger.info(f"MQ Retriever question: {query}")
|
238 |
+
docs = self.retriever.get_relevant_documents(
|
239 |
+
query, callbacks=run_manager.get_child()
|
240 |
+
)
|
241 |
+
documents.extend(docs)
|
242 |
+
return documents
|
243 |
+
|
244 |
+
def unique_union(self, documents: List[Document]) -> List[Document]:
|
245 |
+
"""Get unique Documents.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
documents: List of retrieved Documents
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
List of unique retrieved Documents
|
252 |
+
"""
|
253 |
+
return _unique_documents(documents)
|
254 |
|
reggpt/routers/__init__.py
ADDED
File without changes
|
controller.py β reggpt/routers/controller.py
RENAMED
@@ -16,13 +16,13 @@
|
|
16 |
|
17 |
import logging
|
18 |
logger = logging.getLogger(__name__)
|
19 |
-
from config import AVALIABLE_MODELS , MEMORY_WINDOW_K
|
20 |
|
21 |
# from qaPipeline import QAPipeline
|
22 |
# from qaPipeline_retriever_only import QAPipeline
|
23 |
# qaPipeline = QAPipeline()
|
24 |
|
25 |
-
from qaPipeline import run_agent
|
26 |
|
27 |
def get_QA_Answers(userQuery):
|
28 |
# model=userQuery.model
|
|
|
16 |
|
17 |
import logging
|
18 |
logger = logging.getLogger(__name__)
|
19 |
+
from reggpt.configs.config import AVALIABLE_MODELS , MEMORY_WINDOW_K
|
20 |
|
21 |
# from qaPipeline import QAPipeline
|
22 |
# from qaPipeline_retriever_only import QAPipeline
|
23 |
# qaPipeline = QAPipeline()
|
24 |
|
25 |
+
from reggpt.routers.qaPipeline import run_agent
|
26 |
|
27 |
def get_QA_Answers(userQuery):
|
28 |
# model=userQuery.model
|
reggpt/routers/general.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import logging
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from fastapi import HTTPException
|
8 |
+
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
9 |
+
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
10 |
+
|
11 |
+
from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
12 |
+
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
verbose = os.environ.get('VERBOSE')
|
16 |
+
|
17 |
+
qa_model_type=QA_MODEL_TYPE
|
18 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
19 |
+
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
20 |
+
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
21 |
+
# model_type="tiiuae/falcon-7b-instruct"
|
22 |
+
|
23 |
+
# retriever=load_faiss_retriever()
|
24 |
+
retriever=load_ensemble_retriever()
|
25 |
+
# retriever=load_multi_query_retriever(multi_query_model_type)
|
26 |
+
logger.info("retriever loaded:")
|
27 |
+
|
28 |
+
qa_chain= get_qa_chain(qa_model_type,retriever)
|
29 |
+
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
30 |
+
router_chain= get_router_chain(router_model_type)
|
31 |
+
|
32 |
+
def run_general_qa_chain(query):
|
33 |
+
try:
|
34 |
+
logger.info(f"run_general_qa_chain : Question: {query}")
|
35 |
+
|
36 |
+
# Get the answer from the chain
|
37 |
+
start = time.time()
|
38 |
+
res = general_qa_chain.invoke(query)
|
39 |
+
end = time.time()
|
40 |
+
|
41 |
+
# log the result
|
42 |
+
|
43 |
+
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
|
44 |
+
|
45 |
+
return general_qa_chain_output_parser(res)
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
logger.exception(e)
|
49 |
+
raise e
|
reggpt/routers/out_of_domain.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from fastapi import HTTPException
|
7 |
+
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
8 |
+
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
9 |
+
|
10 |
+
from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
11 |
+
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
verbose = os.environ.get('VERBOSE')
|
15 |
+
|
16 |
+
qa_model_type=QA_MODEL_TYPE
|
17 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
18 |
+
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
19 |
+
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
20 |
+
# model_type="tiiuae/falcon-7b-instruct"
|
21 |
+
|
22 |
+
# retriever=load_faiss_retriever()
|
23 |
+
retriever=load_ensemble_retriever()
|
24 |
+
# retriever=load_multi_query_retriever(multi_query_model_type)
|
25 |
+
logger.info("retriever loaded:")
|
26 |
+
|
27 |
+
qa_chain= get_qa_chain(qa_model_type,retriever)
|
28 |
+
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
29 |
+
router_chain= get_router_chain(router_model_type)
|
30 |
+
def run_out_of_domain_chain(query):
|
31 |
+
return out_of_domain_chain_parser(query)
|
reggpt/routers/qa.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
import logging
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
from fastapi import HTTPException
|
23 |
+
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
24 |
+
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
25 |
+
|
26 |
+
from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
27 |
+
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
28 |
+
load_dotenv()
|
29 |
+
|
30 |
+
verbose = os.environ.get('VERBOSE')
|
31 |
+
|
32 |
+
qa_model_type=QA_MODEL_TYPE
|
33 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
34 |
+
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
35 |
+
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
36 |
+
# model_type="tiiuae/falcon-7b-instruct"
|
37 |
+
|
38 |
+
# retriever=load_faiss_retriever()
|
39 |
+
retriever=load_ensemble_retriever()
|
40 |
+
# retriever=load_multi_query_retriever(multi_query_model_type)
|
41 |
+
logger.info("retriever loaded:")
|
42 |
+
|
43 |
+
qa_chain= get_qa_chain(qa_model_type,retriever)
|
44 |
+
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
45 |
+
router_chain= get_router_chain(router_model_type)
|
46 |
+
|
47 |
+
def run_qa_chain(query):
|
48 |
+
try:
|
49 |
+
logger.info(f"run_qa_chain : Question: {query}")
|
50 |
+
# Get the answer from the chain
|
51 |
+
start = time.time()
|
52 |
+
# res = qa_chain(query)
|
53 |
+
res = qa_chain.invoke({"question": query, "chat_history":""})
|
54 |
+
# res = response
|
55 |
+
# answer, docs = res['result'],res['source_documents']
|
56 |
+
end = time.time()
|
57 |
+
|
58 |
+
# log the result
|
59 |
+
logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
|
60 |
+
|
61 |
+
return qa_chain_output_parser(res)
|
62 |
+
|
63 |
+
except Exception as e:
|
64 |
+
logger.exception(e)
|
65 |
+
raise e
|
66 |
+
|
reggpt/routers/qaPipeline.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
/*************************************************************************
|
3 |
+
*
|
4 |
+
* CONFIDENTIAL
|
5 |
+
* __________________
|
6 |
+
*
|
7 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
8 |
+
* All Rights Reserved
|
9 |
+
*
|
10 |
+
* Author : Theekshana Samaradiwakara
|
11 |
+
* Description :Python Backend API to chat with private data
|
12 |
+
* CreatedDate : 14/11/2023
|
13 |
+
* LastModifiedDate : 18/03/2024
|
14 |
+
*************************************************************************/
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
import logging
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
from dotenv import load_dotenv
|
22 |
+
from fastapi import HTTPException
|
23 |
+
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
24 |
+
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
25 |
+
|
26 |
+
from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
27 |
+
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
28 |
+
load_dotenv()
|
29 |
+
|
30 |
+
verbose = os.environ.get('VERBOSE')
|
31 |
+
|
32 |
+
qa_model_type=QA_MODEL_TYPE
|
33 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
34 |
+
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
35 |
+
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
36 |
+
# model_type="tiiuae/falcon-7b-instruct"
|
37 |
+
|
38 |
+
# retriever=load_faiss_retriever()
|
39 |
+
retriever=load_ensemble_retriever()
|
40 |
+
# retriever=load_multi_query_retriever(multi_query_model_type)
|
41 |
+
logger.info("retriever loaded:")
|
42 |
+
|
43 |
+
qa_chain= get_qa_chain(qa_model_type,retriever)
|
44 |
+
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
45 |
+
router_chain= get_router_chain(router_model_type)
|
reggpt/schemas/__init__.py
ADDED
File without changes
|
schema.py β reggpt/schemas/schema.py
RENAMED
File without changes
|
reggpt/utils/__init__.py
ADDED
File without changes
|
retriever.py β reggpt/utils/retriever.py
RENAMED
@@ -1,137 +1,137 @@
|
|
1 |
-
|
2 |
-
"""
|
3 |
-
/*************************************************************************
|
4 |
-
*
|
5 |
-
* CONFIDENTIAL
|
6 |
-
* __________________
|
7 |
-
*
|
8 |
-
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
9 |
-
* All Rights Reserved
|
10 |
-
*
|
11 |
-
* Author : Theekshana Samaradiwakara
|
12 |
-
* Description :Python Backend API to chat with private data
|
13 |
-
* CreatedDate : 19/03/2023
|
14 |
-
* LastModifiedDate : 19/03/2024
|
15 |
-
*************************************************************************/
|
16 |
-
"""
|
17 |
-
|
18 |
-
"""
|
19 |
-
Ensemble retriever that ensemble the results of
|
20 |
-
multiple retrievers by using weighted Reciprocal Rank Fusion
|
21 |
-
"""
|
22 |
-
|
23 |
-
import logging
|
24 |
-
logger = logging.getLogger(__name__)
|
25 |
-
|
26 |
-
from faissDb import load_FAISS_store
|
27 |
-
|
28 |
-
from langchain_community.retrievers import BM25Retriever
|
29 |
-
from langchain_community.document_loaders import PyPDFLoader
|
30 |
-
from langchain_community.document_loaders import DirectoryLoader
|
31 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
32 |
-
|
33 |
-
from langchain.schema import Document
|
34 |
-
from typing import Iterable
|
35 |
-
import json
|
36 |
-
|
37 |
-
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
|
38 |
-
with open(file_path, 'w') as jsonl_file:
|
39 |
-
for doc in array:
|
40 |
-
jsonl_file.write(doc.json() + '\n')
|
41 |
-
|
42 |
-
def load_docs_from_jsonl(file_path)->Iterable[Document]:
|
43 |
-
array = []
|
44 |
-
with open(file_path, 'r') as jsonl_file:
|
45 |
-
for line in jsonl_file:
|
46 |
-
data = json.loads(line)
|
47 |
-
obj = Document(**data)
|
48 |
-
array.append(obj)
|
49 |
-
return array
|
50 |
-
|
51 |
-
def split_documents():
|
52 |
-
chunk_size=2000
|
53 |
-
chunk_overlap=100
|
54 |
-
|
55 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
56 |
-
|
57 |
-
years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
58 |
-
docs_list=[]
|
59 |
-
splits_list=[]
|
60 |
-
|
61 |
-
|
62 |
-
for year in years:
|
63 |
-
data_path= f"data/CBSL/{year}"
|
64 |
-
logger.info(f"Loading year : {data_path}")
|
65 |
-
|
66 |
-
documents = DirectoryLoader(data_path, loader_cls=PyPDFLoader).load()
|
67 |
-
|
68 |
-
for doc in documents:
|
69 |
-
doc.metadata['year']=year
|
70 |
-
logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" )
|
71 |
-
docs_list.append(doc)
|
72 |
-
|
73 |
-
texts = text_splitter.split_documents(documents)
|
74 |
-
for text in texts:
|
75 |
-
splits_list.append(text)
|
76 |
-
|
77 |
-
splitted_texts_file='data/splitted_texts.jsonl'
|
78 |
-
save_docs_to_jsonl(splits_list,splitted_texts_file)
|
79 |
-
|
80 |
-
from ensemble_retriever import EnsembleRetriever
|
81 |
-
from multi_query_retriever import MultiQueryRetriever
|
82 |
-
|
83 |
-
def load_faiss_retriever():
|
84 |
-
try:
|
85 |
-
vectorstore=load_FAISS_store()
|
86 |
-
retriever = vectorstore.as_retriever(
|
87 |
-
# search_type="mmr",
|
88 |
-
search_kwargs={'k': 5, 'fetch_k': 10}
|
89 |
-
)
|
90 |
-
logger.info("FAISS Retriever loaded:")
|
91 |
-
return retriever
|
92 |
-
|
93 |
-
except Exception as e:
|
94 |
-
logger.exception(e)
|
95 |
-
raise e
|
96 |
-
|
97 |
-
def load_ensemble_retriever():
|
98 |
-
try:
|
99 |
-
# splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl'))
|
100 |
-
splitted_texts_file='./data/splitted_texts.jsonl'
|
101 |
-
sementic_k = 4
|
102 |
-
bm25_k = 2
|
103 |
-
splits_list = load_docs_from_jsonl(splitted_texts_file)
|
104 |
-
|
105 |
-
bm25_retriever = BM25Retriever.from_documents(splits_list)
|
106 |
-
bm25_retriever.k = bm25_k
|
107 |
-
|
108 |
-
faiss_vectorstore = load_FAISS_store()
|
109 |
-
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,})
|
110 |
-
|
111 |
-
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
|
112 |
-
ensemble_retriever.top_k=4
|
113 |
-
|
114 |
-
logger.info("EnsembleRetriever loaded:")
|
115 |
-
return ensemble_retriever
|
116 |
-
|
117 |
-
except Exception as e:
|
118 |
-
logger.exception(e)
|
119 |
-
raise e
|
120 |
-
|
121 |
-
from llm import get_model
|
122 |
-
|
123 |
-
def load_multi_query_retriever(multi_query_model_type):
|
124 |
-
#multi query
|
125 |
-
try:
|
126 |
-
llm = get_model(multi_query_model_type)
|
127 |
-
ensembleRetriever = load_ensemble_retriever()
|
128 |
-
retriever = MultiQueryRetriever.from_llm(
|
129 |
-
retriever=ensembleRetriever,
|
130 |
-
llm=llm
|
131 |
-
)
|
132 |
-
logger.info("MultiQueryRetriever loaded:")
|
133 |
-
return retriever
|
134 |
-
|
135 |
-
except Exception as e:
|
136 |
-
logger.exception(e)
|
137 |
raise e
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
/*************************************************************************
|
4 |
+
*
|
5 |
+
* CONFIDENTIAL
|
6 |
+
* __________________
|
7 |
+
*
|
8 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
9 |
+
* All Rights Reserved
|
10 |
+
*
|
11 |
+
* Author : Theekshana Samaradiwakara
|
12 |
+
* Description :Python Backend API to chat with private data
|
13 |
+
* CreatedDate : 19/03/2023
|
14 |
+
* LastModifiedDate : 19/03/2024
|
15 |
+
*************************************************************************/
|
16 |
+
"""
|
17 |
+
|
18 |
+
"""
|
19 |
+
Ensemble retriever that ensemble the results of
|
20 |
+
multiple retrievers by using weighted Reciprocal Rank Fusion
|
21 |
+
"""
|
22 |
+
|
23 |
+
import logging
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
from reggpt.vectorstores.faissDb import load_FAISS_store
|
27 |
+
|
28 |
+
from langchain_community.retrievers import BM25Retriever
|
29 |
+
from langchain_community.document_loaders import PyPDFLoader
|
30 |
+
from langchain_community.document_loaders import DirectoryLoader
|
31 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
32 |
+
|
33 |
+
from langchain.schema import Document
|
34 |
+
from typing import Iterable
|
35 |
+
import json
|
36 |
+
|
37 |
+
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
|
38 |
+
with open(file_path, 'w') as jsonl_file:
|
39 |
+
for doc in array:
|
40 |
+
jsonl_file.write(doc.json() + '\n')
|
41 |
+
|
42 |
+
def load_docs_from_jsonl(file_path)->Iterable[Document]:
|
43 |
+
array = []
|
44 |
+
with open(file_path, 'r') as jsonl_file:
|
45 |
+
for line in jsonl_file:
|
46 |
+
data = json.loads(line)
|
47 |
+
obj = Document(**data)
|
48 |
+
array.append(obj)
|
49 |
+
return array
|
50 |
+
|
51 |
+
def split_documents():
|
52 |
+
chunk_size=2000
|
53 |
+
chunk_overlap=100
|
54 |
+
|
55 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
56 |
+
|
57 |
+
years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]
|
58 |
+
docs_list=[]
|
59 |
+
splits_list=[]
|
60 |
+
|
61 |
+
|
62 |
+
for year in years:
|
63 |
+
data_path= f"data/CBSL/{year}"
|
64 |
+
logger.info(f"Loading year : {data_path}")
|
65 |
+
|
66 |
+
documents = DirectoryLoader(data_path, loader_cls=PyPDFLoader).load()
|
67 |
+
|
68 |
+
for doc in documents:
|
69 |
+
doc.metadata['year']=year
|
70 |
+
logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" )
|
71 |
+
docs_list.append(doc)
|
72 |
+
|
73 |
+
texts = text_splitter.split_documents(documents)
|
74 |
+
for text in texts:
|
75 |
+
splits_list.append(text)
|
76 |
+
|
77 |
+
splitted_texts_file='data/splitted_texts.jsonl'
|
78 |
+
save_docs_to_jsonl(splits_list,splitted_texts_file)
|
79 |
+
|
80 |
+
from ensemble_retriever import EnsembleRetriever
|
81 |
+
from multi_query_retriever import MultiQueryRetriever
|
82 |
+
|
83 |
+
def load_faiss_retriever():
|
84 |
+
try:
|
85 |
+
vectorstore=load_FAISS_store()
|
86 |
+
retriever = vectorstore.as_retriever(
|
87 |
+
# search_type="mmr",
|
88 |
+
search_kwargs={'k': 5, 'fetch_k': 10}
|
89 |
+
)
|
90 |
+
logger.info("FAISS Retriever loaded:")
|
91 |
+
return retriever
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.exception(e)
|
95 |
+
raise e
|
96 |
+
|
97 |
+
def load_ensemble_retriever():
|
98 |
+
try:
|
99 |
+
# splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl'))
|
100 |
+
splitted_texts_file='./data/splitted_texts.jsonl'
|
101 |
+
sementic_k = 4
|
102 |
+
bm25_k = 2
|
103 |
+
splits_list = load_docs_from_jsonl(splitted_texts_file)
|
104 |
+
|
105 |
+
bm25_retriever = BM25Retriever.from_documents(splits_list)
|
106 |
+
bm25_retriever.k = bm25_k
|
107 |
+
|
108 |
+
faiss_vectorstore = load_FAISS_store()
|
109 |
+
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,})
|
110 |
+
|
111 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
|
112 |
+
ensemble_retriever.top_k=4
|
113 |
+
|
114 |
+
logger.info("EnsembleRetriever loaded:")
|
115 |
+
return ensemble_retriever
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
logger.exception(e)
|
119 |
+
raise e
|
120 |
+
|
121 |
+
from reggpt.llms.llm import get_model
|
122 |
+
|
123 |
+
def load_multi_query_retriever(multi_query_model_type):
|
124 |
+
#multi query
|
125 |
+
try:
|
126 |
+
llm = get_model(multi_query_model_type)
|
127 |
+
ensembleRetriever = load_ensemble_retriever()
|
128 |
+
retriever = MultiQueryRetriever.from_llm(
|
129 |
+
retriever=ensembleRetriever,
|
130 |
+
llm=llm
|
131 |
+
)
|
132 |
+
logger.info("MultiQueryRetriever loaded:")
|
133 |
+
return retriever
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
logger.exception(e)
|
137 |
raise e
|
{utils β reggpt/utils}/utils.py
RENAMED
@@ -1,40 +1,40 @@
|
|
1 |
-
"""
|
2 |
-
Python Backend API to chat with private data
|
3 |
-
15/11/2023
|
4 |
-
Theekshana Samaradiwakara
|
5 |
-
"""
|
6 |
-
"""
|
7 |
-
/*************************************************************************
|
8 |
-
*
|
9 |
-
* CONFIDENTIAL
|
10 |
-
* __________________
|
11 |
-
*
|
12 |
-
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
13 |
-
* All Rights Reserved
|
14 |
-
*
|
15 |
-
* Author : Theekshana Samaradiwakara
|
16 |
-
* Description :Python Backend API to chat with private data
|
17 |
-
* CreatedDate : 15/11/2023
|
18 |
-
* LastModifiedDate : 10/12/2020
|
19 |
-
*************************************************************************/
|
20 |
-
"""
|
21 |
-
|
22 |
-
# from passlib.context import CryptContext
|
23 |
-
# pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
24 |
-
|
25 |
-
|
26 |
-
# def hash(password: str):
|
27 |
-
# return pwd_context.hash(password)
|
28 |
-
|
29 |
-
|
30 |
-
# def verify(plain_password, hashed_password):
|
31 |
-
# return pwd_context.verify(plain_password, hashed_password)
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
import re
|
36 |
-
def is_valid_open_ai_api_key(secretKey):
|
37 |
-
if re.search("^sk-[a-zA-Z0-9]{32,}$", secretKey ):
|
38 |
-
return True
|
39 |
-
else: return False
|
40 |
-
|
|
|
1 |
+
"""
|
2 |
+
Python Backend API to chat with private data
|
3 |
+
15/11/2023
|
4 |
+
Theekshana Samaradiwakara
|
5 |
+
"""
|
6 |
+
"""
|
7 |
+
/*************************************************************************
|
8 |
+
*
|
9 |
+
* CONFIDENTIAL
|
10 |
+
* __________________
|
11 |
+
*
|
12 |
+
* Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
|
13 |
+
* All Rights Reserved
|
14 |
+
*
|
15 |
+
* Author : Theekshana Samaradiwakara
|
16 |
+
* Description :Python Backend API to chat with private data
|
17 |
+
* CreatedDate : 15/11/2023
|
18 |
+
* LastModifiedDate : 10/12/2020
|
19 |
+
*************************************************************************/
|
20 |
+
"""
|
21 |
+
|
22 |
+
# from passlib.context import CryptContext
|
23 |
+
# pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
24 |
+
|
25 |
+
|
26 |
+
# def hash(password: str):
|
27 |
+
# return pwd_context.hash(password)
|
28 |
+
|
29 |
+
|
30 |
+
# def verify(plain_password, hashed_password):
|
31 |
+
# return pwd_context.verify(plain_password, hashed_password)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
import re
|
36 |
+
def is_valid_open_ai_api_key(secretKey):
|
37 |
+
if re.search("^sk-[a-zA-Z0-9]{32,}$", secretKey ):
|
38 |
+
return True
|
39 |
+
else: return False
|
40 |
+
|
reggpt/vectorstores/__init__.py
ADDED
File without changes
|