Hammaad commited on
Commit
93bc171
β€’
1 Parent(s): 92e5bb7

code refactor part 1 complete need to test

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .dockerignore +28 -28
  2. .gitignore +137 -135
  3. CHANGELOG.txt +3 -2
  4. Dockerfile +41 -41
  5. LICENSE +21 -21
  6. app.py +4 -110
  7. faiss_embeddings_2024/index.faiss +0 -3
  8. faiss_embeddings_2024/index.pkl +0 -3
  9. prompts.py +0 -123
  10. .env.example β†’ reggpt/.env.example +0 -0
  11. {utils β†’ reggpt}/__init__.py +0 -0
  12. reggpt/agent/agent.py +33 -0
  13. reggpt/api/__init__.py +0 -0
  14. server.py β†’ reggpt/api/router.py +10 -79
  15. {configs β†’ reggpt/chains}/__init__.py +5 -5
  16. llmChain.py β†’ reggpt/chains/llmChain.py +99 -96
  17. {data β†’ reggpt/configs}/__init__.py +5 -5
  18. reggpt/configs/api.py +28 -0
  19. config.py β†’ reggpt/configs/config.py +35 -35
  20. {configs β†’ reggpt/configs}/logger.py +39 -39
  21. reggpt/controller/__init__.py +0 -0
  22. qaPipeline.py β†’ reggpt/controller/agent.py +73 -150
  23. reggpt/controller/router.py +62 -0
  24. reggpt/data/__init__.py +5 -0
  25. {data β†’ reggpt/data}/splitted_texts.jsonl +0 -0
  26. llm.py β†’ reggpt/llms/llm.py +46 -46
  27. reggpt/memory/__init__.py +0 -0
  28. conversationBufferWindowMemory.py β†’ reggpt/memory/conversationBufferWindowMemory.py +133 -133
  29. reggpt/output_parsers/__init__.py +0 -0
  30. output_parser.py β†’ reggpt/output_parsers/output_parser.py +0 -0
  31. reggpt/prompts/__init__.py +0 -0
  32. reggpt/prompts/document_combine.py +7 -0
  33. reggpt/prompts/general.py +25 -0
  34. reggpt/prompts/multi_query.py +20 -0
  35. reggpt/prompts/retrieval.py +33 -0
  36. reggpt/prompts/router.py +17 -0
  37. ensemble_retriever.py β†’ reggpt/retriever/ensemble_retriever.py +228 -228
  38. multi_query_retriever.py β†’ reggpt/retriever/multi_query_retriever.py +253 -253
  39. reggpt/routers/__init__.py +0 -0
  40. controller.py β†’ reggpt/routers/controller.py +2 -2
  41. reggpt/routers/general.py +49 -0
  42. reggpt/routers/out_of_domain.py +31 -0
  43. reggpt/routers/qa.py +66 -0
  44. reggpt/routers/qaPipeline.py +45 -0
  45. reggpt/schemas/__init__.py +0 -0
  46. schema.py β†’ reggpt/schemas/schema.py +0 -0
  47. reggpt/utils/__init__.py +0 -0
  48. retriever.py β†’ reggpt/utils/retriever.py +136 -136
  49. {utils β†’ reggpt/utils}/utils.py +40 -40
  50. reggpt/vectorstores/__init__.py +0 -0
.dockerignore CHANGED
@@ -1,29 +1,29 @@
1
- # Ignore node_modules
2
- node_modules
3
-
4
- # Ignore logs
5
- logs
6
- *.log
7
-
8
- # Ignore temporary files
9
- tmp
10
- *.tmp
11
-
12
- # Ignore build directories
13
- dist
14
- build
15
-
16
- # Ignore environment variables
17
- .env
18
-
19
- # Ignore Docker files
20
- Dockerfile
21
- docker-compose.yml
22
-
23
- # Ignore IDE specific files
24
- .vscode
25
- .idea
26
-
27
- # Ignore OS generated files
28
- .DS_Store
29
  Thumbs.db
 
1
+ # Ignore node_modules
2
+ node_modules
3
+
4
+ # Ignore logs
5
+ logs
6
+ *.log
7
+
8
+ # Ignore temporary files
9
+ tmp
10
+ *.tmp
11
+
12
+ # Ignore build directories
13
+ dist
14
+ build
15
+
16
+ # Ignore environment variables
17
+ .env
18
+
19
+ # Ignore Docker files
20
+ Dockerfile
21
+ docker-compose.yml
22
+
23
+ # Ignore IDE specific files
24
+ .vscode
25
+ .idea
26
+
27
+ # Ignore OS generated files
28
+ .DS_Store
29
  Thumbs.db
.gitignore CHANGED
@@ -1,135 +1,137 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- pip-wheel-metadata/
24
- share/python-wheels/
25
- *.egg-info/
26
- .installed.cfg
27
- *.egg
28
- MANIFEST
29
-
30
- # PyInstaller
31
- # Usually these files are written by a python script from a template
32
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
- *.manifest
34
- *.spec
35
-
36
- # Installer logs
37
- pip-log.txt
38
- pip-delete-this-directory.txt
39
-
40
- # Unit test / coverage reports
41
- htmlcov/
42
- .tox/
43
- .nox/
44
- .coverage
45
- .coverage.*
46
- .cache
47
- nosetests.xml
48
- coverage.xml
49
- *.cover
50
- *.py,cover
51
- .hypothesis/
52
- .pytest_cache/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- target/
76
-
77
- # Jupyter Notebook
78
- .ipynb_checkpoints
79
-
80
- # IPython
81
- profile_default/
82
- ipython_config.py
83
-
84
- # pyenv
85
- .python-version
86
-
87
- # pipenv
88
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
- # install all needed dependencies.
92
- #Pipfile.lock
93
-
94
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
- __pypackages__/
96
-
97
- # Celery stuff
98
- celerybeat-schedule
99
- celerybeat.pid
100
-
101
- # SageMath parsed files
102
- *.sage.py
103
-
104
- # Environments
105
- .env
106
- .venv
107
- env/
108
- venv/
109
- ENV/
110
- env.bak/
111
- venv.bak/
112
-
113
- # Spyder project settings
114
- .spyderproject
115
- .spyproject
116
-
117
- # Rope project settings
118
- .ropeproject
119
-
120
- # mkdocs documentation
121
- /site
122
-
123
- # mypy
124
- .mypy_cache/
125
- .dmypy.json
126
- dmypy.json
127
-
128
- # Pyre type checker
129
- .pyre/
130
-
131
- # testing files generated
132
- *.txt.json
133
-
134
- *.ipynb
135
- env
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # testing files generated
132
+ *.txt.json
133
+
134
+ *.ipynb
135
+ env
136
+
137
+ .reggpt/vectorstores/faiss_embeddings_2024/
CHANGELOG.txt CHANGED
@@ -1,2 +1,3 @@
1
- 2023-11-30 pipeline with only document retrieval
2
- 2024-08-23 azure app serice , open ai 'gpt4o mini'
 
 
1
+ 2023-11-30 pipeline with only document retrieval
2
+ 2024-08-23 azure app serice , open ai 'gpt4o mini'
3
+ 2024-10-16 Code Refactor
Dockerfile CHANGED
@@ -1,42 +1,42 @@
1
- # Step 1: Use Python 3.11.9 as required
2
- FROM python:3.11.9
3
-
4
- # Step 2: Set up environment variables and timezone configuration
5
- ENV TZ=Asia/Colombo
6
- RUN apt-get update && apt-get install -y libaio1 wget unzip tzdata \
7
- && ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
8
-
9
- # Step 4: Add a user for running the app (after installations)
10
- RUN useradd -m -u 1000 user
11
-
12
- # Step 5: Create the /app directory and set ownership to the new user
13
- RUN mkdir -p /app && chown -R user:user /app
14
-
15
- # Step 6: Switch to non-root user after the directory has the right permissions
16
- USER user
17
- ENV PATH="/home/user/.local/bin:$PATH"
18
-
19
- # Step 7: Set up the working directory for the app
20
- WORKDIR /app
21
-
22
- # Step 8: Copy the requirements file and install dependencies
23
- COPY --chown=user ./requirements.txt requirements.txt
24
- RUN pip install --no-cache-dir --upgrade -r requirements.txt
25
-
26
- # Step 9: Install pipenv and handle pipenv environment
27
- RUN pip install pipenv
28
- COPY --chown=user . /app
29
- RUN pipenv install
30
-
31
- # Step 10: Expose the necessary port (7860 for Hugging Face Spaces)
32
- EXPOSE 7860
33
-
34
- # Step 11: Set environment variables for the app
35
- ENV APP_HOST=0.0.0.0
36
- ENV APP_PORT=7860
37
-
38
- # Step 12: Create logs directory (if necessary)
39
- RUN mkdir -p /app/logs
40
-
41
- # Step 13: Run the app using Uvicorn, listening on port 7860
42
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
+ # Step 1: Use Python 3.11.9 as required
2
+ FROM python:3.11.9
3
+
4
+ # Step 2: Set up environment variables and timezone configuration
5
+ ENV TZ=Asia/Colombo
6
+ RUN apt-get update && apt-get install -y libaio1 wget unzip tzdata \
7
+ && ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
8
+
9
+ # Step 4: Add a user for running the app (after installations)
10
+ RUN useradd -m -u 1000 user
11
+
12
+ # Step 5: Create the /app directory and set ownership to the new user
13
+ RUN mkdir -p /app && chown -R user:user /app
14
+
15
+ # Step 6: Switch to non-root user after the directory has the right permissions
16
+ USER user
17
+ ENV PATH="/home/user/.local/bin:$PATH"
18
+
19
+ # Step 7: Set up the working directory for the app
20
+ WORKDIR /app
21
+
22
+ # Step 8: Copy the requirements file and install dependencies
23
+ COPY --chown=user ./requirements.txt requirements.txt
24
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
25
+
26
+ # Step 9: Install pipenv and handle pipenv environment
27
+ RUN pip install pipenv
28
+ COPY --chown=user . /app
29
+ RUN pipenv install
30
+
31
+ # Step 10: Expose the necessary port (7860 for Hugging Face Spaces)
32
+ EXPOSE 7860
33
+
34
+ # Step 11: Set environment variables for the app
35
+ ENV APP_HOST=0.0.0.0
36
+ ENV APP_PORT=7860
37
+
38
+ # Step 12: Create logs directory (if necessary)
39
+ RUN mkdir -p /app/logs
40
+
41
+ # Step 13: Run the app using Uvicorn, listening on port 7860
42
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE CHANGED
@@ -1,22 +1,22 @@
1
- License
2
-
3
- Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
4
- All Rights Reserved
5
-
6
- This source code is protected under international copyright law. All rights
7
- reserved and protected by the copyright holders.
8
- This file is confidential and only available to authorized individuals with the
9
- permission of the copyright holders.
10
-
11
- Permission is hereby granted, to {User}. for testing and development purposes.
12
-
13
- The above copyright notice and this permission notice shall be included in all
14
- copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
  SOFTWARE.
 
1
+ License
2
+
3
+ Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
4
+ All Rights Reserved
5
+
6
+ This source code is protected under international copyright law. All rights
7
+ reserved and protected by the copyright holders.
8
+ This file is confidential and only available to authorized individuals with the
9
+ permission of the copyright holders.
10
+
11
+ Permission is hereby granted, to {User}. for testing and development purposes.
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
  SOFTWARE.
app.py CHANGED
@@ -15,20 +15,16 @@
15
  """
16
 
17
  import os
18
- import time
19
- import sys
20
  import logging
21
- import datetime
22
  import uvicorn
23
  from dotenv import load_dotenv
24
 
25
- from fastapi import FastAPI, APIRouter, HTTPException, status
26
- from fastapi import HTTPException, status
27
  from fastapi.middleware.cors import CORSMiddleware
28
 
29
- from schema import UserQuery, ResponseModel, Document, LoginRequest, UserModel
30
- from controller import get_QA_Answers, get_avaliable_models
31
 
 
32
 
33
  def filer():
34
  return "logs/log"
@@ -55,110 +51,11 @@ load_dotenv()
55
  host = os.environ.get("APP_HOST")
56
  port = int(os.environ.get("APP_PORT"))
57
 
58
-
59
- class ChatAPI:
60
-
61
- def __init__(self):
62
- self.router = APIRouter()
63
- self.router.add_api_route("/api/v1/health", self.hello, methods=["GET"])
64
- self.router.add_api_route("/api/v1/models", self.avaliable_models, methods=["GET"])
65
- self.router.add_api_route(
66
- "/api/v1/login", self.login, methods=["POST"], response_model=UserModel
67
- )
68
- self.router.add_api_route("/api/v1/chat", self.chat, methods=["POST"])
69
-
70
- async def hello(self):
71
- return "Hello there!"
72
-
73
- async def avaliable_models(self):
74
- logger.info("getting avaliable models")
75
- models = get_avaliable_models()
76
-
77
- if not models:
78
- logger.exception("models not found")
79
- raise HTTPException(
80
- status_code=status.HTTP_404_NOT_FOUND, detail="models not found"
81
- )
82
-
83
- return models
84
-
85
- async def login(self, login_request: LoginRequest):
86
- logger.info(f"username password: {login_request} ")
87
- # Dummy user data for demonstration (normally, you'd use a database)
88
- dummy_users_db = {
89
- "john_doe": {
90
- "userId": 1,
91
- "firstName": "John",
92
- "lastName": "Doe",
93
- "userName": "john_doe",
94
- "password": "password", # Normally, passwords would be hashed and stored securely
95
- "token": "dummy_token_123", # In a real scenario, this would be a JWT or another kind of token
96
- }
97
- }
98
- # Fetch user by username
99
- # user = dummy_users_db.get(login_request.username)
100
- user = dummy_users_db.get("john_doe")
101
- # Validate user credentials
102
- if not user or user["password"] != login_request.password:
103
- raise HTTPException(status_code=401, detail="Invalid username or password")
104
-
105
- # Return the user model without the password
106
- return UserModel(
107
- userId=user["userId"],
108
- firstName=user["firstName"],
109
- lastName=user["lastName"],
110
- userName=user["userName"],
111
- token=user["token"],
112
- )
113
-
114
- async def chat(
115
- self, userQuery: UserQuery
116
- ): #:UserQuery):# -> ResponseModel: #chat: QueryModel): # -> ResponseModel:
117
- """Makes query to doc store via Langchain pipeline.
118
-
119
- :param chat.: question, model, dataset location, history of the chat.
120
- :type chat: QueryModel
121
- """
122
- logger.info(f"userQuery: {userQuery} ")
123
-
124
- try:
125
- start = time.time()
126
- res = get_QA_Answers(userQuery)
127
- logger.info(
128
- f"-------------------------- answer: {res} -------------------------- "
129
- )
130
- # return res
131
- end = time.time()
132
- logger.info(
133
- f"-------------------------- Server process (took {round(end - start, 2)} s.) \n: {res}"
134
- )
135
- print(
136
- f" \n -------------------------- Server process (took {round(end - start, 2)} s.) ------------------------- \n"
137
- )
138
- # return res
139
- return {
140
- "user_input": userQuery.user_question,
141
- "bot_response": res["bot_response"],
142
- "format_data": res["format_data"],
143
- }
144
-
145
-
146
- except HTTPException as e:
147
- logger.exception(e)
148
- raise e
149
-
150
- except Exception as e:
151
- logger.exception(e)
152
- raise HTTPException(status_code=400, detail=f"Error : {e}")
153
-
154
-
155
  # initialize API
156
- app = FastAPI(title="Boardpac chatbot API")
157
  api = ChatAPI()
158
  app.include_router(api.router)
159
 
160
- # origins = ['http://localhost:8000','http://192.168.10.100:8000']
161
-
162
  app.add_middleware(
163
  CORSMiddleware,
164
  allow_origins=["*"], # origins,
@@ -169,9 +66,6 @@ app.add_middleware(
169
 
170
  if __name__ == "__main__":
171
 
172
- host = "0.0.0.0"
173
- port = 8000
174
-
175
  # config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
176
  config = uvicorn.Config("server:app", host=host, port=port)
177
  server = uvicorn.Server(config)
 
15
  """
