Spaces:
Sleeping
Sleeping
# src/relationships.py | |
from typing import Dict, List, Tuple, Optional, Any | |
import sqlite3 | |
from datetime import datetime | |
from dataclasses import dataclass | |
class Entity: | |
"""Entity data structure.""" | |
id: Optional[int] | |
text: str | |
type: str | |
first_seen: str | |
last_seen: str | |
frequency: int | |
confidence: float | |
class Relationship: | |
"""Relationship data structure.""" | |
id: Optional[int] | |
source_id: int | |
target_id: int | |
type: str | |
confidence: float | |
first_seen: str | |
last_seen: str | |
class RelationshipEngine: | |
"""Engine for managing entity and event relationships.""" | |
def __init__(self, db_path: str = ':memory:') -> None: | |
"""Initialize the relationship engine with database connection.""" | |
self.conn = sqlite3.connect(db_path, check_same_thread=False) | |
self.setup_database() | |
def setup_database(self) -> None: | |
"""Initialize database schema.""" | |
self.conn.executescript(''' | |
CREATE TABLE IF NOT EXISTS events ( | |
id INTEGER PRIMARY KEY, | |
text TEXT, | |
timestamp DATETIME, | |
confidence REAL | |
); | |
CREATE TABLE IF NOT EXISTS entities ( | |
id INTEGER PRIMARY KEY, | |
entity_text TEXT, | |
entity_type TEXT, | |
first_seen DATETIME, | |
last_seen DATETIME, | |
frequency INTEGER DEFAULT 1, | |
confidence REAL | |
); | |
CREATE TABLE IF NOT EXISTS event_entities ( | |
event_id INTEGER, | |
entity_id INTEGER, | |
FOREIGN KEY (event_id) REFERENCES events(id), | |
FOREIGN KEY (entity_id) REFERENCES entities(id), | |
PRIMARY KEY (event_id, entity_id) | |
); | |
CREATE TABLE IF NOT EXISTS entity_relationships ( | |
id INTEGER PRIMARY KEY, | |
source_entity_id INTEGER, | |
target_entity_id INTEGER, | |
relationship_type TEXT, | |
confidence REAL, | |
first_seen DATETIME, | |
last_seen DATETIME, | |
FOREIGN KEY (source_entity_id) REFERENCES entities(id), | |
FOREIGN KEY (target_entity_id) REFERENCES entities(id) | |
); | |
CREATE INDEX IF NOT EXISTS idx_entity_text | |
ON entities(entity_text, entity_type); | |
CREATE INDEX IF NOT EXISTS idx_event_entities | |
ON event_entities(event_id, entity_id); | |
CREATE INDEX IF NOT EXISTS idx_entity_relationships | |
ON entity_relationships(source_entity_id, target_entity_id); | |
''') | |
self.conn.commit() | |
def store_entities(self, event_id: int, entities_dict: Dict[str, List[str]]) -> None: | |
"""Store or update entities and their relationships to events.""" | |
now = datetime.now().isoformat() | |
for entity_type, entities in entities_dict.items(): | |
if not isinstance(entities, list): | |
continue | |
for entity_text in entities: | |
# Check if entity exists | |
cursor = self.conn.execute( | |
'SELECT id, frequency FROM entities WHERE entity_text = ? AND entity_type = ?', | |
(entity_text, entity_type) | |
) | |
result = cursor.fetchone() | |
if result: | |
entity_id, freq = result | |
self.conn.execute(''' | |
UPDATE entities | |
SET frequency = ?, last_seen = ? | |
WHERE id = ? | |
''', (freq + 1, now, entity_id)) | |
else: | |
cursor = self.conn.execute(''' | |
INSERT INTO entities | |
(entity_text, entity_type, first_seen, last_seen, confidence) | |
VALUES (?, ?, ?, ?, ?) | |
''', (entity_text, entity_type, now, now, 1.0)) | |
entity_id = cursor.lastrowid | |
self.conn.execute(''' | |
INSERT OR IGNORE INTO event_entities (event_id, entity_id) | |
VALUES (?, ?) | |
''', (event_id, entity_id)) | |
self.conn.commit() | |
def find_related_events(self, event_data: Dict) -> List[Tuple]: | |
"""Find events related through shared entities.""" | |
entity_texts = [] | |
for entity_type, entities in event_data.get('entities', {}).items(): | |
if isinstance(entities, list): | |
entity_texts.extend(entities) | |
if not entity_texts: | |
return [] | |
placeholders = ','.join('?' * len(entity_texts)) | |
query = f''' | |
SELECT DISTINCT e.*, COUNT(ee.entity_id) as shared_entities | |
FROM events e | |
JOIN event_entities ee ON e.id = ee.event_id | |
JOIN entities ent ON ee.entity_id = ent.id | |
WHERE ent.entity_text IN ({placeholders}) | |
GROUP BY e.id | |
ORDER BY shared_entities DESC, e.timestamp DESC | |
LIMIT 5 | |
''' | |
return self.conn.execute(query, entity_texts).fetchall() | |
def update_entity_relationships(self, event_id: int) -> None: | |
"""Update relationships between entities in an event.""" | |
entities = self.conn.execute(''' | |
SELECT e.id, e.entity_text, e.entity_type | |
FROM entities e | |
JOIN event_entities ee ON e.id = ee.entity_id | |
WHERE ee.event_id = ? | |
''', (event_id,)).fetchall() | |
now = datetime.now().isoformat() | |
for i, entity1 in enumerate(entities): | |
for entity2 in entities[i+1:]: | |
if entity1[2] == entity2[2]: | |
continue | |
relationship_type = f"{entity1[2]}_to_{entity2[2]}" | |
self._update_relationship(entity1[0], entity2[0], relationship_type, now) | |
self.conn.commit() | |
def _update_relationship(self, source_id: int, target_id: int, rel_type: str, timestamp: str) -> None: | |
"""Update or create a relationship between entities.""" | |
result = self.conn.execute(''' | |
SELECT id FROM entity_relationships | |
WHERE (source_entity_id = ? AND target_entity_id = ?) | |
OR (source_entity_id = ? AND target_entity_id = ?) | |
''', (source_id, target_id, target_id, source_id)).fetchone() | |
if result: | |
self.conn.execute(''' | |
UPDATE entity_relationships | |
SET last_seen = ?, confidence = confidence + 0.1 | |
WHERE id = ? | |
''', (timestamp, result[0])) | |
else: | |
self.conn.execute(''' | |
INSERT INTO entity_relationships | |
(source_entity_id, target_entity_id, relationship_type, confidence, first_seen, last_seen) | |
VALUES (?, ?, ?, ?, ?, ?) | |
''', (source_id, target_id, rel_type, 0.5, timestamp, timestamp)) | |
def get_entity_relationships(self, event_id: int) -> List[Dict[str, Any]]: | |
"""Get all relationships for entities in an event.""" | |
query = ''' | |
SELECT DISTINCT er.*, | |
e1.entity_text as source_text, e1.entity_type as source_type, | |
e2.entity_text as target_text, e2.entity_type as target_type | |
FROM event_entities ee | |
JOIN entity_relationships er ON ee.entity_id IN (er.source_entity_id, er.target_entity_id) | |
JOIN entities e1 ON er.source_entity_id = e1.id | |
JOIN entities e2 ON er.target_entity_id = e2.id | |
WHERE ee.event_id = ? | |
''' | |
return [dict(row) for row in self.conn.execute(query, (event_id,)).fetchall()] |