|
import json |
|
import warnings |
|
from copy import deepcopy |
|
from typing import Callable, Dict, List, Union |
|
|
|
from lagent.actions import ActionExecutor, AsyncActionExecutor, AsyncIPythonInterpreter, IPythonInteractive |
|
from lagent.agents.agent import Agent, AsyncAgent |
|
from lagent.agents.aggregator import InternLMToolAggregator |
|
from lagent.hooks import InternLMActionProcessor |
|
from lagent.llms import BaseLLM |
|
from lagent.memory import Memory |
|
from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode |
|
from lagent.schema import AgentMessage |
|
from lagent.utils import create_object |
|
|
|
API_PREFIX = ( |
|
"This is the subfunction for tool '{tool_name}', you can use this tool. " |
|
'The description of this function is: \n{description}') |
|
|
|
META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用') |
|
|
|
INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' |
|
'当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' |
|
'这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' |
|
'复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' |
|
'文本处理和分析(比如文本解析和自然语言处理),' |
|
'机器学习和数据科学(用于展示模型训练和数据可视化),' |
|
'以及文件操作和数据导入(处理CSV、JSON等格式的文件)。') |
|
|
|
PLUGIN_CN = ('你可以使用如下工具:' |
|
'\n{prompt}\n' |
|
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' |
|
'同时注意你可以使用的工具,不要随意捏造!') |
|
|
|
|
|
def get_plugin_prompt(actions, api_desc_template=API_PREFIX): |
|
plugin_descriptions = [] |
|
for action in actions if isinstance(actions, list) else [actions]: |
|
action = create_object(action) |
|
action_desc = deepcopy(action.description) |
|
if action.is_toolkit: |
|
for api in action_desc['api_list']: |
|
api['name'] = f"{action.name}.{api['name']}" |
|
api['description'] = api_desc_template.format( |
|
tool_name=action.name, description=api['description']) |
|
api['parameters'] = [ |
|
param for param in api['parameters'] |
|
if param['name'] in api['required'] |
|
] |
|
plugin_descriptions.append(api) |
|
else: |
|
action_desc['description'] = api_desc_template.format( |
|
tool_name=action.name, description=action_desc['description']) |
|
action_desc['parameters'] = [ |
|
param for param in action_desc['parameters'] |
|
if param['name'] in action_desc['required'] |
|
] |
|
plugin_descriptions.append(action_desc) |
|
return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4) |
|
|
|
|
|
class AgentForInternLM(Agent): |
|
|
|
_INTERNAL_AGENT_CLS = Agent |
|
|
|
def __init__( |
|
self, |
|
llm: Union[BaseLLM, Dict], |
|
plugins: Union[dict, List[dict]] = None, |
|
interpreter: dict = None, |
|
template: Union[str, dict, List[dict]] = None, |
|
memory: Dict = dict(type=Memory), |
|
output_format: Dict = dict( |
|
type=MixedToolParser, |
|
template=META_CN, |
|
parsers=[ |
|
dict(type=PluginParser, template=PLUGIN_CN), |
|
dict(type=InterpreterParser, template=INTERPRETER_CN), |
|
]), |
|
aggregator: Dict = dict(type=InternLMToolAggregator), |
|
action_hooks: List = [dict(type=InternLMActionProcessor)], |
|
finish_condition: Callable[ |
|
[AgentMessage], |
|
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, |
|
max_turn: int = 4, |
|
**kwargs, |
|
): |
|
agent = dict( |
|
type=self._INTERNAL_AGENT_CLS, |
|
llm=llm, |
|
template=template, |
|
output_format=output_format, |
|
memory=memory, |
|
aggregator=aggregator, |
|
hooks=kwargs.pop('hooks', None), |
|
) |
|
self.agent = create_object(agent) |
|
self.plugin_executor = plugins and ActionExecutor( |
|
plugins, hooks=action_hooks) |
|
self.interpreter_executor = interpreter and ActionExecutor( |
|
interpreter, hooks=action_hooks) |
|
if not (self.plugin_executor or self.interpreter_executor): |
|
warnings.warn( |
|
'Neither plugin nor interpreter executor is initialized. ' |
|
'An exception will be thrown when the agent call a tool.') |
|
self.finish_condition = finish_condition |
|
self.max_turn = max_turn |
|
super().__init__(**kwargs) |
|
|
|
def forward(self, message: AgentMessage, session_id=0, **kwargs): |
|
if isinstance(message, str): |
|
message = AgentMessage(sender='user', content=message) |
|
for _ in range(self.max_turn): |
|
message = self.agent(message, session_id=session_id, **kwargs) |
|
assert isinstance(message.formatted, dict) |
|
if self.finish_condition(message): |
|
return message |
|
if message.formatted['tool_type']: |
|
tool_type = message.formatted["tool_type"] |
|
executor = getattr(self, f'{tool_type}_executor', None) |
|
if not executor: |
|
raise RuntimeError(f'No available {tool_type} executor') |
|
message = executor(message, session_id=session_id) |
|
return message |
|
|
|
def get_steps(self, session_id=0): |
|
steps, tool_type = [], None |
|
for msg in self.agent.memory.get_memory(session_id): |
|
if msg.sender == self.agent.name: |
|
steps.append( |
|
dict(role='thought', content=msg.formatted['thought'])) |
|
if msg.formatted['tool_type']: |
|
tool_type = msg.formatted['tool_type'] |
|
steps.append( |
|
dict( |
|
role='tool', |
|
content=msg.formatted['action'], |
|
name=tool_type)) |
|
elif msg.sender != 'user': |
|
feedback = dict(role='environment', content=msg.content) |
|
if tool_type: |
|
feedback['name'] = tool_type |
|
steps.append(feedback) |
|
return steps |
|
|
|
|
|
class MathCoder(AgentForInternLM): |
|
|
|
def __init__( |
|
self, |
|
llm: Union[BaseLLM, Dict], |
|
interpreter: dict = dict( |
|
type=IPythonInteractive, timeout=20, max_out_len=8192), |
|
template: Union[str, dict, List[dict]] = None, |
|
memory: Dict = dict(type=Memory), |
|
output_format: Dict = dict( |
|
type=InterpreterParser, |
|
template= |
|
('Integrate step-by-step reasoning and Python code to solve math problems ' |
|
'using the following guidelines:\n' |
|
'- Analyze the question and write jupyter code to solve the problem;\n' |
|
r"- Present the final result in LaTeX using a '\boxed{{}}' without any " |
|
'units. \n')), |
|
aggregator: Dict = dict(type=InternLMToolAggregator), |
|
action_hooks: List = [dict(type=InternLMActionProcessor)], |
|
finish_condition: Callable[ |
|
[AgentMessage], |
|
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, |
|
max_turn: int = 6, |
|
**kwargs, |
|
): |
|
kwargs.pop('plugins', None) |
|
super().__init__( |
|
llm=llm, |
|
interpreter=interpreter, |
|
template=template, |
|
memory=memory, |
|
output_format=output_format, |
|
aggregator=aggregator, |
|
action_hooks=action_hooks, |
|
finish_condition=finish_condition, |
|
max_turn=max_turn, |
|
**kwargs) |
|
|
|
|
|
class AsyncAgentForInternLM(AsyncAgent): |
|
|
|
_INTERNAL_AGENT_CLS = AsyncAgent |
|
|
|
def __init__( |
|
self, |
|
llm: Union[BaseLLM, Dict], |
|
plugins: Union[dict, List[dict]] = None, |
|
interpreter: dict = None, |
|
template: Union[str, dict, List[dict]] = None, |
|
memory: Dict = dict(type=Memory), |
|
output_format: Dict = dict( |
|
type=MixedToolParser, |
|
template=META_CN, |
|
parsers=[ |
|
dict(type=PluginParser, template=PLUGIN_CN), |
|
dict(type=InterpreterParser, template=INTERPRETER_CN), |
|
]), |
|
aggregator: Dict = dict(type=InternLMToolAggregator), |
|
action_hooks: List = [dict(type=InternLMActionProcessor)], |
|
finish_condition: Callable[ |
|
[AgentMessage], |
|
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, |
|
max_turn: int = 4, |
|
**kwargs, |
|
): |
|
agent = dict( |
|
type=self._INTERNAL_AGENT_CLS, |
|
llm=llm, |
|
template=template, |
|
output_format=output_format, |
|
memory=memory, |
|
aggregator=aggregator, |
|
hooks=kwargs.pop('hooks', None), |
|
) |
|
self.agent = create_object(agent) |
|
self.plugin_executor = plugins and AsyncActionExecutor( |
|
plugins, hooks=action_hooks) |
|
self.interpreter_executor = interpreter and AsyncActionExecutor( |
|
interpreter, hooks=action_hooks) |
|
if not (self.plugin_executor or self.interpreter_executor): |
|
warnings.warn( |
|
'Neither plugin nor interpreter executor is initialized. ' |
|
'An exception will be thrown when the agent call a tool.') |
|
self.finish_condition = finish_condition |
|
self.max_turn = max_turn |
|
super().__init__(**kwargs) |
|
|
|
async def forward(self, message: AgentMessage, session_id=0, **kwargs): |
|
if isinstance(message, str): |
|
message = AgentMessage(sender='user', content=message) |
|
for _ in range(self.max_turn): |
|
message = await self.agent( |
|
message, session_id=session_id, **kwargs) |
|
assert isinstance(message.formatted, dict) |
|
if self.finish_condition(message): |
|
return message |
|
if message.formatted['tool_type']: |
|
tool_type = message.formatted["tool_type"] |
|
executor = getattr(self, f'{tool_type}_executor', None) |
|
if not executor: |
|
raise RuntimeError(f'No available {tool_type} executor') |
|
message = await executor(message, session_id=session_id) |
|
return message |
|
|
|
def get_steps(self, session_id=0): |
|
steps, tool_type = [], None |
|
for msg in self.agent.memory.get_memory(session_id): |
|
if msg.sender == self.agent.name: |
|
steps.append( |
|
dict(role='thought', content=msg.formatted['thought'])) |
|
if msg.formatted['tool_type']: |
|
tool_type = msg.formatted['tool_type'] |
|
steps.append( |
|
dict( |
|
role='tool', |
|
content=msg.formatted['action'], |
|
name=tool_type)) |
|
elif msg.sender != 'user': |
|
feedback = dict(role='environment', content=msg.content) |
|
if tool_type: |
|
feedback['name'] = tool_type |
|
steps.append(feedback) |
|
return steps |
|
|
|
|
|
class AsyncMathCoder(AsyncAgentForInternLM): |
|
|
|
def __init__( |
|
self, |
|
llm: Union[BaseLLM, Dict], |
|
interpreter: dict = dict(type=AsyncIPythonInterpreter), |
|
template: Union[str, dict, List[dict]] = None, |
|
memory: Dict = dict(type=Memory), |
|
output_format: Dict = dict( |
|
type=InterpreterParser, |
|
template= |
|
('Integrate step-by-step reasoning and Python code to solve math problems ' |
|
'using the following guidelines:\n' |
|
'- Analyze the question and write jupyter code to solve the problem;\n' |
|
r"- Present the final result in LaTeX using a '\boxed{{}}' without any " |
|
'units. \n')), |
|
aggregator: Dict = dict(type=InternLMToolAggregator), |
|
action_hooks: List = [dict(type=InternLMActionProcessor)], |
|
finish_condition: Callable[ |
|
[AgentMessage], |
|
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, |
|
max_turn: int = 6, |
|
**kwargs, |
|
): |
|
kwargs.pop('plugins', None) |
|
super().__init__( |
|
llm=llm, |
|
interpreter=interpreter, |
|
template=template, |
|
memory=memory, |
|
output_format=output_format, |
|
aggregator=aggregator, |
|
action_hooks=action_hooks, |
|
finish_condition=finish_condition, |
|
max_turn=max_turn, |
|
**kwargs) |
|
|
|
async def forward(self, message: AgentMessage, session_id=0, **kwargs): |
|
try: |
|
return await super().forward(message, session_id, **kwargs) |
|
finally: |
|
interpreter = next( |
|
iter(self.interpreter_executor.actions.values())) |
|
if interpreter.name == 'AsyncIPythonInterpreter': |
|
await interpreter.close_session(session_id) |
|
|