import logging import time from pathlib import Path from typing import AsyncIterator, Optional from agent_protocol_client import ( AgentApi, ApiClient, Configuration, Step, TaskRequestBody, ) from agbenchmark.agent_interface import get_list_of_file_paths from agbenchmark.config import AgentBenchmarkConfig logger = logging.getLogger(__name__) async def run_api_agent( task: str, config: AgentBenchmarkConfig, timeout: int, artifacts_location: Optional[Path] = None, *, mock: bool = False, ) -> AsyncIterator[Step]: configuration = Configuration(host=config.host) async with ApiClient(configuration) as api_client: api_instance = AgentApi(api_client) task_request_body = TaskRequestBody(input=task, additional_input=None) start_time = time.time() response = await api_instance.create_agent_task( task_request_body=task_request_body ) task_id = response.task_id if artifacts_location: logger.debug("Uploading task input artifacts to agent...") await upload_artifacts( api_instance, artifacts_location, task_id, "artifacts_in" ) logger.debug("Running agent until finished or timeout...") while True: step = await api_instance.execute_agent_task_step(task_id=task_id) yield step if time.time() - start_time > timeout: raise TimeoutError("Time limit exceeded") if step and mock: step.is_last = True if not step or step.is_last: break if artifacts_location: # In "mock" mode, we cheat by giving the correct artifacts to pass the test if mock: logger.debug("Uploading mock artifacts to agent...") await upload_artifacts( api_instance, artifacts_location, task_id, "artifacts_out" ) logger.debug("Downloading agent artifacts...") await download_agent_artifacts_into_folder( api_instance, task_id, config.temp_folder ) async def download_agent_artifacts_into_folder( api_instance: AgentApi, task_id: str, folder: Path ): artifacts = await api_instance.list_agent_task_artifacts(task_id=task_id) for artifact in artifacts.artifacts: # current absolute path of the directory of the file if artifact.relative_path: path: str = ( artifact.relative_path if not artifact.relative_path.startswith("/") else artifact.relative_path[1:] ) folder = (folder / path).parent if not folder.exists(): folder.mkdir(parents=True) file_path = folder / artifact.file_name logger.debug(f"Downloading agent artifact {artifact.file_name} to {folder}") with open(file_path, "wb") as f: content = await api_instance.download_agent_task_artifact( task_id=task_id, artifact_id=artifact.artifact_id ) f.write(content) async def upload_artifacts( api_instance: AgentApi, artifacts_location: Path, task_id: str, type: str ) -> None: for file_path in get_list_of_file_paths(artifacts_location, type): relative_path: Optional[str] = "/".join( str(file_path).split(f"{type}/", 1)[-1].split("/")[:-1] ) if not relative_path: relative_path = None await api_instance.upload_agent_task_artifacts( task_id=task_id, file=str(file_path), relative_path=relative_path )