16
 
17
  import os
 
 
18
  import logging
 
19
  import uvicorn
20
  from dotenv import load_dotenv
21
 
22
+ from fastapi import FastAPI
 
23
  from fastapi.middleware.cors import CORSMiddleware
24
 
25
+ from reggpt.configs.api import API_TITLE, API_VERSION, API_DESCRIPTION
 
26
 
27
+ from reggpt.api.router import ChatAPI
28
 
29
  def filer():
30
  return "logs/log"
 
51
  host = os.environ.get("APP_HOST")
52
  port = int(os.environ.get("APP_PORT"))
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # initialize API
55
+ app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION)
56
  api = ChatAPI()
57
  app.include_router(api.router)
58
 
 
 
59
  app.add_middleware(
60
  CORSMiddleware,
61
  allow_origins=["*"], # origins,
 
66
 
67
  if __name__ == "__main__":
68
 
 
 
 
69
  # config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
70
  config = uvicorn.Config("server:app", host=host, port=port)
71
  server = uvicorn.Server(config)
faiss_embeddings_2024/index.faiss DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a3087f02172887cbf8bfb0fb3b371843548619c2a2873fdf4629339e2031a2c1
3
- size 10895405
 
 
 
 
faiss_embeddings_2024/index.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3da917c3758e2bfe0aedbd050199f4c80ec372d5b0349b49126b790fb1757db9
3
- size 3935715
 
 
 
 
prompts.py DELETED
@@ -1,123 +0,0 @@
1
- """
2
- /*************************************************************************
3
- *
4
- * CONFIDENTIAL
5
- * __________________
6
- *
7
- * Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
8
- * All Rights Reserved
9
- *
10
- * Author : Theekshana Samaradiwakara
11
- * Description :Python Backend API to chat with private data
12
- * CreatedDate : 14/11/2023
13
- * LastModifiedDate : 19/03/2024
14
- *************************************************************************/
15
- """
16
-
17
- from langchain.prompts import PromptTemplate
18
-
19
- # multi query prompt
20
- MULTY_QUERY_PROMPT = PromptTemplate(
21
- input_variables=["question"],
22
- template="""You are an AI language model assistant. Your task is to generate three
23
- different versions of the given user question to retrieve relevant documents from a vector
24
- database. By generating multiple perspectives on the user question, your goal is to help
25
- the user overcome some of the limitations of the distance-based similarity search.
26
- Provide these alternative questions separated by newlines.
27
-
28
- Dont add anything extra before or after to the 3 questions. Just give 3 lines with 3 questions.
29
- Just provide 3 lines having 3 questions only.
30
- Answer should be in following format.
31
-
32
- 1. alternative question 1
33
- 2. alternative question 2
34
- 3. alternative question 3
35
-
36
- Original question: {question}""",
37
- )
38
-
39
- #retrieval prompt
40
- B_INST, E_INST = "[INST]", "[/INST]"
41
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
42
-
43
- retrieval_qa_template = (
44
- """<<SYS>>
45
- You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
46
-
47
- please answer the question based on the chat history provided below. Answer should be short and simple as possible and on to the point.
48
- <chat history>: {chat_history}
49
-
50
- If the question is related to welcomes and greetings answer accordingly.
51
-
52
- Else If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
53
- please answer the question based only on the information provided in following central bank documents published in various years.
54
- The published year is mentioned as the metadata 'year' of each source document.
55
- Please notice that content of a one document of a past year can updated by a new document from a recent year.
56
- Always try to answer with latest information and mention the year which information extracted.
57
- If you dont know the answer say you dont know, dont try to makeup answers. Dont add any extra details that is not mentioned in the context.
58
-
59
- <</SYS>>
60
-
61
- [INST]
62
- <DOCUMENTS>
63
- {context}
64
- </DOCUMENTS>
65
-
66
- Question : {question}[/INST]"""
67
- )
68
-
69
-
70
- retrieval_qa_chain_prompt = PromptTemplate(
71
- input_variables=["question", "context", "chat_history"],
72
- template=retrieval_qa_template
73
- )
74
-
75
-
76
-
77
- #document combine prompt
78
- document_combine_prompt = PromptTemplate(
79
- input_variables=["source","year", "page","page_content"],
80
- template=
81
- """<doc> source: {source}, year: {year}, page: {page}, page content: {page_content} </doc>"""
82
- )
83
-
84
-
85
- router_template_Mixtral_V0= """
86
- You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
87
-
88
- If a user asks a question you have to classify it to following 3 types Relevant, Greeting, Other.
89
-
90
- "Relevant”: If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
91
- "Greeting”: If the question is a greeting like good morning, hi my name is., thank you or General Question ask about the AI assistance of a company boardpac.
92
- "Other”: If the question is not related to research papers.
93
-
94
- Give the correct name of question type. If you are not sure return "Not Sure" instead.
95
-
96
- Question : {question}
97
- """
98
- router_prompt=PromptTemplate.from_template(router_template_Mixtral_V0)
99
-
100
-
101
- general_qa_template_Mixtral_V0= """
102
- You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
103
- you can answer Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations related question .
104
-
105
- Is the provided question below a greeting? First, evaluate whether the input resembles a typical greeting or not.
106
-
107
- Greetings are used to say 'hello' and 'how are you?' and to say 'goodbye' and 'nice speaking with you.' and 'hi, I'm (user's name).'
108
- Greetings are words used when we want to introduce ourselves to others and when we want to find out how someone is feeling.
109
-
110
- You can only reply to the user's greetings.
111
- If the question is a greeting, reply accordingly as the AI assistant of company boardpac.
112
- If the question is not related to greetings and research papers, say that it is out of your domain.
113
- If the question is not clear enough, ask for more details and don't try to make up answers.
114
-
115
- Answer should be polite, short, and simple.
116
-
117
- Additionally, it's important to note that this AI assistant has access to an internal collection of research papers, and answers can be provided using the information available in those CBSL Dataset.
118
-
119
- Question: {question}
120
- """
121
-
122
- general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template_Mixtral_V0)
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.env.example β†’ reggpt/.env.example RENAMED
File without changes
{utils β†’ reggpt}/__init__.py RENAMED
File without changes
reggpt/agent/agent.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logger = logging.getLogger(__name__)
3
+ from fastapi import HTTPException
4
+ import time
5
+ from reggpt.routers.qaPipeline import run_router_chain, chain_selector
6
+
7
+ def run_agent(query):
8
+ try:
9
+ logger.info(f"run_agent : Question: {query}")
10
+ print(f"---------------- run_agent : Question: {query} ----------------")
11
+ # Get the answer from the chain
12
+ start = time.time()
13
+ chain_type = run_router_chain(query)
14
+ res = chain_selector(chain_type,query)
15
+ end = time.time()
16
+
17
+ # log the result
18
+ logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
19
+ print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
20
+
21
+ return res
22
+
23
+ except HTTPException as e:
24
+ print('HTTPException')
25
+ print(e)
26
+ logger.exception(e)
27
+ raise e
28
+
29
+ except Exception as e:
30
+ print('Exception')
31
+ print(e)
32
+ logger.exception(e)
33
+ raise e
reggpt/api/__init__.py ADDED
File without changes
server.py β†’ reggpt/api/router.py RENAMED
@@ -1,71 +1,28 @@
1
- """
2
- /*************************************************************************
3
- *
4
- * CONFIDENTIAL
5
- * __________________
6
- *
7
- * Copyright (2023-2025) AI Labs, IronOne Technologies, LLC
8
- * All Rights Reserved
9
- *
10
- * Author : Theekshana Samaradiwakara
11
- * Description :Python Backend API to chat with private data
12
- * CreatedDate : 14/11/2023
13
- * LastModifiedDate : 15/10/2024
14
- *************************************************************************/
15
- """
16
-
17
- import os
18
  import time
19
- import sys
20
- import logging
21
- import datetime
22
- import uvicorn
23
- from dotenv import load_dotenv
24
-
25
- from fastapi import FastAPI, APIRouter, HTTPException, status
26
- from fastapi import HTTPException, status
27
- from fastapi.middleware.cors import CORSMiddleware
28
-
29
- from schema import UserQuery, ResponseModel, Document, LoginRequest, UserModel
30
- from controller import get_QA_Answers, get_avaliable_models
31
 
32
 
33
- def filer():
34
- return "logs/log"
35
- # today = datetime.datetime.today()
36
- # log_filename = f"logs/{today.year}-{today.month:02d}-{today.day:02d}.log"
37
- # return log_filename
38
-
39
 
40
- file_handler = logging.FileHandler(filer())
41
- # file_handler = logging.handlers.TimedRotatingFileHandler(filer(),when="D")
42
- file_handler.setLevel(logging.INFO)
43
 
44
- logging.basicConfig(
45
- level=logging.DEBUG,
46
- format="%(asctime)s %(levelname)s (%(name)s) : %(message)s",
47
- datefmt="%Y-%m-%d %H:%M:%S",
48
- handlers=[file_handler],
49
- force=True,
50
- )
51
 
52
  logger = logging.getLogger(__name__)
53
 
54
- load_dotenv()
55
- host = os.environ.get("APP_HOST")
56
- port = int(os.environ.get("APP_PORT"))
57
-
58
 
59
  class ChatAPI:
60
 
61
  def __init__(self):
62
  self.router = APIRouter()
63
- self.router.add_api_route("/api/v1/health", self.hello, methods=["GET"])
64
- self.router.add_api_route("/api/v1/models", self.avaliable_models, methods=["GET"])
65
  self.router.add_api_route(
66
- "/api/v1/login", self.login, methods=["POST"], response_model=UserModel
67
  )
68
- self.router.add_api_route("/api/v1/chat", self.chat, methods=["POST"])
69
 
70
  async def hello(self):
71
  return "Hello there!"
@@ -151,29 +108,3 @@ class ChatAPI:
151
  logger.exception(e)
152
  raise HTTPException(status_code=400, detail=f"Error : {e}")
153
 
154
-
155
- # initialize API
156
- app = FastAPI(title="Boardpac chatbot API")
157
- api = ChatAPI()
158
- app.include_router(api.router)
159
-
160
- # origins = ['http://localhost:8000','http://192.168.10.100:8000']
161
-
162
- app.add_middleware(
163
- CORSMiddleware,
164
- allow_origins=["*"], # origins,
165
- allow_credentials=True,
166
- allow_methods=["*"],
167
- allow_headers=["*"],
168
- )
169
-
170
- if __name__ == "__main__":
171
-
172
- host = "0.0.0.0"
173
- port = 8000
174
-
175
- # config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
176
- config = uvicorn.Config("server:app", host=host, port=port)
177
- server = uvicorn.Server(config)
178
- server.run()
179
- # uvicorn.run(app)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
+ from fastapi import APIRouter, HTTPException, status
5
+ from fastapi import HTTPException, status
 
 
 
 
6
 
7
+ from reggpt.schemas.schema import UserQuery, LoginRequest, UserModel
8
+ from reggpt.routers.controller import get_QA_Answers, get_avaliable_models
 
9
 
10
+ from reggpt.configs.api import API_ENDPOINT_LOGIN,API_ENDPOINT_CHAT, API_ENDPOINT_HEALTH, API_ENDPOINT_MODEL
11
+ import logging
 
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
 
 
 
15
 
16
  class ChatAPI:
17
 
18
  def __init__(self):
19
  self.router = APIRouter()
20
+ self.router.add_api_route(API_ENDPOINT_HEALTH, self.hello, methods=["GET"])
21
+ self.router.add_api_route(API_ENDPOINT_MODEL, self.avaliable_models, methods=["GET"])
22
  self.router.add_api_route(
23
+ API_ENDPOINT_LOGIN, self.login, methods=["POST"], response_model=UserModel
24
  )
25
+ self.router.add_api_route(API_ENDPOINT_CHAT, self.chat, methods=["POST"])
26
 
27
  async def hello(self):
28
  return "Hello there!"
 
108
  logger.exception(e)
