|
from dataclasses import dataclass |
|
from typing import List, Dict, Any, Optional |
|
import json |
|
import requests |
|
from bs4 import BeautifulSoup |
|
from openai import OpenAI |
|
|
|
""" |
|
EXAMPLE OUTPUT: |
|
|
|
What is the current population for the city where Einstein was born? |
|
|
|
Step 1 |
|
---------------------------------------- |
|
|
|
Executing: fetch_wiki_content |
|
Arguments: {'title': 'Albert Einstein'} |
|
|
|
Step 2 |
|
---------------------------------------- |
|
|
|
Executing: deliver_answer |
|
Arguments: {'fields': ['Ulm, German Empire']} |
|
ANSWER FROM THE ASSISTANT: ['Ulm, German Empire'] |
|
|
|
Step 3 |
|
---------------------------------------- |
|
|
|
Executing: fetch_wiki_content |
|
Arguments: {'title': 'Ulm'} |
|
|
|
Step 4 |
|
---------------------------------------- |
|
|
|
Executing: deliver_answer |
|
Arguments: {'fields': ['128,928']} |
|
ANSWER FROM THE ASSISTANT: ['128,928'] |
|
|
|
Step 5 |
|
---------------------------------------- |
|
Extraction Complete |
|
|
|
|
|
Why was Einstein famous? |
|
|
|
Step 1 |
|
---------------------------------------- |
|
|
|
Executing: fetch_wiki_content |
|
Arguments: {'title': 'Albert Einstein'} |
|
|
|
Step 2 |
|
---------------------------------------- |
|
|
|
Executing: deliver_answer |
|
Arguments: {'fields': ['Best known for developing the theory of relativity, Einstein also made important contributions to quantum mechanics.', 'His mass–energy equivalence formula E = mc2, which arises from special relativity, has been called "the world\'s most famous equation."', 'He received the 1921 Nobel Prize in Physics.']} |
|
ANSWER FROM THE ASSISTANT: ['Best known for developing the theory of relativity, Einstein also made important contributions to quantum mechanics.', 'His mass–energy equivalence formula E = mc2, which arises from special relativity, has been called "the world\'s most famous equation."', 'He received the 1921 Nobel Prize in Physics.'] |
|
|
|
Step 3 |
|
---------------------------------------- |
|
Extraction Complete |
|
""" |
|
|
|
@dataclass |
|
class WikiConfig: |
|
"""Configuration for OpenAI and Wikipedia settings""" |
|
api_key: str = "sk-123" |
|
api_base: str = "{info}/v1" |
|
model: Optional[str] = None |
|
max_steps: int = 5 |
|
wikipedia_base_url: str = "https://en.wikipedia.org/wiki/" |
|
|
|
class WikiTools: |
|
"""Collection of Wikipedia and extraction tools""" |
|
|
|
def __init__(self, base_url: str): |
|
self.base_url = base_url |
|
|
|
def fetch_wiki_content(self, title: str, section: Optional[str] = None) -> str: |
|
"""Fetch and clean Wikipedia article content, optionally from a specific section""" |
|
url = f"{self.base_url}{title.replace(' ', '_')}" |
|
response = requests.get(url) |
|
soup = BeautifulSoup(response.content, 'html.parser') |
|
|
|
|
|
for unwanted in soup.find_all(['script', 'style', 'footer', 'header']): |
|
unwanted.decompose() |
|
|
|
if section: |
|
|
|
section_tag = soup.find('span', {'id': section}) |
|
if section_tag: |
|
content = section_tag.parent.find_next_siblings() |
|
text = ' '.join(tag.get_text() for tag in content) |
|
else: |
|
return "Section not found" |
|
else: |
|
|
|
content = soup.find(id='mw-content-text') |
|
if content: |
|
text = content.get_text() |
|
else: |
|
return "Content not found" |
|
|
|
|
|
text = ' '.join(text.split()) |
|
return text[:8000] |
|
|
|
@staticmethod |
|
def deliver_answer(fields: List[str]) -> Dict[str, Any]: |
|
"""Extract specific information from text spans""" |
|
print (f"ANSWER FROM THE ASSISTANT: {fields}") |
|
return { |
|
"extracted_fields": "Provided fields was delivered to the user successfully." |
|
} |
|
|
|
class ToolRegistry: |
|
"""Registry of available tools and their schemas""" |
|
|
|
def __init__(self, wiki_tools: WikiTools): |
|
self.wiki_tools = wiki_tools |
|
|
|
@property |
|
def available_functions(self) -> Dict[str, callable]: |
|
return { |
|
"fetch_wiki_content": self.wiki_tools.fetch_wiki_content, |
|
"deliver_answer": self.wiki_tools.deliver_answer |
|
} |
|
|
|
@property |
|
def tool_schemas(self) -> List[Dict[str, Any]]: |
|
return [ |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "fetch_wiki_content", |
|
"description": "Fetch content from a Wikipedia article", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"title": { |
|
"type": "string", |
|
"description": "The title of the Wikipedia article" |
|
}, |
|
"section": { |
|
"type": "string", |
|
"description": "Optional: Specific section ID to fetch", |
|
"optional": True |
|
} |
|
}, |
|
"required": ["title"] |
|
} |
|
} |
|
}, |
|
{ |
|
"type": "function", |
|
"function": { |
|
"name": "deliver_answer", |
|
"description": "Extract specific information from the fetched text", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"fields": { |
|
"type": "array", |
|
"items": {"type": "string"}, |
|
"description": "List of text spans from the article that are relevant to the query" |
|
} |
|
}, |
|
"required": ["fields"] |
|
} |
|
} |
|
} |
|
] |
|
|
|
class WikiExtractionAgent: |
|
"""Main agent class that handles the extraction process""" |
|
|
|
def __init__(self, config: WikiConfig): |
|
self.config = config |
|
self.client = OpenAI(api_key=config.api_key, base_url=config.api_base) |
|
self.wiki_tools = WikiTools(config.wikipedia_base_url) |
|
self.tools = ToolRegistry(self.wiki_tools) |
|
self.messages = [{"system" : "1. First fetch any wikipedia pages you might need to answer the user query. Do not answer from parametric knowledge.\n\n2.Then, provide the answer to the user using the deliver_answer from the retrieved wikipedia page.\n\n3. You may need to issue multiple calls to wikipedia after extracting answers if there are nested dependencies for information."}] |
|
|
|
if not config.model: |
|
models = self.client.models.list() |
|
self.config.model = models.data[0].id |
|
|
|
def _serialize_tool_call(self, tool_call) -> Dict[str, Any]: |
|
"""Convert tool call to serializable format""" |
|
return { |
|
"id": tool_call.id, |
|
"type": tool_call.type, |
|
"function": { |
|
"name": tool_call.function.name, |
|
"arguments": tool_call.function.arguments |
|
} |
|
} |
|
|
|
def process_tool_calls(self, message) -> List[Dict[str, Any]]: |
|
"""Process and execute tool calls from assistant""" |
|
results = [] |
|
|
|
for tool_call in message.tool_calls: |
|
function_name = tool_call.function.name |
|
function_args = json.loads(tool_call.function.arguments) |
|
|
|
print(f"\nExecuting: {function_name}") |
|
print(f"Arguments: {function_args}") |
|
|
|
function_response = self.tools.available_functions[function_name](**function_args) |
|
results.append({ |
|
"tool": function_name, |
|
"args": function_args, |
|
"response": function_response |
|
}) |
|
|
|
self.messages.append({ |
|
"role": "tool", |
|
"content": json.dumps(function_response), |
|
"tool_call_id": tool_call.id, |
|
"name": function_name |
|
}) |
|
|
|
return results |
|
|
|
def extract_information(self, query: str) -> List[Dict[str, Any]]: |
|
"""Main method to handle the extraction process""" |
|
self.messages = [{ |
|
"role": "user", |
|
"content": f"""Extract information from Wikipedia to answer this query: {query} |
|
|
|
You can use these tools: |
|
1. fetch_wiki_content: Get article content |
|
2. deliver_answer: deliver relevant information |
|
|
|
Please fetch content first, and iterate as needed to get to the webpage with the correct answer and then deliver the relevant information.""" |
|
}] |
|
|
|
all_results = [] |
|
|
|
for step in range(self.config.max_steps): |
|
print(f"\nStep {step + 1}") |
|
print("-" * 40) |
|
|
|
response = self.client.chat.completions.create( |
|
messages=self.messages, |
|
model=self.config.model, |
|
tools=self.tools.tool_schemas, |
|
temperature=0.0, |
|
) |
|
|
|
message = response.choices[0].message |
|
|
|
if not message.tool_calls: |
|
print("Extraction Complete") |
|
break |
|
|
|
self.messages.append({ |
|
"role": "assistant", |
|
"content": json.dumps(message.content), |
|
"tool_calls": [self._serialize_tool_call(tc) for tc in message.tool_calls] |
|
}) |
|
|
|
results = self.process_tool_calls(message) |
|
all_results.extend(results) |
|
|
|
return all_results |
|
|
|
def main(): |
|
|
|
config = WikiConfig() |
|
agent = WikiExtractionAgent(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
results = agent.extract_information( |
|
query="""What is the current population for the city where Einstein was born?""" |
|
) |
|
|
|
|
|
|
|
|
|
results = agent.extract_information( |
|
query="Why was Einstein famous?" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|