File size: 5,193 Bytes
e67043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import time
import types
from typing import Any, Dict, List, Tuple, Union
from langchain.agents import AgentExecutor
from langchain.input import get_color_mapping
from langchain.schema import AgentAction, AgentFinish
from .translator import Translator


class AgentExecutorWithTranslation(AgentExecutor):
    translator: Translator = Translator()

    def prep_outputs(
        self,
        inputs: Dict[str, str],
        outputs: Dict[str, str],
        return_only_outputs: bool = False,
    ) -> Dict[str, str]:
        try:
            outputs = super().prep_outputs(inputs, outputs, return_only_outputs)
        except ValueError as e:
            return outputs
        else:
            if "input" in outputs:
                outputs = self.translator(outputs)
            return outputs


class Executor(AgentExecutorWithTranslation):
    def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
        """Run text through and get agent response."""
        # Construct a mapping of tool name to tool for easy lookup
        name_to_tool_map = {tool.name: tool for tool in self.tools}
        # We construct a mapping from each tool to a color, used for logging.
        color_mapping = get_color_mapping(
            [tool.name for tool in self.tools], excluded_colors=["green"]
        )
        intermediate_steps: List[Tuple[AgentAction, str]] = []
        # Let's start tracking the iterations the agent has gone through
        iterations = 0
        time_elapsed = 0.0
        start_time = time.time()
        # We now enter the agent loop (until it returns something).
        while self._should_continue(iterations, time_elapsed):
            next_step_output = self._take_next_step(
                name_to_tool_map, color_mapping, inputs, intermediate_steps
            )
            if isinstance(next_step_output, AgentFinish):
                yield self._return(next_step_output, intermediate_steps)
                return

            for i, output in enumerate(next_step_output):
                agent_action = output[0]
                tool_logo = None
                for tool in self.tools:
                    if tool.name == agent_action.tool:
                        tool_logo = tool.tool_logo_md
                if isinstance(output[1], types.GeneratorType):
                    logo = f"{tool_logo}" if tool_logo is not None else ""
                    yield (
                        AgentAction("", agent_action.tool_input, agent_action.log),
                        f"Further use other tool {logo} to answer the question.",
                    )
                    for out in output[1]:
                        yield out
                    next_step_output[i] = (agent_action, out)
                else:
                    for tool in self.tools:
                        if tool.name == agent_action.tool:
                            yield (
                                AgentAction(
                                    tool_logo, agent_action.tool_input, agent_action.log
                                ),
                                output[1],
                            )

            intermediate_steps.extend(next_step_output)
            if len(next_step_output) == 1:
                next_step_action = next_step_output[0]
                # See if tool should return directly
                tool_return = self._get_tool_return(next_step_action)
                if tool_return is not None:
                    yield self._return(tool_return, intermediate_steps)
                    return
            iterations += 1
            time_elapsed = time.time() - start_time
        output = self.agent.return_stopped_response(
            self.early_stopping_method, intermediate_steps, **inputs
        )
        yield self._return(output, intermediate_steps)
        return

    def __call__(
        self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
    ) -> Dict[str, Any]:
        """Run the logic of this chain and add to output if desired.

        Args:
            inputs: Dictionary of inputs, or single input if chain expects
                only one param.
            return_only_outputs: boolean for whether to return only outputs in the
                response. If True, only new keys generated by this chain will be
                returned. If False, both input keys and new keys generated by this
                chain will be returned. Defaults to False.

        """
        inputs = self.prep_inputs(inputs)
        self.callback_manager.on_chain_start(
            {"name": self.__class__.__name__},
            inputs,
            verbose=self.verbose,
        )
        try:
            for output in self._call(inputs):
                if type(output) is dict:
                    output = self.prep_outputs(inputs, output, return_only_outputs)
                yield output
        except (KeyboardInterrupt, Exception) as e:
            self.callback_manager.on_chain_error(e, verbose=self.verbose)
            raise e
        self.callback_manager.on_chain_end(output, verbose=self.verbose)
        # return self.prep_outputs(inputs, output, return_only_outputs)
        return output