109
  raise HTTPException(status_code=400, detail=f"Error : {e}")
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{configs β†’ reggpt/chains}/__init__.py RENAMED
@@ -1,5 +1,5 @@
1
- # import os
2
- # import sys
3
-
4
- # if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
5
- # sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 
1
+ # import os
2
+ # import sys
3
+
4
+ # if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
5
+ # sys.path.append(os.path.dirname(os.path.abspath(__file__)))
llmChain.py β†’ reggpt/chains/llmChain.py RENAMED
@@ -1,96 +1,99 @@
1
- """
2
- /*************************************************************************
3
- *
4
- * CONFIDENTIAL
5
- * __________________
6
- *
7
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
- * All Rights Reserved
9
- *
10
- * Author : Theekshana Samaradiwakara
11
- * Description :Python Backend API to chat with private data
12
- * CreatedDate : 14/11/2023
13
- * LastModifiedDate : 18/03/2024
14
- *************************************************************************/
15
- """
16
-
17
- import os
18
- import logging
19
- logger = logging.getLogger(__name__)
20
- from dotenv import load_dotenv
21
-
22
- load_dotenv()
23
-
24
- verbose = os.environ.get('VERBOSE')
25
-
26
- from llm import get_model
27
- from langchain.chains import ConversationalRetrievalChain
28
- # from conversationBufferWindowMemory import ConversationBufferWindowMemory
29
-
30
- # from langchain.prompts import PromptTemplate
31
- from langchain.chains import LLMChain
32
-
33
- from prompts import retrieval_qa_chain_prompt, document_combine_prompt, general_qa_chain_prompt, router_prompt
34
-
35
- def get_qa_chain(model_type,retriever):
36
- logger.info("creating qa_chain")
37
-
38
- try:
39
- qa_llm = get_model(model_type)
40
-
41
- qa_chain = ConversationalRetrievalChain.from_llm(
42
- llm=qa_llm,
43
- chain_type="stuff",
44
- retriever = retriever,
45
- # retriever = self.retriever(search_kwargs={"k": target_source_chunks}
46
- return_source_documents= True,
47
- get_chat_history=lambda h : h,
48
- combine_docs_chain_kwargs={
49
- "prompt": retrieval_qa_chain_prompt,
50
- "document_prompt": document_combine_prompt,
51
- },
52
- verbose=True,
53
- # memory=memory,
54
- )
55
-
56
- logger.info("qa_chain created")
57
- return qa_chain
58
-
59
- except Exception as e:
60
- msg=f"Error : {e}"
61
- logger.exception(msg)
62
- raise e
63
-
64
-
65
- def get_general_qa_chain(model_type):
66
- logger.info("creating general_qa_chain")
67
-
68
- try:
69
- general_qa_llm = get_model(model_type)
70
- general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt)
71
-
72
- logger.info("general_qa_chain created")
73
- return general_qa_chain
74
-
75
- except Exception as e:
76
- msg=f"Error : {e}"
77
- logger.exception(msg)
78
- raise e
79
-
80
-
81
- def get_router_chain(model_type):
82
- logger.info("creating router_chain")
83
-
84
- try:
85
- router_llm = get_model(model_type)
86
- router_chain = LLMChain(llm=router_llm, prompt=router_prompt)
87
-
88
- logger.info("router_chain created")
89
- return router_chain
90
-
91
- except Exception as e:
92
- msg=f"Error : {e}"
93
- logger.exception(msg)
94
- raise e
95
-
96
-
 
 
 
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Author : Theekshana Samaradiwakara
11
+ * Description :Python Backend API to chat with private data
12
+ * CreatedDate : 14/11/2023
13
+ * LastModifiedDate : 18/03/2024
14
+ *************************************************************************/
15
+ """
16
+
17
+ import os
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+ from dotenv import load_dotenv
21
+
22
+ load_dotenv()
23
+
24
+ verbose = os.environ.get('VERBOSE')
25
+
26
+ from reggpt.llms.llm import get_model
27
+ from langchain.chains import ConversationalRetrievalChain
28
+ # from conversationBufferWindowMemory import ConversationBufferWindowMemory
29
+
30
+ # from langchain.prompts import PromptTemplate
31
+ from langchain.chains import LLMChain
32
+ from reggpt.prompts import document_combine as document_combine_prompt
33
+ from reggpt.prompts import retrieval as retrieval_qa_chain_prompt
34
+ from reggpt.prompts import general as general_qa_chain_prompt
35
+ from reggpt.prompts import router as router_prompt
36
+
37
+
38
+ def get_qa_chain(model_type,retriever):
39
+ logger.info("creating qa_chain")
40
+
41
+ try:
42
+ qa_llm = get_model(model_type)
43
+
44
+ qa_chain = ConversationalRetrievalChain.from_llm(
45
+ llm=qa_llm,
46
+ chain_type="stuff",
47
+ retriever = retriever,
48
+ # retriever = self.retriever(search_kwargs={"k": target_source_chunks}
49
+ return_source_documents= True,
50
+ get_chat_history=lambda h : h,
51
+ combine_docs_chain_kwargs={
52
+ "prompt": retrieval_qa_chain_prompt,
53
+ "document_prompt": document_combine_prompt,
54
+ },
55
+ verbose=True,
56
+ # memory=memory,
57
+ )
58
+
59
+ logger.info("qa_chain created")
60
+ return qa_chain
61
+
62
+ except Exception as e:
63
+ msg=f"Error : {e}"
64
+ logger.exception(msg)
65
+ raise e
66
+
67
+
68
+ def get_general_qa_chain(model_type):
69
+ logger.info("creating general_qa_chain")
70
+
71
+ try:
72
+ general_qa_llm = get_model(model_type)
73
+ general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt)
74
+
75
+ logger.info("general_qa_chain created")
76
+ return general_qa_chain
77
+
78
+ except Exception as e:
79
+ msg=f"Error : {e}"
80
+ logger.exception(msg)
81
+ raise e
82
+
83
+
84
+ def get_router_chain(model_type):
85
+ logger.info("creating router_chain")
86
+
87
+ try:
88
+ router_llm = get_model(model_type)
89
+ router_chain = LLMChain(llm=router_llm, prompt=router_prompt)
90
+
91
+ logger.info("router_chain created")
92
+ return router_chain
93
+
94
+ except Exception as e:
95
+ msg=f"Error : {e}"
96
+ logger.exception(msg)
97
+ raise e
98
+
99
+
{data β†’ reggpt/configs}/__init__.py RENAMED
@@ -1,5 +1,5 @@
1
- # import os
2
- # import sys
3
-
4
- # if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
5
- # sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 
1
+ # import os
2
+ # import sys
3
+
4
+ # if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
5
+ # sys.path.append(os.path.dirname(os.path.abspath(__file__)))
reggpt/configs/api.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2024-2025) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Authors : Hammaad Rizwan, Theekshana Samaradiwakara
11
+ * Description : API Configurations
12
+ * CreatedDate : 14/10/2024
13
+ * LastModifiedDate : 15/10/2024
14
+ *************************************************************************/
15
+ """
16
+ API_TITLE = "RegGPT Back End v1"
17
+ API_VERSION = "0.1.0"
18
+ API_DESCRIPTION = "API_DESC"
19
+
20
+ API_ENDPOINT_PREFIX = "/api/v1"
21
+ API_DOCS_URL = f"{API_ENDPOINT_PREFIX}/docs"
22
+ API_REDOC_URL = f"{API_ENDPOINT_PREFIX}/redoc"
23
+ API_OPENAPI_URL = f"{API_ENDPOINT_PREFIX}/openapi.json"
24
+
25
+ API_ENDPOINT_HEALTH = f"{API_ENDPOINT_PREFIX}/health"
26
+ API_ENDPOINT_CHAT = f"{API_ENDPOINT_PREFIX}/chat"
27
+ API_ENDPOINT_MODEL = f"{API_ENDPOINT_PREFIX}/models"
28
+ API_ENDPOINT_LOGIN = f"{API_ENDPOINT_PREFIX}/login"
config.py β†’ reggpt/configs/config.py RENAMED
@@ -1,36 +1,36 @@
1
- AVALIABLE_MODELS=[
2
- {
3
- "id":"gpt-4o-mini",
4
- "model_name":"openai/gpt-4o-mini",
5
- "description":"gpt-4o-mini model from openai"
6
- }
7
- ]
8
-
9
- MODELS={
10
- "DEFAULT":"openai",
11
- "gpt-4o-mini":"openai",
12
-
13
- }
14
-
15
- DATASETS={
16
- "DEFAULT":"faiss",
17
- "a":"A",
18
- "b":"B",
19
- "c":"C"
20
-
21
- }
22
-
23
- MEMORY_WINDOW_K = 1
24
-
25
- QA_MODEL_TYPE = "openai"
26
- GENERAL_QA_MODEL_TYPE = "openai"
27
- ROUTER_MODEL_TYPE = "openai"
28
- Multi_Query_MODEL_TYPE = "openai"
29
-
30
-
31
- ANSWER_TYPES = [
32
- "relevant",
33
- "greeting",
34
- "other",
35
- "not sure",
36
  ]
 
1
+ AVALIABLE_MODELS=[
2
+ {
3
+ "id":"gpt-4o-mini",
4
+ "model_name":"openai/gpt-4o-mini",
5
+ "description":"gpt-4o-mini model from openai"
6
+ }
7
+ ]
8
+
9
+ MODELS={
10
+ "DEFAULT":"openai",
11
+ "gpt-4o-mini":"openai",
12
+
13
+ }
14
+
15
+ DATASETS={
16
+ "DEFAULT":"faiss",
17
+ "a":"A",
18
+ "b":"B",
19
+ "c":"C"
20
+
21
+ }
22
+
23
+ MEMORY_WINDOW_K = 1
24
+
25
+ QA_MODEL_TYPE = "openai"
26
+ GENERAL_QA_MODEL_TYPE = "openai"
27
+ ROUTER_MODEL_TYPE = "openai"
28
+ Multi_Query_MODEL_TYPE = "openai"
29
+
30
+
31
+ ANSWER_TYPES = [
32
+ "relevant",
33
+ "greeting",
34
+ "other",
35
+ "not sure",
36
  ]
{configs β†’ reggpt/configs}/logger.py RENAMED
@@ -1,40 +1,40 @@
1
- import logging
2
- import time
3
- # from functools import wraps
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
- stream_handler = logging.StreamHandler()
8
- log_filename = "output.log"
9
- file_handler = logging.FileHandler(filename=log_filename)
10
- handlers = [stream_handler, file_handler]
11
-
12
-
13
- class TimeFilter(logging.Filter):
14
- def filter(self, record):
15
- return "Running" in record.getMessage()
16
-
17
-
18
- logger.addFilter(TimeFilter())
19
-
20
- # Configure the logging module
21
- logging.basicConfig(
22
- level=logging.INFO,
23
- format="%(name)s %(asctime)s - %(levelname)s - %(message)s",
24
- handlers=handlers,
25
- )
26
-
27
-
28
- def time_logger(func):
29
- """Decorator function to log time taken by any function."""
30
-
31
- # @wraps(func)
32
- def wrapper(*args, **kwargs):
33
- start_time = time.time() # Start time before function execution
34
- result = func(*args, **kwargs) # Function execution
35
- end_time = time.time() # End time after function execution
36
- execution_time = end_time - start_time # Calculate execution time
37
- logger.info(f"Running {func.__name__}: --- {execution_time} seconds ---")
38
- return result
39
-
40
  return wrapper
 
1
+ import logging
2
+ import time
3
+ # from functools import wraps
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ stream_handler = logging.StreamHandler()
8
+ log_filename = "output.log"
9
+ file_handler = logging.FileHandler(filename=log_filename)
10
+ handlers = [stream_handler, file_handler]
11
+
12
+
13
+ class TimeFilter(logging.Filter):
14
+ def filter(self, record):
15
+ return "Running" in record.getMessage()
16
+
17
+
18
+ logger.addFilter(TimeFilter())
19
+
20
+ # Configure the logging module
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format="%(name)s %(asctime)s - %(levelname)s - %(message)s",
24
+ handlers=handlers,
25
+ )
26
+
27
+
28
+ def time_logger(func):
29
+ """Decorator function to log time taken by any function."""
30
+
31
+ # @wraps(func)
32
+ def wrapper(*args, **kwargs):
33
+ start_time = time.time() # Start time before function execution
34
+ result = func(*args, **kwargs) # Function execution
35
+ end_time = time.time() # End time after function execution
36
+ execution_time = end_time - start_time # Calculate execution time
37
+ logger.info(f"Running {func.__name__}: --- {execution_time} seconds ---")
38
+ return result
39
+
40
  return wrapper
reggpt/controller/__init__.py ADDED
File without changes
qaPipeline.py β†’ reggpt/controller/agent.py RENAMED
@@ -1,150 +1,73 @@
1
- """
2
- /*************************************************************************
3
- *
4
- * CONFIDENTIAL
5
- * __________________
6
- *
7
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
- * All Rights Reserved
9
- *
10
- * Author : Theekshana Samaradiwakara
11
- * Description :Python Backend API to chat with private data
12
- * CreatedDate : 14/11/2023
13
- * LastModifiedDate : 18/03/2024
14
- *************************************************************************/
15
- """
16
-
17
- import os
18
- import time
19
- import logging
20
- logger = logging.getLogger(__name__)
21
- from dotenv import load_dotenv
22
- from fastapi import HTTPException
23
- from llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
24
- from output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
25
-
26
- from config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
27
- from retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
28
- load_dotenv()
29
-
30
- verbose = os.environ.get('VERBOSE')
31
-
32
- qa_model_type=QA_MODEL_TYPE
33
- general_qa_model_type=GENERAL_QA_MODEL_TYPE
34
- router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
35
- multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
36
- # model_type="tiiuae/falcon-7b-instruct"
37
-
38
- # retriever=load_faiss_retriever()
39
- retriever=load_ensemble_retriever()
40
- # retriever=load_multi_query_retriever(multi_query_model_type)
41
- logger.info("retriever loaded:")
42
-
43
- qa_chain= get_qa_chain(qa_model_type,retriever)
44
- general_qa_chain= get_general_qa_chain(general_qa_model_type)
45
- router_chain= get_router_chain(router_model_type)
46
-
47
- def chain_selector(chain_type, query):
48
- chain_type = chain_type.lower().strip()
49
- logger.info(f"chain_selector : chain_type: {chain_type} Question: {query}")
50
- if "greeting" in chain_type:
51
- return run_general_qa_chain(query)
52
- elif "other" in chain_type:
53
- return run_out_of_domain_chain(query)
54
- elif ("relevant" in chain_type) or ("not sure" in chain_type) :
55
- return run_qa_chain(query)
56
- else:
57
- raise ValueError(
58
- f"Received invalid type '{chain_type}'"
59
- )
60
-
61
- def run_agent(query):
62
- try:
63
- logger.info(f"run_agent : Question: {query}")
64
- print(f"---------------- run_agent : Question: {query} ----------------")
65
- # Get the answer from the chain
66
- start = time.time()
67
- chain_type = run_router_chain(query)
68
- res = chain_selector(chain_type,query)
69
- end = time.time()
70
-
71
- # log the result
72
- logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
73
- print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
74
-
75
- return res
76
-
77
- except HTTPException as e:
78
- print('HTTPException eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
79
- print(e)
80
- logger.exception(e)
81
- raise e
82
-
83
- except Exception as e:
84
- print('Exception eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee')
85
- print(e)
86
- logger.exception(e)
87
- raise e
88
-
89
-
90
- def run_router_chain(query):
91
- try:
92
- logger.info(f"run_router_chain : Question: {query}")
93
- # Get the answer from the chain
94
- start = time.time()
95
- chain_type = router_chain.invoke(query)['text']
96
- end = time.time()
97
-
98
- # log the result
99
- logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}")
100
-
101
- return chain_type
102
-
103
- except Exception as e:
104
- logger.exception(e)
105
- raise e
106
-
107
-
108
- def run_qa_chain(query):
109
- try:
110
- logger.info(f"run_qa_chain : Question: {query}")
111
- # Get the answer from the chain
112
- start = time.time()
113
- # res = qa_chain(query)
114
- res = qa_chain.invoke({"question": query, "chat_history":""})
115
- # res = response
116
- # answer, docs = res['result'],res['source_documents']
117
- end = time.time()
118
-
119
- # log the result
120
- logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
121
-
122
- return qa_chain_output_parser(res)
123
-
124
- except Exception as e:
125
- logger.exception(e)
126
- raise e
127
-
128
-
129
- def run_general_qa_chain(query):
130
- try:
131
- logger.info(f"run_general_qa_chain : Question: {query}")
132
-
133
- # Get the answer from the chain
134
- start = time.time()
135
- res = general_qa_chain.invoke(query)
136
- end = time.time()
137
-
138
- # log the result
139
-
140
- logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
141
-
142
- return general_qa_chain_output_parser(res)
143
-
144
- except Exception as e:
145
- logger.exception(e)
146
- raise e
147
-
148
-
149
- def run_out_of_domain_chain(query):
150
- return out_of_domain_chain_parser(query)
 
