Spaces:
Sleeping
Sleeping
# src/analyzer.py | |
from typing import Dict, List, Any, Optional, Union | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
from transformers import pipeline | |
from datetime import datetime | |
from .ontology import OntologyRegistry | |
from .relationships import RelationshipEngine | |
class EventAnalyzer: | |
"""Main analyzer class for event processing.""" | |
def __init__(self) -> None: | |
"""Initialize the event analyzer with required components.""" | |
self.ontology = OntologyRegistry() | |
self.relationship_engine = RelationshipEngine() | |
self.executor = ThreadPoolExecutor(max_workers=3) | |
# Initialize NLP pipelines | |
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english") | |
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
async def extract_entities(self, text: str) -> Dict[str, List[str]]: | |
"""Extract entities from text using NER pipeline.""" | |
def _extract(): | |
return self.ner_pipeline(text) | |
ner_results = await asyncio.get_event_loop().run_in_executor( | |
self.executor, _extract | |
) | |
entities = { | |
"people": [], | |
"organizations": [], | |
"locations": [], | |
"hashtags": [word for word in text.split() if word.startswith('#')] | |
} | |
for item in ner_results: | |
if item["entity"].endswith("PER"): | |
entities["people"].append(item["word"]) | |
elif item["entity"].endswith("ORG"): | |
entities["organizations"].append(item["word"]) | |
elif item["entity"].endswith("LOC"): | |
entities["locations"].append(item["word"]) | |
return entities | |
def extract_temporal(self, text: str) -> List[str]: | |
"""Extract temporal expressions from text.""" | |
return self.ontology.validate_pattern(text, 'temporal') | |
async def extract_locations(self, text: str) -> List[str]: | |
"""Extract locations using both NER and pattern matching.""" | |
entities = await self.extract_entities(text) | |
ml_locations = entities.get('locations', []) | |
pattern_locations = self.ontology.validate_pattern(text, 'location') | |
return list(set(ml_locations + pattern_locations)) | |
def calculate_confidence(self, | |
entities: Dict[str, List[str]], | |
temporal_data: List[str], | |
related_events: List[Any]) -> float: | |
"""Calculate confidence score for extracted information.""" | |
# Base confidence from entity presence | |
base_confidence = min(1.0, ( | |
0.2 * bool(entities["people"]) + | |
0.2 * bool(entities["organizations"]) + | |
0.3 * bool(entities["locations"]) + | |
0.3 * bool(temporal_data) | |
)) | |
# Get entity parameters for frequency calculation | |
entity_params = [ | |
*entities["people"], | |
*entities["organizations"], | |
*entities["locations"] | |
] | |
if not entity_params: | |
return base_confidence | |
# Calculate entity frequency boost | |
query = f''' | |
SELECT AVG(frequency) as avg_freq | |
FROM entities | |
WHERE entity_text IN ({','.join(['?']*len(entity_params))}) | |
''' | |
cursor = self.relationship_engine.conn.execute(query, entity_params) | |
avg_frequency = cursor.fetchone()[0] or 1 | |
frequency_boost = min(0.2, (avg_frequency - 1) * 0.05) | |
# Calculate relationship confidence boost | |
relationship_confidence = 0 | |
if related_events: | |
relationship_scores = [] | |
for event in related_events: | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT COUNT(*) as shared_entities | |
FROM event_entities ee1 | |
JOIN event_entities ee2 ON ee1.entity_id = ee2.entity_id | |
WHERE ee1.event_id = ? AND ee2.event_id = ? | |
''', (event[0], event[0])) | |
shared_count = cursor.fetchone()[0] | |
relationship_scores.append(min(0.3, shared_count * 0.1)) | |
if relationship_scores: | |
relationship_confidence = max(relationship_scores) | |
return min(1.0, base_confidence + frequency_boost + relationship_confidence) | |
async def analyze_event(self, text: str) -> Dict[str, Any]: | |
"""Analyze event text and extract structured information.""" | |
try: | |
# Parallel extraction | |
entities_future = self.extract_entities(text) | |
temporal_data = self.extract_temporal(text) | |
locations_future = self.extract_locations(text) | |
# Gather async results | |
entities, locations = await asyncio.gather( | |
entities_future, locations_future | |
) | |
# Merge locations and add temporal data | |
entities['locations'] = locations | |
entities['temporal'] = temporal_data | |
# Find related events | |
related_events = self.relationship_engine.find_related_events({ | |
'text': text, | |
'entities': entities | |
}) | |
# Calculate confidence | |
confidence = self.calculate_confidence(entities, temporal_data, related_events) | |
# Store event if confidence meets threshold | |
cursor = None | |
if confidence >= 0.6: | |
cursor = self.relationship_engine.conn.execute( | |
'INSERT INTO events (text, timestamp, confidence) VALUES (?, ?, ?)', | |
(text, datetime.now().isoformat(), confidence) | |
) | |
event_id = cursor.lastrowid | |
# Store entities and update relationships | |
self.relationship_engine.store_entities(event_id, { | |
'person': entities['people'], | |
'organization': entities['organizations'], | |
'location': entities['locations'], | |
'temporal': temporal_data, | |
'hashtag': entities['hashtags'] | |
}) | |
self.relationship_engine.update_entity_relationships(event_id) | |
self.relationship_engine.conn.commit() | |
# Get entity relationships for output | |
entity_relationships = [] | |
if cursor and cursor.lastrowid: | |
entity_relationships = self.relationship_engine.get_entity_relationships(cursor.lastrowid) | |
return { | |
"text": text, | |
"entities": entities, | |
"confidence": confidence, | |
"verification_needed": confidence < 0.6, | |
"related_events": [ | |
{ | |
"text": event[1], | |
"timestamp": event[2], | |
"confidence": event[3], | |
"shared_entities": event[4] if len(event) > 4 else None | |
} | |
for event in related_events | |
], | |
"entity_relationships": entity_relationships | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def get_entity_statistics(self) -> Dict[str, List[tuple]]: | |
"""Get statistics about stored entities and relationships.""" | |
stats = {} | |
# Entity counts by type | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT entity_type, COUNT(*) as count, AVG(frequency) as avg_frequency | |
FROM entities | |
GROUP BY entity_type | |
''') | |
stats['entity_counts'] = cursor.fetchall() | |
# Most frequent entities | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT entity_text, entity_type, frequency | |
FROM entities | |
ORDER BY frequency DESC | |
LIMIT 10 | |
''') | |
stats['frequent_entities'] = cursor.fetchall() | |
# Relationship statistics | |
cursor = self.relationship_engine.conn.execute(''' | |
SELECT relationship_type, COUNT(*) as count, AVG(confidence) as avg_confidence | |
FROM entity_relationships | |
GROUP BY relationship_type | |
''') | |
stats['relationship_stats'] = cursor.fetchall() | |
return stats |