Spaces:
Runtime error
Runtime error
File size: 2,713 Bytes
8a58cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
"""SQL wrapper around SQLDatabase in langchain."""
from typing import Any, Dict, List, Tuple
from langchain.sql_database import SQLDatabase as LangchainSQLDatabase
from sqlalchemy import MetaData, create_engine, insert
from sqlalchemy.engine import Engine
class SQLDatabase(LangchainSQLDatabase):
"""SQL Database.
Wrapper around SQLDatabase object from langchain. Offers
some helper utilities for insertion and querying.
See `langchain documentation <https://tinyurl.com/4we5ku8j>`_ for more details:
Args:
*args: Arguments to pass to langchain SQLDatabase.
**kwargs: Keyword arguments to pass to langchain SQLDatabase.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self.metadata_obj = MetaData(bind=self._engine)
self.metadata_obj.reflect()
@property
def engine(self) -> Engine:
"""Return SQL Alchemy engine."""
return self._engine
@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> "SQLDatabase":
"""Construct a SQLAlchemy engine from URI."""
return cls(create_engine(database_uri), **kwargs)
def get_table_columns(self, table_name: str) -> List[dict]:
"""Get table columns."""
return self._inspector.get_columns(table_name)
def get_single_table_info(self, table_name: str) -> str:
"""Get table info for a single table."""
# same logic as table_info, but with specific table names
template = "Table '{table_name}' has columns: {columns}."
columns = []
for column in self._inspector.get_columns(table_name):
columns.append(f"{column['name']} ({str(column['type'])})")
column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str)
return table_str
def insert_into_table(self, table_name: str, data: dict) -> None:
"""Insert data into a table."""
table = self.metadata_obj.tables[table_name]
stmt = insert(table).values(**data)
self._engine.execute(stmt)
def run_sql(self, command: str) -> Tuple[str, Dict]:
"""Execute a SQL statement and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.connect() as connection:
cursor = connection.exec_driver_sql(command)
if cursor.returns_rows:
result = cursor.fetchall()
return str(result), {"result": result}
return "", {}
|