1
+ import os
2
+ import time
3
+ import logging
4
+ logger = logging.getLogger(__name__)
5
+ from dotenv import load_dotenv
6
+ from fastapi import HTTPException
7
+ from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
8
+ from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
9
+
10
+ from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
11
+ from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
12
+ load_dotenv()
13
+
14
+ verbose = os.environ.get('VERBOSE')
15
+
16
+ qa_model_type=QA_MODEL_TYPE
17
+ general_qa_model_type=GENERAL_QA_MODEL_TYPE
18
+ router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
19
+ multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
20
+ # model_type="tiiuae/falcon-7b-instruct"
21
+
22
+ # retriever=load_faiss_retriever()
23
+ retriever=load_ensemble_retriever()
24
+ # retriever=load_multi_query_retriever(multi_query_model_type)
25
+ logger.info("retriever loaded:")
26
+
27
+ qa_chain= get_qa_chain(qa_model_type,retriever)
28
+ general_qa_chain= get_general_qa_chain(general_qa_model_type)
29
+ router_chain= get_router_chain(router_model_type)
30
+
31
+
32
+ def chain_selector(chain_type, query):
33
+ chain_type = chain_type.lower().strip()
34
+ logger.info(f"chain_selector : chain_type: {chain_type} Question: {query}")
35
+ if "greeting" in chain_type:
36
+ return run_general_qa_chain(query)
37
+ elif "other" in chain_type:
38
+ return run_out_of_domain_chain(query)
39
+ elif ("relevant" in chain_type) or ("not sure" in chain_type) :
40
+ return run_qa_chain(query)
41
+ else:
42
+ raise ValueError(
43
+ f"Received invalid type '{chain_type}'"
44
+ )
45
+
46
+ def run_agent(query):
47
+ try:
48
+ logger.info(f"run_agent : Question: {query}")
49
+ print(f"---------------- run_agent : Question: {query} ----------------")
50
+ # Get the answer from the chain
51
+ start = time.time()
52
+ chain_type = run_router_chain(query)
53
+ res = chain_selector(chain_type,query)
54
+ end = time.time()
55
+
56
+ # log the result
57
+ logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
58
+ print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
59
+
60
+ return res
61
+
62
+ except HTTPException as e:
63
+ print('HTTPException')
64
+ print(e)
65
+ logger.exception(e)
66
+ raise e
67
+
68
+ except Exception as e:
69
+ print('Exception')
70
+ print(e)
71
+ logger.exception(e)
72
+ raise e
73
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reggpt/controller/router.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Author : Theekshana Samaradiwakara
11
+ * Description :Python Backend API to chat with private data
12
+ * CreatedDate : 14/11/2023
13
+ * LastModifiedDate : 18/03/2024
14
+ *************************************************************************/
15
+ """
16
+
17
+ import os
18
+ import time
19
+ import logging
20
+ logger = logging.getLogger(__name__)
21
+ from dotenv import load_dotenv
22
+ from fastapi import HTTPException
23
+ from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
24
+ from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
25
+
26
+ from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
27
+ from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
28
+ load_dotenv()
29
+
30
+ verbose = os.environ.get('VERBOSE')
31
+
32
+ qa_model_type=QA_MODEL_TYPE
33
+ general_qa_model_type=GENERAL_QA_MODEL_TYPE
34
+ router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
35
+ multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
36
+ # model_type="tiiuae/falcon-7b-instruct"
37
+
38
+ # retriever=load_faiss_retriever()
39
+ retriever=load_ensemble_retriever()
40
+ # retriever=load_multi_query_retriever(multi_query_model_type)
41
+ logger.info("retriever loaded:")
42
+
43
+ qa_chain= get_qa_chain(qa_model_type,retriever)
44
+ general_qa_chain= get_general_qa_chain(general_qa_model_type)
45
+ router_chain= get_router_chain(router_model_type)
46
+
47
+ def run_router_chain(query):
48
+ try:
49
+ logger.info(f"run_router_chain : Question: {query}")
50
+ # Get the answer from the chain
51
+ start = time.time()
52
+ chain_type = router_chain.invoke(query)['text']
53
+ end = time.time()
54
+
55
+ # log the result
56
+ logger.info(f"Answer (took {round(end - start, 2)} s.) chain_type: {chain_type}")
57
+
58
+ return chain_type
59
+
60
+ except Exception as e:
61
+ logger.exception(e)
62
+ raise e
reggpt/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # import os
2
+ # import sys
3
+
4
+ # if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
5
+ # sys.path.append(os.path.dirname(os.path.abspath(__file__)))
{data β†’ reggpt/data}/splitted_texts.jsonl RENAMED
The diff for this file is too large to render. See raw diff
 
llm.py β†’ reggpt/llms/llm.py RENAMED
@@ -1,47 +1,47 @@
1
- """
2
- /*************************************************************************
3
- *
4
- * CONFIDENTIAL
5
- * __________________
6
- *
7
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
- * All Rights Reserved
9
- *
10
- * Author : Theekshana Samaradiwakara
11
- * Description :Python Backend API to chat with private data
12
- * CreatedDate : 14/11/2023
13
- * LastModifiedDate : 18/03/2024
14
- *************************************************************************/
15
- """
16
-
17
- import os
18
- # import time
19
- import logging
20
- logger = logging.getLogger(__name__)
21
- from dotenv import load_dotenv
22
-
23
- # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
24
- from langchain_openai import ChatOpenAI
25
-
26
- load_dotenv()
27
-
28
- openai_api_key = os.environ.get('OPENAI_API_KEY')
29
- # openai_api_key = "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"
30
- print(f"--- ---- ---- openai_api_key: {openai_api_key}")
31
-
32
- verbose = os.environ.get('VERBOSE')
33
-
34
- def get_model(model_type):
35
-
36
- match model_type:
37
- case "openai":
38
- llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, openai_api_key=openai_api_key)
39
- case _default:
40
- # raise exception if model_type is not supported
41
- msg=f"Model type '{model_type}' is not supported. Please choose a valid one"
42
- logger.error(msg)
43
- return Exception(msg)
44
-
45
-
46
- logger.info(f"model_type: {model_type} loaded:")
47
  return llm
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Author : Theekshana Samaradiwakara
11
+ * Description :Python Backend API to chat with private data
12
+ * CreatedDate : 14/11/2023
13
+ * LastModifiedDate : 18/03/2024
14
+ *************************************************************************/
15
+ """
16
+
17
+ import os
18
+ # import time
19
+ import logging
20
+ logger = logging.getLogger(__name__)
21
+ from dotenv import load_dotenv
22
+
23
+ # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
24
+ from langchain_openai import ChatOpenAI
25
+
26
+ load_dotenv()
27
+
28
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
29
+ # openai_api_key = "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"
30
+ print(f"--- ---- ---- openai_api_key: {openai_api_key}")
31
+
32
+ verbose = os.environ.get('VERBOSE')
33
+
34
+ def get_model(model_type):
35
+
36
+ match model_type:
37
+ case "openai":
38
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, openai_api_key=openai_api_key)
39
+ case _default:
40
+ # raise exception if model_type is not supported
41
+ msg=f"Model type '{model_type}' is not supported. Please choose a valid one"
42
+ logger.error(msg)
43
+ return Exception(msg)
44
+
45
+
46
+ logger.info(f"model_type: {model_type} loaded:")
47
  return llm
