Spaces:
Running
Running
File size: 7,338 Bytes
cd36062 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
"""
This module is responsible for the VectorDB API. It currently supports:
* DELETE api/v1/clear
- Clears the whole DB.
* POST api/v1/add
- Add some corpus to the DB. You can also specify metadata to be added alongside it.
* POST api/v1/delete
- Delete specific records with given metadata.
* POST api/v1/get
- Get results from chromaDB.
"""
import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs
from threading import Thread
from modules import shared
from modules.logging_colors import logger
from .chromadb import ChromaCollector
from .data_processor import process_and_add_to_collector
import extensions.superboogav2.parameters as parameters
class CustomThreadingHTTPServer(ThreadingHTTPServer):
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
self.collector = collector
super().__init__(server_address, RequestHandlerClass, bind_and_activate)
def finish_request(self, request, client_address):
self.RequestHandlerClass(request, client_address, self, self.collector)
class Handler(BaseHTTPRequestHandler):
def __init__(self, request, client_address, server, collector: ChromaCollector):
self.collector = collector
super().__init__(request, client_address, server)
def _send_412_error(self, message):
self.send_response(412)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": message})
self.wfile.write(response.encode('utf-8'))
def _send_404_error(self):
self.send_response(404)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": "Resource not found"})
self.wfile.write(response.encode('utf-8'))
def _send_400_error(self, error_message: str):
self.send_response(400)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"error": error_message})
self.wfile.write(response.encode('utf-8'))
def _send_200_response(self, message: str):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
if isinstance(message, str):
response = json.dumps({"message": message})
else:
response = json.dumps(message)
self.wfile.write(response.encode('utf-8'))
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
if sort_param == parameters.SORT_DISTANCE:
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
elif sort_param == parameters.SORT_ID:
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
else: # Default is dist
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
return {
"results": results
}
def do_GET(self):
self._send_404_error()
def do_POST(self):
try:
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
parsed_path = urlparse(self.path)
path = parsed_path.path
query_params = parse_qs(parsed_path.query)
if path in ['/api/v1/add', '/api/add']:
corpus = body.get('corpus')
if corpus is None:
self._send_412_error("Missing parameter 'corpus'")
return
clear_before_adding = body.get('clear_before_adding', False)
metadata = body.get('metadata')
process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata)
self._send_200_response("Data successfully added")
elif path in ['/api/v1/delete', '/api/delete']:
metadata = body.get('metadata')
if corpus is None:
self._send_412_error("Missing parameter 'metadata'")
return
self.collector.delete(ids_to_delete=None, where=metadata)
self._send_200_response("Data successfully deleted")
elif path in ['/api/v1/get', '/api/get']:
search_strings = body.get('search_strings')
if search_strings is None:
self._send_412_error("Missing parameter 'search_strings'")
return
n_results = body.get('n_results')
if n_results is None:
n_results = parameters.get_chunk_count()
max_token_count = body.get('max_token_count')
if max_token_count is None:
max_token_count = parameters.get_max_token_count()
sort_param = query_params.get('sort', ['distance'])[0]
results = self._handle_get(search_strings, n_results, max_token_count, sort_param)
self._send_200_response(results)
else:
self._send_404_error()
except Exception as e:
self._send_400_error(str(e))
def do_DELETE(self):
try:
parsed_path = urlparse(self.path)
path = parsed_path.path
query_params = parse_qs(parsed_path.query)
if path in ['/api/v1/clear', '/api/clear']:
self.collector.clear()
self._send_200_response("Data successfully cleared")
else:
self._send_404_error()
except Exception as e:
self._send_400_error(str(e))
def do_OPTIONS(self):
self.send_response(200)
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
super().end_headers()
class APIManager:
def __init__(self, collector: ChromaCollector):
self.server = None
self.collector = collector
self.is_running = False
def start_server(self, port: int):
if self.server is not None:
print("Server already running.")
return
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
self.server = CustomThreadingHTTPServer((address, port), Handler, self.collector)
logger.info(f'Starting chromaDB API at http://{address}:{port}/api')
Thread(target=self.server.serve_forever, daemon=True).start()
self.is_running = True
def stop_server(self):
if self.server is not None:
logger.info(f'Stopping chromaDB API.')
self.server.shutdown()
self.server.server_close()
self.server = None
self.is_running = False
def is_server_running(self):
return self.is_running |