reggpt/memory/__init__.py ADDED
File without changes
conversationBufferWindowMemory.py β†’ reggpt/memory/conversationBufferWindowMemory.py RENAMED
@@ -1,134 +1,134 @@
1
- """
2
- /*************************************************************************
3
- *
4
- * CONFIDENTIAL
5
- * __________________
6
- *
7
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
- * All Rights Reserved
9
- *
10
- * Author : Theekshana Samaradiwakara
11
- * Description :Python Backend API to chat with private data
12
- * CreatedDate : 14/11/2023
13
- * LastModifiedDate : 18/11/2020
14
- *************************************************************************/
15
- """
16
-
17
- from abc import ABC
18
- from typing import Any, Dict, Optional, Tuple
19
- # import json
20
-
21
- from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
22
- from langchain.memory.utils import get_prompt_input_key
23
- from langchain.pydantic_v1 import Field
24
- from langchain.schema import BaseChatMessageHistory, BaseMemory
25
-
26
- from typing import List, Union
27
-
28
- # from langchain.memory.chat_memory import BaseChatMemory
29
- from langchain.schema.messages import BaseMessage, get_buffer_string
30
-
31
-
32
- class BaseChatMemory(BaseMemory, ABC):
33
- """Abstract base class for chat memory."""
34
-
35
- chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
36
- output_key: Optional[str] = None
37
- input_key: Optional[str] = None
38
- return_messages: bool = False
39
-
40
- def _get_input_output(
41
- self, inputs: Dict[str, Any], outputs: Dict[str, str]
42
- ) -> Tuple[str, str]:
43
-
44
-
45
- if self.input_key is None:
46
- prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
47
- else:
48
- prompt_input_key = self.input_key
49
-
50
- if self.output_key is None:
51
- """
52
- output for agent with LLM chain tool = {answer}
53
- output for agent with ConversationalRetrievalChain tool = {'question', 'chat_history', 'answer','source_documents'}
54
- """
55
-
56
- LLM_key = 'output'
57
- Retrieval_key = 'answer'
58
- if isinstance(outputs[LLM_key], dict):
59
- Retrieval_dict = outputs[LLM_key]
60
- if Retrieval_key in Retrieval_dict.keys():
61
- #output keys are 'answer' , 'source_documents'
62
- output = Retrieval_dict[Retrieval_key]
63
- else:
64
- raise ValueError(f"output key: {LLM_key} not a valid dictionary")
65
-
66
- else:
67
- #otherwise output key will be 'output'
68
- output_key = list(outputs.keys())[0]
69
- output = outputs[output_key]
70
-
71
- # if len(outputs) != 1:
72
- # raise ValueError(f"One output key expected, got {outputs.keys()}")
73
-
74
-
75
- else:
76
- output_key = self.output_key
77
- output = outputs[output_key]
78
-
79
- return inputs[prompt_input_key], output
80
-
81
- def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
82
- """Save context from this conversation to buffer."""
83
- input_str, output_str = self._get_input_output(inputs, outputs)
84
- self.chat_memory.add_user_message(input_str)
85
- self.chat_memory.add_ai_message(output_str)
86
-
87
- def clear(self) -> None:
88
- """Clear memory contents."""
89
- self.chat_memory.clear()
90
-
91
-
92
-
93
-
94
-
95
- class ConversationBufferWindowMemory(BaseChatMemory):
96
- """Buffer for storing conversation memory inside a limited size window."""
97
-
98
- human_prefix: str = "Human"
99
- ai_prefix: str = "AI"
100
- memory_key: str = "history" #: :meta private:
101
- k: int = 5
102
- """Number of messages to store in buffer."""
103
-
104
- @property
105
- def buffer(self) -> Union[str, List[BaseMessage]]:
106
- """String buffer of memory."""
107
- return self.buffer_as_messages if self.return_messages else self.buffer_as_str
108
-
109
- @property
110
- def buffer_as_str(self) -> str:
111
- """Exposes the buffer as a string in case return_messages is True."""
112
- messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
113
- return get_buffer_string(
114
- messages,
115
- human_prefix=self.human_prefix,
116
- ai_prefix=self.ai_prefix,
117
- )
118
-
119
- @property
120
- def buffer_as_messages(self) -> List[BaseMessage]:
121
- """Exposes the buffer as a list of messages in case return_messages is False."""
122
- return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
123
-
124
- @property
125
- def memory_variables(self) -> List[str]:
126
- """Will always return list of memory variables.
127
-
128
- :meta private:
129
- """
130
- return [self.memory_key]
131
-
132
- def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
133
- """Return history buffer."""
134
  return {self.memory_key: self.buffer}
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Author : Theekshana Samaradiwakara
11
+ * Description :Python Backend API to chat with private data
12
+ * CreatedDate : 14/11/2023
13
+ * LastModifiedDate : 18/11/2020
14
+ *************************************************************************/
15
+ """
16
+
17
+ from abc import ABC
18
+ from typing import Any, Dict, Optional, Tuple
19
+ # import json
20
+
21
+ from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
22
+ from langchain.memory.utils import get_prompt_input_key
23
+ from langchain.pydantic_v1 import Field
24
+ from langchain.schema import BaseChatMessageHistory, BaseMemory
25
+
26
+ from typing import List, Union
27
+
28
+ # from langchain.memory.chat_memory import BaseChatMemory
29
+ from langchain.schema.messages import BaseMessage, get_buffer_string
30
+
31
+
32
+ class BaseChatMemory(BaseMemory, ABC):
33
+ """Abstract base class for chat memory."""
34
+
35
+ chat_memory: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
36
+ output_key: Optional[str] = None
37
+ input_key: Optional[str] = None
38
+ return_messages: bool = False
39
+
40
+ def _get_input_output(
41
+ self, inputs: Dict[str, Any], outputs: Dict[str, str]
42
+ ) -> Tuple[str, str]:
43
+
44
+
45
+ if self.input_key is None:
46
+ prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
47
+ else:
48
+ prompt_input_key = self.input_key
49
+
50
+ if self.output_key is None:
51
+ """
52
+ output for agent with LLM chain tool = {answer}
53
+ output for agent with ConversationalRetrievalChain tool = {'question', 'chat_history', 'answer','source_documents'}
54
+ """
55
+
56
+ LLM_key = 'output'
57
+ Retrieval_key = 'answer'
58
+ if isinstance(outputs[LLM_key], dict):
59
+ Retrieval_dict = outputs[LLM_key]
60
+ if Retrieval_key in Retrieval_dict.keys():
61
+ #output keys are 'answer' , 'source_documents'
62
+ output = Retrieval_dict[Retrieval_key]
63
+ else:
64
+ raise ValueError(f"output key: {LLM_key} not a valid dictionary")
65
+
66
+ else:
67
+ #otherwise output key will be 'output'
68
+ output_key = list(outputs.keys())[0]
69
+ output = outputs[output_key]
70
+
71
+ # if len(outputs) != 1:
72
+ # raise ValueError(f"One output key expected, got {outputs.keys()}")
73
+
74
+
75
+ else:
76
+ output_key = self.output_key
77
+ output = outputs[output_key]
78
+
79
+ return inputs[prompt_input_key], output
80
+
81
+ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
82
+ """Save context from this conversation to buffer."""
83
+ input_str, output_str = self._get_input_output(inputs, outputs)
84
+ self.chat_memory.add_user_message(input_str)
85
+ self.chat_memory.add_ai_message(output_str)
86
+
87
+ def clear(self) -> None:
88
+ """Clear memory contents."""
89
+ self.chat_memory.clear()
90
+
91
+
92
+
93
+
94
+
95
+ class ConversationBufferWindowMemory(BaseChatMemory):
96
+ """Buffer for storing conversation memory inside a limited size window."""
97
+
98
+ human_prefix: str = "Human"
99
+ ai_prefix: str = "AI"
100
+ memory_key: str = "history" #: :meta private:
101
+ k: int = 5
102
+ """Number of messages to store in buffer."""
103
+
104
+ @property
105
+ def buffer(self) -> Union[str, List[BaseMessage]]:
106
+ """String buffer of memory."""
107
+ return self.buffer_as_messages if self.return_messages else self.buffer_as_str
108
+
109
+ @property
110
+ def buffer_as_str(self) -> str:
111
+ """Exposes the buffer as a string in case return_messages is True."""
112
+ messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
113
+ return get_buffer_string(
114
+ messages,
115
+ human_prefix=self.human_prefix,
116
+ ai_prefix=self.ai_prefix,
117
+ )
118
+
119
+ @property
120
+ def buffer_as_messages(self) -> List[BaseMessage]:
121
+ """Exposes the buffer as a list of messages in case return_messages is False."""
122
+ return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
123
+
124
+ @property
125
+ def memory_variables(self) -> List[str]:
126
+ """Will always return list of memory variables.
127
+
128
+ :meta private:
129
+ """
130
+ return [self.memory_key]
131
+
132
+ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
133
+ """Return history buffer."""
134
  return {self.memory_key: self.buffer}
reggpt/output_parsers/__init__.py ADDED
File without changes
output_parser.py β†’ reggpt/output_parsers/output_parser.py RENAMED
File without changes
reggpt/prompts/__init__.py ADDED
File without changes
reggpt/prompts/document_combine.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ document_combine_prompt = PromptTemplate(
4
+ input_variables=["source","year", "page","page_content"],
5
+ template=
6
+ """<doc> source: {source}, year: {year}, page: {page}, page content: {page_content} </doc>"""
7
+ )
reggpt/prompts/general.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ general_qa_template_Mixtral_V0= """
4
+ You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
5
+ you can answer Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations related question .
6
+
7
+ Is the provided question below a greeting? First, evaluate whether the input resembles a typical greeting or not.
8
+
9
+ Greetings are used to say 'hello' and 'how are you?' and to say 'goodbye' and 'nice speaking with you.' and 'hi, I'm (user's name).'
10
+ Greetings are words used when we want to introduce ourselves to others and when we want to find out how someone is feeling.
11
+
12
+ You can only reply to the user's greetings.
13
+ If the question is a greeting, reply accordingly as the AI assistant of company boardpac.
14
+ If the question is not related to greetings and research papers, say that it is out of your domain.
15
+ If the question is not clear enough, ask for more details and don't try to make up answers.
16
+
17
+ Answer should be polite, short, and simple.
18
+
19
+ Additionally, it's important to note that this AI assistant has access to an internal collection of research papers, and answers can be provided using the information available in those CBSL Dataset.
20
+
21
+ Question: {question}
22
+ """
23
+
24
+ general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template_Mixtral_V0)
25
+
reggpt/prompts/multi_query.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ MULTY_QUERY_PROMPT = PromptTemplate(
4
+ input_variables=["question"],
5
+ template="""You are an AI language model assistant. Your task is to generate three
6
+ different versions of the given user question to retrieve relevant documents from a vector
7
+ database. By generating multiple perspectives on the user question, your goal is to help
8
+ the user overcome some of the limitations of the distance-based similarity search.
9
+ Provide these alternative questions separated by newlines.
10
+
11
+ Dont add anything extra before or after to the 3 questions. Just give 3 lines with 3 questions.
12
+ Just provide 3 lines having 3 questions only.
13
+ Answer should be in following format.
14
+
15
+ 1. alternative question 1
16
+ 2. alternative question 2
17
+ 3. alternative question 3
18
+
19
+ Original question: {question}""",
20
+ )
reggpt/prompts/retrieval.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ retrieval_qa_template = (
4
+ """<<SYS>>
5
+ You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
6
+
7
+ please answer the question based on the chat history provided below. Answer should be short and simple as possible and on to the point.
8
+ <chat history>: {chat_history}
9
+
10
+ If the question is related to welcomes and greetings answer accordingly.
11
+
12
+ Else If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
13
+ please answer the question based only on the information provided in following central bank documents published in various years.
14
+ The published year is mentioned as the metadata 'year' of each source document.
15
+ Please notice that content of a one document of a past year can updated by a new document from a recent year.
16
+ Always try to answer with latest information and mention the year which information extracted.
17
+ If you dont know the answer say you dont know, dont try to makeup answers. Dont add any extra details that is not mentioned in the context.
18
+
19
+ <</SYS>>
20
+
21
+ [INST]
22
+ <DOCUMENTS>
23
+ {context}
24
+ </DOCUMENTS>
25
+
26
+ Question : {question}[/INST]"""
27
+ )
28
+
29
+
30
+ retrieval_qa_chain_prompt = PromptTemplate(
31
+ input_variables=["question", "context", "chat_history"],
32
+ template=retrieval_qa_template
33
+ )
reggpt/prompts/router.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ router_template_Mixtral_V0= """
4
+ You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
5
+
6
+ If a user asks a question you have to classify it to following 3 types Relevant, Greeting, Other.
7
+
8
+ "Relevant”: If the question is related to Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations.
9
+ "Greeting”: If the question is a greeting like good morning, hi my name is., thank you or General Question ask about the AI assistance of a company boardpac.
10
+ "Other”: If the question is not related to research papers.
11
+
12
+ Give the correct name of question type. If you are not sure return "Not Sure" instead.
13
+
14
+ Question : {question}
15
+ """
16
+ router_prompt=PromptTemplate.from_template(router_template_Mixtral_V0)
17
+
ensemble_retriever.py β†’ reggpt/retriever/ensemble_retriever.py RENAMED
@@ -1,228 +1,228 @@
1
- """
2
- /*************************************************************************
3
- *
4
- * CONFIDENTIAL
5
- * __________________
6
- *
7
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
- * All Rights Reserved
9
- *
10
- * Author : Theekshana Samaradiwakara
11
- * Description :Python Backend API to chat with private data
12
- * CreatedDate : 14/11/2023
13
- * LastModifiedDate : 18/03/2024
14
- *************************************************************************/
15
- """
16
-
17
- """
18
- Ensemble retriever that ensemble the results of
19
- multiple retrievers by using weighted Reciprocal Rank Fusion
20
- """
21
-
22
- import os
23
- import sys
24
-
25
- from pathlib import Path
26
- Path(__file__).resolve().parent.parent
27
-
28
- if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
29
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
30
-
31
-
32
- import logging
33
- logger = logging.getLogger(__name__)
34
- from typing import Any, Dict, List
35
-
36
- from langchain.callbacks.manager import (
37
- AsyncCallbackManagerForRetrieverRun,
38
- CallbackManagerForRetrieverRun,
39
- )
40
- from langchain.pydantic_v1 import root_validator
41
- from langchain.schema import BaseRetriever, Document
42
-
43
- import numpy as np
44
- import pandas as pd
45
-
46
-
47
- class EnsembleRetriever(BaseRetriever):
48
- """Retriever that ensembles the multiple retrievers.
49
-
50
- It uses a rank fusion.
51
-
52
- Args:
53
- retrievers: A list of retrievers to ensemble.
54
- weights: A list of weights corresponding to the retrievers. Defaults to equal
55
- weighting for all retrievers.
56
- c: A constant added to the rank, controlling the balance between the importance
57
- of high-ranked items and the consideration given to lower-ranked items.
58
- Default is 60.
59
- """
60
-
61
- retrievers: List[BaseRetriever]
62
- weights: List[float]
63
- c: int = 60
64
- date_key: str = "year"
65
- top_k: int = 4
66
-
67
- @root_validator(pre=True,allow_reuse=True)
68
- def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
69
- if not values.get("weights"):
70
- n_retrievers = len(values["retrievers"])
71
- values["weights"] = [1 / n_retrievers] * n_retrievers
72
- return values
73
-
74
- def _get_relevant_documents(
75
- self,
76
- query: str,
77
- *,
78
- run_manager: CallbackManagerForRetrieverRun,
79
- ) -> List[Document]:
80
- """
81
- Get the relevant documents for a given query.
82
-
83
- Args:
84
- query: The query to search for.
85
-
86
- Returns:
87
- A list of reranked documents.
88
- """
89
-
90
- # Get fused result of the retrievers.
91
- fused_documents = self.rank_fusion(query, run_manager)
92
-
93
- # check for key exists
94
- if fused_documents[0].metadata[self.date_key] != None:
95
- doc_dates = pd.to_datetime(
96
- [doc.metadata[self.date_key] for doc in fused_documents]
97
- )
98
- sorted_node_idxs = np.flip(doc_dates.argsort())
99
- fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
100
- logger.info('Ensemble Retriever Documents sorted by year')
101
-
102
- # return fused_documents[:self.top_k]
103
- return fused_documents
104
-
105
- async def _aget_relevant_documents(
106
- self,
107
- query: str,
108
- *,
109
- run_manager: AsyncCallbackManagerForRetrieverRun,
110
- ) -> List[Document]:
111
- """
112
- Asynchronously get the relevant documents for a given query.
113
-
114
- Args:
115
- query: The query to search for.
116
-
117
- Returns:
118
- A list of reranked documents.
119
- """
120
-
121
- # Get fused result of the retrievers.
122
- fused_documents = await self.arank_fusion(query, run_manager)
123
-
124
- return fused_documents
125
-
126
- def rank_fusion(
127
- self, query: str, run_manager: CallbackManagerForRetrieverRun
128
- ) -> List[Document]:
129
- """
130
- Retrieve the results of the retrievers and use rank_fusion_func to get
131
- the final result.
132
-
133
- Args:
134
- query: The query to search for.
135
-
136
- Returns:
137
- A list of reranked documents.
138
- """
139
-
140
- # Get the results of all retrievers.
141
- retriever_docs = [
142
- retriever.get_relevant_documents(
143
- query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
144
- )
145
- for i, retriever in enumerate(self.retrievers)
146
- ]
147
-
148
- # apply rank fusion
149
- fused_documents = self.weighted_reciprocal_rank(retriever_docs)
150
-
151
- return fused_documents
152
-
153
- async def arank_fusion(
154
- self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
155
- ) -> List[Document]:
156
- """
157
- Asynchronously retrieve the results of the retrievers
158
- and use rank_fusion_func to get the final result.
159
-
160
- Args:
161
- query: The query to search for.
162
-
163
- Returns:
164
- A list of reranked documents.
165
- """
166
-
167
- # Get the results of all retrievers.
168
- retriever_docs = [
169
- await retriever.aget_relevant_documents(
170
- query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
171
- )
172
- for i, retriever in enumerate(self.retrievers)
173
- ]
174
-
175
- # apply rank fusion
176
- fused_documents = self.weighted_reciprocal_rank(retriever_docs)
177
-
178
- return fused_documents
179
-
180
- def weighted_reciprocal_rank(
181
- self, doc_lists: List[List[Document]]
182
- ) -> List[Document]:
183
- """
184
- Perform weighted Reciprocal Rank Fusion on multiple rank lists.
185
- You can find more details about RRF here:
186
- https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
187
-
188
- Args:
189
- doc_lists: A list of rank lists, where each rank list contains unique items.
190
-
191
- Returns:
192
- list: The final aggregated list of items sorted by their weighted RRF
193
- scores in descending order.
194
- """
195
- if len(doc_lists) != len(self.weights):
196
- raise ValueError(
197
- "Number of rank lists must be equal to the number of weights."
198
- )
199
-
200
- # Create a union of all unique documents in the input doc_lists
201
- all_documents = set()
202
- for doc_list in doc_lists:
203
- for doc in doc_list:
204
- all_documents.add(doc.page_content)
205
-
206
- # Initialize the RRF score dictionary for each document
207
- rrf_score_dic = {doc: 0.0 for doc in all_documents}
208
-
209
- # Calculate RRF scores for each document
210
- for doc_list, weight in zip(doc_lists, self.weights):
211
- for rank, doc in enumerate(doc_list, start=1):
212
- rrf_score = weight * (1 / (rank + self.c))
213
- rrf_score_dic[doc.page_content] += rrf_score
214
-
215
- # Sort documents by their RRF scores in descending order
216
- sorted_documents = sorted(
217
- rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
218
- )
219
-
220
- # Map the sorted page_content back to the original document objects
221
- page_content_to_doc_map = {
222
- doc.page_content: doc for doc_list in doc_lists for doc in doc_list
223
- }
224
- sorted_docs = [
225
- page_content_to_doc_map[page_content] for page_content in sorted_documents
226
- ]
227
-
228
- return sorted_docs
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Author : Theekshana Samaradiwakara
11
+ * Description :Python Backend API to chat with private data
12
+ * CreatedDate : 14/11/2023
13
+ * LastModifiedDate : 18/03/2024
14
+ *************************************************************************/
15
+ """
16
+
17
+ """
18
+ Ensemble retriever that ensemble the results of
19
+ multiple retrievers by using weighted Reciprocal Rank Fusion
20
+ """
21
+
22
+ import os
23
+ import sys
24
+
25
+ from pathlib import Path
26
+ Path(__file__).resolve().parent.parent
27
+
28
+ if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
29
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
30
+
31
+
32
+ import logging
33
+ logger = logging.getLogger(__name__)
34
+ from typing import Any, Dict, List
35
+
36
+ from langchain.callbacks.manager import (
37
+ AsyncCallbackManagerForRetrieverRun,
38
+ CallbackManagerForRetrieverRun,
39
+ )
40
+ from langchain.pydantic_v1 import root_validator
41
+ from langchain.schema import BaseRetriever, Document
42
+
43
+ import numpy as np
44
+ import pandas as pd
45
+
46
+
47
+ class EnsembleRetriever(BaseRetriever):
48
+ """Retriever that ensembles the multiple retrievers.
49
+
50
+ It uses a rank fusion.
51
+
52
+ Args:
53
+ retrievers: A list of retrievers to ensemble.
54
+ weights: A list of weights corresponding to the retrievers. Defaults to equal
55
+ weighting for all retrievers.
56
+ c: A constant added to the rank, controlling the balance between the importance
57
+ of high-ranked items and the consideration given to lower-ranked items.
58
+ Default is 60.
59
+ """
60
+
61
+ retrievers: List[BaseRetriever]
62
+ weights: List[float]
63
+ c: int = 60
64
+ date_key: str = "year"
65
+ top_k: int = 4
66
+
67
+ @root_validator(pre=True,allow_reuse=True)
68
+ def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
69
+ if not values.get("weights"):
70
+ n_retrievers = len(values["retrievers"])
71
+ values["weights"] = [1 / n_retrievers] * n_retrievers
72
+ return values
73
+
74
+ def _get_relevant_documents(
75
+ self,
76
+ query: str,
77
+ *,
78
+ run_manager: CallbackManagerForRetrieverRun,
79
+ ) -> List[Document]:
80
+ """
81
+ Get the relevant documents for a given query.
82
+
83
+ Args:
84
+ query: The query to search for.
85
+
86
+ Returns:
87
+ A list of reranked documents.
88
+ """
89
+
90
+ # Get fused result of the retrievers.
91
+ fused_documents = self.rank_fusion(query, run_manager)
92
+
93
+ # check for key exists
94
+ if fused_documents[0].metadata[self.date_key] != None:
95
+ doc_dates = pd.to_datetime(
96
+ [doc.metadata[self.date_key] for doc in fused_documents]
97
+ )
98
+ sorted_node_idxs = np.flip(doc_dates.argsort())
99
+ fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
100
+ logger.info('Ensemble Retriever Documents sorted by year')
101
+
102
+ # return fused_documents[:self.top_k]
103
+ return fused_documents
104
+
105
+ async def _aget_relevant_documents(
106
+ self,
107
+ query: str,
108
+ *,
109
+ run_manager: AsyncCallbackManagerForRetrieverRun,
110
+ ) -> List[Document]:
111
+ """
112
+ Asynchronously get the relevant documents for a given query.
113
+
114
+ Args:
115
+ query: The query to search for.
116
+
117
+ Returns:
118
+ A list of reranked documents.
119
+ """
120
+
121
+ # Get fused result of the retrievers.
122
+ fused_documents = await self.arank_fusion(query, run_manager)
123
+
124
+ return fused_documents
125
+
126
+ def rank_fusion(
127
+ self, query: str, run_manager: CallbackManagerForRetrieverRun
128
+ ) -> List[Document]:
129
+ """
130
+ Retrieve the results of the retrievers and use rank_fusion_func to get
131
+ the final result.
132
+
133
+ Args:
134
+ query: The query to search for.
135
+
136
+ Returns:
137
+ A list of reranked documents.
138
+ """
139
+
140
+ # Get the results of all retrievers.
141
+ retriever_docs = [
142
+ retriever.get_relevant_documents(
143
+ query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
144
+ )
145
+ for i, retriever in enumerate(self.retrievers)
146
+ ]
147
+
148
+ # apply rank fusion
149
+ fused_documents = self.weighted_reciprocal_rank(retriever_docs)
150
+
151
+ return fused_documents
152
+
153
+ async def arank_fusion(
154
+ self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
155
+ ) -> List[Document]:
156
+ """
157
+ Asynchronously retrieve the results of the retrievers
158
+ and use rank_fusion_func to get the final result.
159
+
160
+ Args:
161
+ query: The query to search for.
162
+
163
+ Returns:
164
+ A list of reranked documents.
165
+ """
166
+
167
+ # Get the results of all retrievers.
168
+ retriever_docs = [
169
+ await retriever.aget_relevant_documents(
170
+ query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
171
+ )
172
+ for i, retriever in enumerate(self.retrievers)
173
+ ]
174
+
175
+ # apply rank fusion
176
+ fused_documents = self.weighted_reciprocal_rank(retriever_docs)
177
+
178
+ return fused_documents
179
+
180
+ def weighted_reciprocal_rank(
181
+ self, doc_lists: List[List[Document]]
182
+ ) -> List[Document]:
183
+ """
184
+ Perform weighted Reciprocal Rank Fusion on multiple rank lists.
185
+ You can find more details about RRF here:
186
+ https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
187
+
188
+ Args:
189
+ doc_lists: A list of rank lists, where each rank list contains unique items.
190
+
191
+ Returns:
192
+ list: The final aggregated list of items sorted by their weighted RRF
193
+ scores in descending order.
194
+ """
195
+ if len(doc_lists) != len(self.weights):
196
+ raise ValueError(
197
+ "Number of rank lists must be equal to the number of weights."
198
+ )
199
+
200
+ # Create a union of all unique documents in the input doc_lists
201
+ all_documents = set()
202
+ for doc_list in doc_lists:
203
+ for doc in doc_list:
204
+ all_documents.add(doc.page_content)
205
+
206
+ # Initialize the RRF score dictionary for each document
207
+ rrf_score_dic = {doc: 0.0 for doc in all_documents}
208
+
209
+ # Calculate RRF scores for each document
210
+ for doc_list, weight in zip(doc_lists, self.weights):
211
+ for rank, doc in enumerate(doc_list, start=1):
212
+ rrf_score = weight * (1 / (rank + self.c))
213
+ rrf_score_dic[doc.page_content] += rrf_score
214
+
215
+ # Sort documents by their RRF scores in descending order
216
+ sorted_documents = sorted(
217
+ rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
218
+ )
219
+
220
+ # Map the sorted page_content back to the original document objects
221
+ page_content_to_doc_map = {
222
+ doc.page_content: doc for doc_list in doc_lists for doc in doc_list
223
+ }
224
+ sorted_docs = [
225
+ page_content_to_doc_map[page_content] for page_content in sorted_documents
226
+ ]
227
+
228
+ return sorted_docs
multi_query_retriever.py β†’ reggpt/retriever/multi_query_retriever.py RENAMED
@@ -1,254 +1,254 @@
1
-
2
- """
3
- /*************************************************************************
4
- *
5
- * CONFIDENTIAL
6
- * __________________
7
- *
8
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
9
- * All Rights Reserved
10
- *
11
- * Author : Theekshana Samaradiwakara
12
- * Description :Python Backend API to chat with private data
13
- * CreatedDate : 14/11/2023
14
- * LastModifiedDate : 21/03/2024
15
- *************************************************************************/
16
- """
17
-
18
- import asyncio
19
- import logging
20
- from typing import List, Optional, Sequence
21
-
22
- from langchain_core.callbacks import (
23
- AsyncCallbackManagerForRetrieverRun,
24
- CallbackManagerForRetrieverRun,
25
- )
26
- from langchain_core.documents import Document
27
- from langchain_core.language_models import BaseLanguageModel
28
- from langchain_core.output_parsers import BaseOutputParser
29
- from langchain_core.prompts.prompt import PromptTemplate
30
- from langchain_core.retrievers import BaseRetriever
31
-
32
- from langchain.chains.llm import LLMChain
33
-
34
- import numpy as np
35
- import pandas as pd
36
-
37
- logger = logging.getLogger(__name__)
38
-
39
- from prompts import MULTY_QUERY_PROMPT
40
-
41
- class LineListOutputParser(BaseOutputParser[List[str]]):
42
- """Output parser for a list of lines."""
43
-
44
- def parse(self, text: str) -> List[str]:
45
- lines = text.strip().split("\n")
46
- return lines
47
-
48
-
49
- # Default prompt
50
- # DEFAULT_QUERY_PROMPT = PromptTemplate(
51
- # input_variables=["question"],
52
- # template="""You are an AI language model assistant. Your task is
53
- # to generate 3 different versions of the given user
54
- # question to retrieve relevant documents from a vector database.
55
- # By generating multiple perspectives on the user question,
56
- # your goal is to help the user overcome some of the limitations
57
- # of distance-based similarity search. Provide these alternative
58
- # questions separated by newlines. Original question: {question}""",
59
- # )
60
-
61
-
62
-
63
-
64
- def _unique_documents(documents: Sequence[Document]) -> List[Document]:
65
- return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
66
-
67
-
68
- class MultiQueryRetriever(BaseRetriever):
69
- """Given a query, use an LLM to write a set of queries.
70
-
71
- Retrieve docs for each query. Return the unique union of all retrieved docs.
72
- """
73
-
74
- retriever: BaseRetriever
75
- llm_chain: LLMChain
76
- verbose: bool = True
77
- parser_key: str = "lines"
78
- """DEPRECATED. parser_key is no longer used and should not be specified."""
79
- include_original: bool = False
80
- """Whether to include the original query in the list of generated queries."""
81
- date_key: str = "year"
82
- top_k: int = 4
83
-
84
- @classmethod
85
- def from_llm(
86
- cls,
87
- retriever: BaseRetriever,
88
- llm: BaseLanguageModel,
89
- prompt: PromptTemplate = MULTY_QUERY_PROMPT,
90
- parser_key: Optional[str] = None,
91
- include_original: bool = False,
92
- ) -> "MultiQueryRetriever":
93
- """Initialize from llm using default template.
94
-
95
- Args:
96
- retriever: retriever to query documents from
97
- llm: llm for query generation using DEFAULT_QUERY_PROMPT
98
- include_original: Whether to include the original query in the list of
99
- generated queries.
100
-
101
- Returns:
102
- MultiQueryRetriever
103
- """
104
- output_parser = LineListOutputParser()
105
- llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
106
- return cls(
107
- retriever=retriever,
108
- llm_chain=llm_chain,
109
- include_original=include_original,
110
- )
111
-
112
- async def _aget_relevant_documents(
113
- self,
114
- query: str,
115
- *,
116
- run_manager: AsyncCallbackManagerForRetrieverRun,
117
- ) -> List[Document]:
118
- """Get relevant documents given a user query.
119
-
120
- Args:
121
- question: user query
122
-
123
- Returns:
124
- Unique union of relevant documents from all generated queries
125
- """
126
- queries = await self.agenerate_queries(query, run_manager)
127
- if self.include_original:
128
- queries.append(query)
129
- documents = await self.aretrieve_documents(queries, run_manager)
130
- return self.unique_union(documents)
131
-
132
-
133
- async def agenerate_queries(
134
- self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
135
- ) -> List[str]:
136
- """Generate queries based upon user input.
137
-
138
- Args:
139
- question: user query
140
-
141
- Returns:
142
- List of LLM generated queries that are similar to the user input
143
- """
144
- response = await self.llm_chain.ainvoke(
145
- inputs={"question": question}, callbacks=run_manager.get_child()
146
- )
147
- lines = response["text"]
148
- if self.verbose:
149
- logger.info(f"Generated queries: {lines}")
150
- return lines
151
-
152
- async def aretrieve_documents(
153
- self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
154
- ) -> List[Document]:
155
- """Run all LLM generated queries.
156
-
157
- Args:
158
- queries: query list
159
-
160
- Returns:
161
- List of retrieved Documents
162
- """
163
- document_lists = await asyncio.gather(
164
- *(
165
- self.retriever.aget_relevant_documents(
166
- query, callbacks=run_manager.get_child()
167
- )
168
- for query in queries
169
- )
170
- )
171
- return [doc for docs in document_lists for doc in docs]
172
-
173
- def _get_relevant_documents(
174
- self,
175
- query: str,
176
- *,
177
- run_manager: CallbackManagerForRetrieverRun,
178
- ) -> List[Document]:
179
- """Get relevant documents given a user query.
180
-
181
- Args:
182
- question: user query
183
-
184
- Returns:
185
- Unique union of relevant documents from all generated queries
186
- """
187
- queries = self.generate_queries(query, run_manager)
188
- if self.include_original:
189
- queries.append(query)
190
- documents = self.retrieve_documents(queries, run_manager)
191
- fused_documents= self.unique_union(documents)
192
- # check for key exists
193
- if fused_documents[0].metadata[self.date_key] != None:
194
- doc_dates = pd.to_datetime(
195
- [doc.metadata[self.date_key] for doc in fused_documents]
196
- )
197
- sorted_node_idxs = np.flip(doc_dates.argsort())
198
- fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
199
- logger.info('Documents sorted by year')
200
-
201
- return fused_documents[:self.top_k]
202
-
203
-
204
-
205
- def generate_queries(
206
- self, question: str, run_manager: CallbackManagerForRetrieverRun
207
- ) -> List[str]:
208
- """Generate queries based upon user input.
209
-
210
- Args:
211
- question: user query
212
-
213
- Returns:
214
- List of LLM generated queries that are similar to the user input
215
- """
216
- response = self.llm_chain.invoke(
217
- {"question": question}, callbacks=run_manager.get_child()
218
- )
219
- lines = response["text"]
220
- if self.verbose:
221
- logger.info(f"Generated queries: {lines}")
222
- return lines
223
-
224
- def retrieve_documents(
225
- self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
226
- ) -> List[Document]:
227
- """Run all LLM generated queries.
228
-
229
- Args:
230
- queries: query list
231
-
232
- Returns:
233
- List of retrieved Documents
234
- """
235
- documents = []
236
- for query in queries:
237
- logger.info(f"MQ Retriever question: {query}")
238
- docs = self.retriever.get_relevant_documents(
239
- query, callbacks=run_manager.get_child()
240
- )
241
- documents.extend(docs)
242
- return documents
243
-
244
- def unique_union(self, documents: List[Document]) -> List[Document]:
245
- """Get unique Documents.
246
-
247
- Args:
248
- documents: List of retrieved Documents
249
-
250
- Returns:
251
- List of unique retrieved Documents
252
- """
253
- return _unique_documents(documents)
254
 
 
1
+
2
+ """
3
+ /*************************************************************************
4
+ *
5
+ * CONFIDENTIAL
6
+ * __________________
7
+ *
8
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
9
+ * All Rights Reserved
10
+ *
11
+ * Author : Theekshana Samaradiwakara
12
+ * Description :Python Backend API to chat with private data
13
+ * CreatedDate : 14/11/2023
14
+ * LastModifiedDate : 21/03/2024
15
+ *************************************************************************/
16
+ """
17
+
18
+ import asyncio
19
+ import logging
20
+ from typing import List, Optional, Sequence
21
+
22
+ from langchain_core.callbacks import (
23
+ AsyncCallbackManagerForRetrieverRun,
24
+ CallbackManagerForRetrieverRun,
25
+ )
26
+ from langchain_core.documents import Document
27
+ from langchain_core.language_models import BaseLanguageModel
28
+ from langchain_core.output_parsers import BaseOutputParser
29
+ from langchain_core.prompts.prompt import PromptTemplate
30
+ from langchain_core.retrievers import BaseRetriever
31
+
32
+ from langchain.chains.llm import LLMChain
33
+
34
+ import numpy as np
35
+ import pandas as pd
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ from reggpt.prompts.prompts import MULTY_QUERY_PROMPT
40
+
41
+ class LineListOutputParser(BaseOutputParser[List[str]]):
42
+ """Output parser for a list of lines."""
43
+
44
+ def parse(self, text: str) -> List[str]:
45
+ lines = text.strip().split("\n")
46
+ return lines
47
+
48
+
49
+ # Default prompt
50
+ # DEFAULT_QUERY_PROMPT = PromptTemplate(
51
+ # input_variables=["question"],
52
+ # template="""You are an AI language model assistant. Your task is
53
+ # to generate 3 different versions of the given user
54
+ # question to retrieve relevant documents from a vector database.
55
+ # By generating multiple perspectives on the user question,
56
+ # your goal is to help the user overcome some of the limitations
57
+ # of distance-based similarity search. Provide these alternative
58
+ # questions separated by newlines. Original question: {question}""",
59
+ # )
60
+
61
+
62
+
63
+
64
+ def _unique_documents(documents: Sequence[Document]) -> List[Document]:
65
+ return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
66
+
67
+
68
+ class MultiQueryRetriever(BaseRetriever):
69
+ """Given a query, use an LLM to write a set of queries.
70
+
71
+ Retrieve docs for each query. Return the unique union of all retrieved docs.
72
+ """
73
+
74
+ retriever: BaseRetriever
75
+ llm_chain: LLMChain
76
+ verbose: bool = True
77
+ parser_key: str = "lines"
78
+ """DEPRECATED. parser_key is no longer used and should not be specified."""
79
+ include_original: bool = False
80
+ """Whether to include the original query in the list of generated queries."""
81
+ date_key: str = "year"
82
+ top_k: int = 4
83
+
84
+ @classmethod
85
+ def from_llm(
86
+ cls,
87
+ retriever: BaseRetriever,
88
+ llm: BaseLanguageModel,
89
+ prompt: PromptTemplate = MULTY_QUERY_PROMPT,
90
+ parser_key: Optional[str] = None,
91
+ include_original: bool = False,
92
+ ) -> "MultiQueryRetriever":
93
+ """Initialize from llm using default template.
94
+
95
+ Args:
96
+ retriever: retriever to query documents from
97
+ llm: llm for query generation using DEFAULT_QUERY_PROMPT
98
+ include_original: Whether to include the original query in the list of
99
+ generated queries.
100
+
101
+ Returns:
102
+ MultiQueryRetriever
103
+ """
104
+ output_parser = LineListOutputParser()
105
+ llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
106
+ return cls(
107
+ retriever=retriever,
108
+ llm_chain=llm_chain,
109
+ include_original=include_original,
110
+ )
111
+
112
+ async def _aget_relevant_documents(
113
+ self,
114
+ query: str,
115
+ *,
116
+ run_manager: AsyncCallbackManagerForRetrieverRun,
117
+ ) -> List[Document]:
118
+ """Get relevant documents given a user query.
119
+
120
+ Args:
121
+ question: user query
122
+
123
+ Returns:
124
+ Unique union of relevant documents from all generated queries
125
+ """
126
+ queries = await self.agenerate_queries(query, run_manager)
127
+ if self.include_original:
128
+ queries.append(query)
129
+ documents = await self.aretrieve_documents(queries, run_manager)
130
+ return self.unique_union(documents)
131
+
132
+
133
+ async def agenerate_queries(
134
+ self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
135
+ ) -> List[str]:
136
+ """Generate queries based upon user input.
137
+
138
+ Args:
139
+ question: user query
140
+
141
+ Returns:
142
+ List of LLM generated queries that are similar to the user input
143
+ """
144
+ response = await self.llm_chain.ainvoke(
145
+ inputs={"question": question}, callbacks=run_manager.get_child()
146
+ )
147
+ lines = response["text"]
148
+ if self.verbose:
149
+ logger.info(f"Generated queries: {lines}")
150
+ return lines
151
+
152
+ async def aretrieve_documents(
153
+ self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
154
+ ) -> List[Document]:
155
+ """Run all LLM generated queries.
156
+
157
+ Args:
158
+ queries: query list
159
+
160
+ Returns:
161
+ List of retrieved Documents
162
+ """
163
+ document_lists = await asyncio.gather(
164
+ *(
165
+ self.retriever.aget_relevant_documents(
166
+ query, callbacks=run_manager.get_child()
167
+ )
168
+ for query in queries
169
+ )
170
+ )
171
+ return [doc for docs in document_lists for doc in docs]
172
+
173
+ def _get_relevant_documents(
174
+ self,
175
+ query: str,
176
+ *,
177
+ run_manager: CallbackManagerForRetrieverRun,
178
+ ) -> List[Document]:
179
+ """Get relevant documents given a user query.
180
+
181
+ Args:
182
+ question: user query
183
+
184
+ Returns:
185
+ Unique union of relevant documents from all generated queries
186
+ """
187
+ queries = self.generate_queries(query, run_manager)
188
+ if self.include_original:
189
+ queries.append(query)
190
+ documents = self.retrieve_documents(queries, run_manager)
191
+ fused_documents= self.unique_union(documents)
192
+ # check for key exists
193
+ if fused_documents[0].metadata[self.date_key] != None:
194
+ doc_dates = pd.to_datetime(
195
+ [doc.metadata[self.date_key] for doc in fused_documents]
196
+ )
197
+ sorted_node_idxs = np.flip(doc_dates.argsort())
198
+ fused_documents = [fused_documents[idx] for idx in sorted_node_idxs]
199
+ logger.info('Documents sorted by year')
200
+
201
+ return fused_documents[:self.top_k]
202
+
203
+
204
+
205
+ def generate_queries(
206
+ self, question: str, run_manager: CallbackManagerForRetrieverRun
207
+ ) -> List[str]:
208
+ """Generate queries based upon user input.
209
+
210
+ Args:
211
+ question: user query
212
+
213
+ Returns:
214
+ List of LLM generated queries that are similar to the user input
215
+ """
216
+ response = self.llm_chain.invoke(
217
+ {"question": question}, callbacks=run_manager.get_child()
218
+ )
219
+ lines = response["text"]
220
+ if self.verbose:
221
+ logger.info(f"Generated queries: {lines}")
222
+ return lines
223
+
224
+ def retrieve_documents(
225
+ self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
226
+ ) -> List[Document]:
227
+ """Run all LLM generated queries.
228
+
229
+ Args:
230
+ queries: query list
231
+
232
+ Returns:
233
+ List of retrieved Documents
234
+ """
235
+ documents = []
236
+ for query in queries:
237
+ logger.info(f"MQ Retriever question: {query}")
238
+ docs = self.retriever.get_relevant_documents(
239
+ query, callbacks=run_manager.get_child()
240
+ )
241
+ documents.extend(docs)
242
+ return documents
243
+
244
+ def unique_union(self, documents: List[Document]) -> List[Document]:
245
+ """Get unique Documents.
246
+
247
+ Args:
248
+ documents: List of retrieved Documents
249
+
250
+ Returns:
251
+ List of unique retrieved Documents
252
+ """
253
+ return _unique_documents(documents)
254
 
reggpt/routers/__init__.py ADDED
File without changes
controller.py β†’ reggpt/routers/controller.py RENAMED
@@ -16,13 +16,13 @@
16
 
17
  import logging
18
  logger = logging.getLogger(__name__)
19
- from config import AVALIABLE_MODELS , MEMORY_WINDOW_K
20
 
21
  # from qaPipeline import QAPipeline
22
  # from qaPipeline_retriever_only import QAPipeline
23
  # qaPipeline = QAPipeline()
24
 
25
- from qaPipeline import run_agent
26
 
27
  def get_QA_Answers(userQuery):
28
  # model=userQuery.model
 
16
 
17
  import logging
18
  logger = logging.getLogger(__name__)
19
+ from reggpt.configs.config import AVALIABLE_MODELS , MEMORY_WINDOW_K
20
 
21
  # from qaPipeline import QAPipeline
22
  # from qaPipeline_retriever_only import QAPipeline
23
  # qaPipeline = QAPipeline()
24
 
25
+ from reggpt.routers.qaPipeline import run_agent
26
 
27
  def get_QA_Answers(userQuery):
28
  # model=userQuery.model
reggpt/routers/general.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import logging
5
+ logger = logging.getLogger(__name__)
6
+ from dotenv import load_dotenv
7
+ from fastapi import HTTPException
8
+ from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
9
+ from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
10
+
11
+ from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
12
+ from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
13
+ load_dotenv()
14
+
15
+ verbose = os.environ.get('VERBOSE')
16
+
17
+ qa_model_type=QA_MODEL_TYPE
18
+ general_qa_model_type=GENERAL_QA_MODEL_TYPE
19
+ router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
20
+ multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
21
+ # model_type="tiiuae/falcon-7b-instruct"
22
+
23
+ # retriever=load_faiss_retriever()
24
+ retriever=load_ensemble_retriever()
25
+ # retriever=load_multi_query_retriever(multi_query_model_type)
26
+ logger.info("retriever loaded:")
27
+
28
+ qa_chain= get_qa_chain(qa_model_type,retriever)
29
+ general_qa_chain= get_general_qa_chain(general_qa_model_type)
30
+ router_chain= get_router_chain(router_model_type)
31
+
32
+ def run_general_qa_chain(query):
33
+ try:
34
+ logger.info(f"run_general_qa_chain : Question: {query}")
35
+
36
+ # Get the answer from the chain
37
+ start = time.time()
38
+ res = general_qa_chain.invoke(query)
39
+ end = time.time()
40
+
41
+ # log the result
42
+
43
+ logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
44
+
45
+ return general_qa_chain_output_parser(res)
46
+
47
+ except Exception as e:
48
+ logger.exception(e)
49
+ raise e
reggpt/routers/out_of_domain.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ logger = logging.getLogger(__name__)
5
+ from dotenv import load_dotenv
6
+ from fastapi import HTTPException
7
+ from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
8
+ from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
9
+
10
+ from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
11
+ from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
12
+ load_dotenv()
13
+
14
+ verbose = os.environ.get('VERBOSE')
15
+
16
+ qa_model_type=QA_MODEL_TYPE
17
+ general_qa_model_type=GENERAL_QA_MODEL_TYPE
18
+ router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
19
+ multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
20
+ # model_type="tiiuae/falcon-7b-instruct"
21
+
22
+ # retriever=load_faiss_retriever()
23
+ retriever=load_ensemble_retriever()
24
+ # retriever=load_multi_query_retriever(multi_query_model_type)
25
+ logger.info("retriever loaded:")
26
+
27
+ qa_chain= get_qa_chain(qa_model_type,retriever)
28
+ general_qa_chain= get_general_qa_chain(general_qa_model_type)
29
+ router_chain= get_router_chain(router_model_type)
30
+ def run_out_of_domain_chain(query):
31
+ return out_of_domain_chain_parser(query)
reggpt/routers/qa.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Author : Theekshana Samaradiwakara
11
+ * Description :Python Backend API to chat with private data
12
+ * CreatedDate : 14/11/2023
13
+ * LastModifiedDate : 18/03/2024
14
+ *************************************************************************/
15
+ """
16
+
17
+ import os
18
+ import time
19
+ import logging
20
+ logger = logging.getLogger(__name__)
21
+ from dotenv import load_dotenv
22
+ from fastapi import HTTPException
23
+ from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
24
+ from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
25
+
26
+ from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
27
+ from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
28
+ load_dotenv()
29
+
30
+ verbose = os.environ.get('VERBOSE')
31
+
32
+ qa_model_type=QA_MODEL_TYPE
33
+ general_qa_model_type=GENERAL_QA_MODEL_TYPE
34
+ router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
35
+ multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
36
+ # model_type="tiiuae/falcon-7b-instruct"
37
+
38
+ # retriever=load_faiss_retriever()
39
+ retriever=load_ensemble_retriever()
40
+ # retriever=load_multi_query_retriever(multi_query_model_type)
41
+ logger.info("retriever loaded:")
42
+
43
+ qa_chain= get_qa_chain(qa_model_type,retriever)
44
+ general_qa_chain= get_general_qa_chain(general_qa_model_type)
45
+ router_chain= get_router_chain(router_model_type)
46
+
47
+ def run_qa_chain(query):
48
+ try:
49
+ logger.info(f"run_qa_chain : Question: {query}")
50
+ # Get the answer from the chain
51
+ start = time.time()
52
+ # res = qa_chain(query)
53
+ res = qa_chain.invoke({"question": query, "chat_history":""})
54
+ # res = response
55
+ # answer, docs = res['result'],res['source_documents']
56
+ end = time.time()
57
+
58
+ # log the result
59
+ logger.info(f"Answer (took {round(end - start, 2)} s.) \n: {res}")
60
+
61
+ return qa_chain_output_parser(res)
62
+
63
+ except Exception as e:
64
+ logger.exception(e)
65
+ raise e
66
+
reggpt/routers/qaPipeline.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ /*************************************************************************
3
+ *
4
+ * CONFIDENTIAL
5
+ * __________________
6
+ *
7
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
8
+ * All Rights Reserved
9
+ *
10
+ * Author : Theekshana Samaradiwakara
11
+ * Description :Python Backend API to chat with private data
12
+ * CreatedDate : 14/11/2023
13
+ * LastModifiedDate : 18/03/2024
14
+ *************************************************************************/
15
+ """
16
+
17
+ import os
18
+ import time
19
+ import logging
20
+ logger = logging.getLogger(__name__)
21
+ from dotenv import load_dotenv
22
+ from fastapi import HTTPException
23
+ from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
24
+ from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
25
+
26
+ from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
27
+ from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
28
+ load_dotenv()
29
+
30
+ verbose = os.environ.get('VERBOSE')
31
+
32
+ qa_model_type=QA_MODEL_TYPE
33
+ general_qa_model_type=GENERAL_QA_MODEL_TYPE
34
+ router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
35
+ multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
36
+ # model_type="tiiuae/falcon-7b-instruct"
37
+
38
+ # retriever=load_faiss_retriever()
39
+ retriever=load_ensemble_retriever()
40
+ # retriever=load_multi_query_retriever(multi_query_model_type)
41
+ logger.info("retriever loaded:")
42
+
43
+ qa_chain= get_qa_chain(qa_model_type,retriever)
44
+ general_qa_chain= get_general_qa_chain(general_qa_model_type)
45
+ router_chain= get_router_chain(router_model_type)
reggpt/schemas/__init__.py ADDED
File without changes
schema.py β†’ reggpt/schemas/schema.py RENAMED
File without changes
reggpt/utils/__init__.py ADDED
File without changes
retriever.py β†’ reggpt/utils/retriever.py RENAMED
@@ -1,137 +1,137 @@
1
-
2
- """
3
- /*************************************************************************
4
- *
5
- * CONFIDENTIAL
6
- * __________________
7
- *
8
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
9
- * All Rights Reserved
10
- *
11
- * Author : Theekshana Samaradiwakara
12
- * Description :Python Backend API to chat with private data
13
- * CreatedDate : 19/03/2023
14
- * LastModifiedDate : 19/03/2024
15
- *************************************************************************/
16
- """
17
-
18
- """
19
- Ensemble retriever that ensemble the results of
20
- multiple retrievers by using weighted Reciprocal Rank Fusion
21
- """
22
-
23
- import logging
24
- logger = logging.getLogger(__name__)
25
-
26
- from faissDb import load_FAISS_store
27
-
28
- from langchain_community.retrievers import BM25Retriever
29
- from langchain_community.document_loaders import PyPDFLoader
30
- from langchain_community.document_loaders import DirectoryLoader
31
- from langchain_text_splitters import RecursiveCharacterTextSplitter
32
-
33
- from langchain.schema import Document
34
- from typing import Iterable
35
- import json
36
-
37
- def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
38
- with open(file_path, 'w') as jsonl_file:
39
- for doc in array:
40
- jsonl_file.write(doc.json() + '\n')
41
-
42
- def load_docs_from_jsonl(file_path)->Iterable[Document]:
43
- array = []
44
- with open(file_path, 'r') as jsonl_file:
45
- for line in jsonl_file:
46
- data = json.loads(line)
47
- obj = Document(**data)
48
- array.append(obj)
49
- return array
50
-
51
- def split_documents():
52
- chunk_size=2000
53
- chunk_overlap=100
54
-
55
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
56
-
57
- years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]
58
- docs_list=[]
59
- splits_list=[]
60
-
61
-
62
- for year in years:
63
- data_path= f"data/CBSL/{year}"
64
- logger.info(f"Loading year : {data_path}")
65
-
66
- documents = DirectoryLoader(data_path, loader_cls=PyPDFLoader).load()
67
-
68
- for doc in documents:
69
- doc.metadata['year']=year
70
- logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" )
71
- docs_list.append(doc)
72
-
73
- texts = text_splitter.split_documents(documents)
74
- for text in texts:
75
- splits_list.append(text)
76
-
77
- splitted_texts_file='data/splitted_texts.jsonl'
78
- save_docs_to_jsonl(splits_list,splitted_texts_file)
79
-
80
- from ensemble_retriever import EnsembleRetriever
81
- from multi_query_retriever import MultiQueryRetriever
82
-
83
- def load_faiss_retriever():
84
- try:
85
- vectorstore=load_FAISS_store()
86
- retriever = vectorstore.as_retriever(
87
- # search_type="mmr",
88
- search_kwargs={'k': 5, 'fetch_k': 10}
89
- )
90
- logger.info("FAISS Retriever loaded:")
91
- return retriever
92
-
93
- except Exception as e:
94
- logger.exception(e)
95
- raise e
96
-
97
- def load_ensemble_retriever():
98
- try:
99
- # splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl'))
100
- splitted_texts_file='./data/splitted_texts.jsonl'
101
- sementic_k = 4
102
- bm25_k = 2
103
- splits_list = load_docs_from_jsonl(splitted_texts_file)
104
-
105
- bm25_retriever = BM25Retriever.from_documents(splits_list)
106
- bm25_retriever.k = bm25_k
107
-
108
- faiss_vectorstore = load_FAISS_store()
109
- faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,})
110
-
111
- ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
112
- ensemble_retriever.top_k=4
113
-
114
- logger.info("EnsembleRetriever loaded:")
115
- return ensemble_retriever
116
-
117
- except Exception as e:
118
- logger.exception(e)
119
- raise e
120
-
121
- from llm import get_model
122
-
123
- def load_multi_query_retriever(multi_query_model_type):
124
- #multi query
125
- try:
126
- llm = get_model(multi_query_model_type)
127
- ensembleRetriever = load_ensemble_retriever()
128
- retriever = MultiQueryRetriever.from_llm(
129
- retriever=ensembleRetriever,
130
- llm=llm
131
- )
132
- logger.info("MultiQueryRetriever loaded:")
133
- return retriever
134
-
135
- except Exception as e:
136
- logger.exception(e)
137
  raise e
 
1
+
2
+ """
3
+ /*************************************************************************
4
+ *
5
+ * CONFIDENTIAL
6
+ * __________________
7
+ *
8
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
9
+ * All Rights Reserved
10
+ *
11
+ * Author : Theekshana Samaradiwakara
12
+ * Description :Python Backend API to chat with private data
13
+ * CreatedDate : 19/03/2023
14
+ * LastModifiedDate : 19/03/2024
15
+ *************************************************************************/
16
+ """
17
+
18
+ """
19
+ Ensemble retriever that ensemble the results of
20
+ multiple retrievers by using weighted Reciprocal Rank Fusion
21
+ """
22
+
23
+ import logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+ from reggpt.vectorstores.faissDb import load_FAISS_store
27
+
28
+ from langchain_community.retrievers import BM25Retriever
29
+ from langchain_community.document_loaders import PyPDFLoader
30
+ from langchain_community.document_loaders import DirectoryLoader
31
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
32
+
33
+ from langchain.schema import Document
34
+ from typing import Iterable
35
+ import json
36
+
37
+ def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
38
+ with open(file_path, 'w') as jsonl_file:
39
+ for doc in array:
40
+ jsonl_file.write(doc.json() + '\n')
41
+
42
+ def load_docs_from_jsonl(file_path)->Iterable[Document]:
43
+ array = []
44
+ with open(file_path, 'r') as jsonl_file:
45
+ for line in jsonl_file:
46
+ data = json.loads(line)
47
+ obj = Document(**data)
48
+ array.append(obj)
49
+ return array
50
+
51
+ def split_documents():
52
+ chunk_size=2000
53
+ chunk_overlap=100
54
+
55
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
56
+
57
+ years = [2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]
58
+ docs_list=[]
59
+ splits_list=[]
60
+
61
+
62
+ for year in years:
63
+ data_path= f"data/CBSL/{year}"
64
+ logger.info(f"Loading year : {data_path}")
65
+
66
+ documents = DirectoryLoader(data_path, loader_cls=PyPDFLoader).load()
67
+
68
+ for doc in documents:
69
+ doc.metadata['year']=year
70
+ logger.info(f"{doc.metadata['year']} : {doc.metadata['source']}" )
71
+ docs_list.append(doc)
72
+
73
+ texts = text_splitter.split_documents(documents)
74
+ for text in texts:
75
+ splits_list.append(text)
76
+
77
+ splitted_texts_file='data/splitted_texts.jsonl'
78
+ save_docs_to_jsonl(splits_list,splitted_texts_file)
79
+
80
+ from ensemble_retriever import EnsembleRetriever
81
+ from multi_query_retriever import MultiQueryRetriever
82
+
83
+ def load_faiss_retriever():
84
+ try:
85
+ vectorstore=load_FAISS_store()
86
+ retriever = vectorstore.as_retriever(
87
+ # search_type="mmr",
88
+ search_kwargs={'k': 5, 'fetch_k': 10}
89
+ )
90
+ logger.info("FAISS Retriever loaded:")
91
+ return retriever
92
+
93
+ except Exception as e:
94
+ logger.exception(e)
95
+ raise e
96
+
97
+ def load_ensemble_retriever():
98
+ try:
99
+ # splitted_texts_file=os.path.dirname(os.path.abspath(__file__).join('/data/splitted_texts.jsonl'))
100
+ splitted_texts_file='./data/splitted_texts.jsonl'
101
+ sementic_k = 4
102
+ bm25_k = 2
103
+ splits_list = load_docs_from_jsonl(splitted_texts_file)
104
+
105
+ bm25_retriever = BM25Retriever.from_documents(splits_list)
106
+ bm25_retriever.k = bm25_k
107
+
108
+ faiss_vectorstore = load_FAISS_store()
109
+ faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={'k': sementic_k,})
110
+
111
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
112
+ ensemble_retriever.top_k=4
113
+
114
+ logger.info("EnsembleRetriever loaded:")
115
+ return ensemble_retriever
116
+
117
+ except Exception as e:
118
+ logger.exception(e)
119
+ raise e
120
+
121
+ from reggpt.llms.llm import get_model
122
+
123
+ def load_multi_query_retriever(multi_query_model_type):
124
+ #multi query
125
+ try:
126
+ llm = get_model(multi_query_model_type)
127
+ ensembleRetriever = load_ensemble_retriever()
128
+ retriever = MultiQueryRetriever.from_llm(
129
+ retriever=ensembleRetriever,
130
+ llm=llm
131
+ )
132
+ logger.info("MultiQueryRetriever loaded:")
133
+ return retriever
134
+
135
+ except Exception as e:
136
+ logger.exception(e)
137
  raise e
{utils β†’ reggpt/utils}/utils.py RENAMED
@@ -1,40 +1,40 @@
1
- """
2
- Python Backend API to chat with private data
3
- 15/11/2023
4
- Theekshana Samaradiwakara
5
- """
6
- """
7
- /*************************************************************************
8
- *
9
- * CONFIDENTIAL
10
- * __________________
11
- *
12
- * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
13
- * All Rights Reserved
14
- *
15
- * Author : Theekshana Samaradiwakara
16
- * Description :Python Backend API to chat with private data
17
- * CreatedDate : 15/11/2023
18
- * LastModifiedDate : 10/12/2020
19
- *************************************************************************/
20
- """
21
-
22
- # from passlib.context import CryptContext
23
- # pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
24
-
25
-
26
- # def hash(password: str):
27
- # return pwd_context.hash(password)
28
-
29
-
30
- # def verify(plain_password, hashed_password):
31
- # return pwd_context.verify(plain_password, hashed_password)
32
-
33
-
34
-
35
- import re
36
- def is_valid_open_ai_api_key(secretKey):
37
- if re.search("^sk-[a-zA-Z0-9]{32,}$", secretKey ):
38
- return True
39
- else: return False
40
-
 
1
+ """
2
+ Python Backend API to chat with private data
3
+ 15/11/2023
4
+ Theekshana Samaradiwakara
5
+ """
6
+ """
7
+ /*************************************************************************
8
+ *
9
+ * CONFIDENTIAL
10
+ * __________________
11
+ *
12
+ * Copyright (2023-2024) AI Labs, IronOne Technologies, LLC
13
+ * All Rights Reserved
14
+ *
15
+ * Author : Theekshana Samaradiwakara
16
+ * Description :Python Backend API to chat with private data
17
+ * CreatedDate : 15/11/2023
18
+ * LastModifiedDate : 10/12/2020
19
+ *************************************************************************/
20
+ """
21
+
22
+ # from passlib.context import CryptContext
23
+ # pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
24
+
25
+
26
+ # def hash(password: str):
27
+ # return pwd_context.hash(password)
28
+
29
+
30
+ # def verify(plain_password, hashed_password):
31
+ # return pwd_context.verify(plain_password, hashed_password)
32
+
33
+
34
+
35
+ import re
36
+ def is_valid_open_ai_api_key(secretKey):
37
+ if re.search("^sk-[a-zA-Z0-9]{32,}$", secretKey ):
38
+ return True
39
+ else: return False
40
+
reggpt/vectorstores/__init__.py ADDED
File without changes