Merge pull request #121 from alipay/dev_weizj

add sql tool
This commit is contained in:
Jerry Z H
2024-07-11 12:24:05 +08:00
committed by GitHub
14 changed files with 171 additions and 47 deletions

View File

@@ -51,8 +51,8 @@ class ExpressingPlanner(Planner):
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()
res = asyncio.run(
chain_with_history.ainvoke(input=planner_input, config={"configurable": {"session_id": "unused"}}))
res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}
def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> Prompt:

View File

@@ -42,7 +42,7 @@ class PeerPlanner(Planner):
"""
planner_config = agent_model.plan.get('planner')
sub_agents = self.generate_sub_agents(planner_config)
return self.agents_run(sub_agents, planner_config, planner_input, input_object)
return self.agents_run(agent_model, sub_agents, planner_config, planner_input, input_object)
@staticmethod
def generate_sub_agents(planner_config: dict) -> dict:
@@ -79,7 +79,8 @@ class PeerPlanner(Planner):
elif context:
input_object.add_data('expert_framework', context)
def agents_run(self, agents: dict, planner_config: dict, agent_input: dict, input_object: InputObject) -> dict:
def agents_run(self, agent_mode: AgentModel, agents: dict, planner_config: dict, agent_input: dict,
input_object: InputObject) -> dict:
"""Planner agents run.
Args:
@@ -125,7 +126,10 @@ class PeerPlanner(Planner):
for index, one_framework in enumerate(planning_result.get_data('framework')):
logger_info += f"[{index + 1}] {one_framework} \n"
LOGGER.info(logger_info)
self.stream_output(input_object, {"data": planning_result.to_dict(), "type": "planning"})
self.stream_output(input_object, {"data": {
'output': planning_result.to_dict(),
"agent_info": agent_mode.info
}, "type": "planning"})
if not executing_result or jump_step in ["planning", "executing"]:
if not executingAgent:
@@ -144,7 +148,10 @@ class PeerPlanner(Planner):
one_exec_log_info += f"[{index + 1}] output: {one_exec_res['output']}\n"
logger_info += one_exec_log_info
LOGGER.info(logger_info)
self.stream_output(input_object, {"data": executing_result.to_dict(), "type": "executing"})
self.stream_output(input_object, {"data": {
'output': executing_result.to_dict(),
"agent_info": agent_mode.info
}, "type": "executing"})
if not expressing_result or jump_step in ["planning", "executing", "expressing"]:
if not expressingAgent:
@@ -159,7 +166,10 @@ class PeerPlanner(Planner):
logger_info = f"\nExpressing agent execution result is :\n"
logger_info += f"{expressing_result.get_data('output')}"
LOGGER.info(logger_info)
self.stream_output(input_object, {"data": executing_result.to_dict(), "type": "expressing"})
self.stream_output(input_object, {"data": {
'output': expressing_result.get_data('output'),
"agent_info": agent_mode.info
}, "type": "expressing"})
if not reviewing_result or jump_step in ["planning", "executing", "expressing", "reviewing"]:
if not reviewingAgent:

View File

@@ -6,16 +6,18 @@
# @FileName: planner.py
"""Base class for Planner."""
from abc import abstractmethod
import copy
import logging
from queue import Queue
from typing import Optional, List
from typing import Optional, List, Any
from langchain_core.runnables import RunnableSerializable
from agentuniverse.agent.action.knowledge.knowledge import Knowledge
from agentuniverse.agent.action.knowledge.knowledge_manager import KnowledgeManager
from agentuniverse.agent.action.knowledge.store.document import Document
from agentuniverse.agent.action.knowledge.store.query import Query
from agentuniverse.agent.action.tool.tool_manager import ToolManager
from agentuniverse.agent.agent_manager import AgentManager
from agentuniverse.agent.agent_model import AgentModel
from agentuniverse.agent.input_object import InputObject
from agentuniverse.agent.memory.chat_memory import ChatMemory
@@ -28,7 +30,7 @@ from agentuniverse.base.config.component_configer.configers.planner_configer imp
from agentuniverse.llm.llm import LLM
from agentuniverse.llm.llm_manager import LLMManager
from agentuniverse.prompt.prompt import Prompt
from agentuniverse.base.util.memory_util import generate_messages
from agentuniverse.base.util.memory_util import generate_messages, generate_memories
logging.getLogger().setLevel(logging.ERROR)
@@ -101,6 +103,7 @@ class Planner(ComponentBase):
action: dict = agent_model.action or dict()
tools: list = action.get('tool') or list()
knowledge: list = action.get('knowledge') or list()
agents: list = action.get('agent') or list()
action_result: list = list()
@@ -120,6 +123,16 @@ class Planner(ComponentBase):
for document in knowledge_res:
action_result.append(document.text)
for agent_name in agents:
agent = AgentManager().get_instance_obj(agent_name)
if agent is None:
continue
agent_input = {key: input_object.get_data(key) for key in agent.input_keys()}
output_object = agent.run(**agent_input)
action_result.append("\n".join([output_object.get_data(key)
for key in agent.output_keys()
if output_object.get_data(key) is not None]))
planner_input['background'] = planner_input['background'] or '' + "\n".join(action_result)
def handle_prompt(self, agent_model: AgentModel, planner_input: dict):
@@ -167,7 +180,26 @@ class Planner(ComponentBase):
input_object (InputObject): Agent input object.
data (dict): The data to be streamed.
"""
output_stream:Queue = input_object.get_data('output_stream', None)
output_stream: Queue = input_object.get_data('output_stream', None)
if output_stream is None:
return
output_stream.put_nowait(data)
def invoke_chain(self, agent_model: AgentModel, chain: RunnableSerializable[Any, str], planner_input: dict, chat_history,
input_object: InputObject):
if not input_object.get_data('output_stream'):
res = chain.invoke(input=planner_input, config={"configurable": {"session_id": "unused"}})
return res
result = []
for token in chain.stream(input=planner_input, config={"configurable": {"session_id": "unused"}}):
self.stream_output(input_object, {
'type': 'token',
'data': {
'chunk': token,
'agent_info': agent_model.info
}
})
result.append(token)
return "".join(result)

View File

@@ -52,8 +52,7 @@ class PlanningPlanner(Planner):
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()
res = asyncio.run(
chain_with_history.ainvoke(input=planner_input, config={"configurable": {"session_id": "unused"}}))
res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}
def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> Prompt:

View File

@@ -5,11 +5,9 @@
# @Email : lc299034@antgroup.com
# @FileName: rag_planner.py
"""Rag planner module."""
from typing import Any
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableSerializable
from langchain_core.runnables.history import RunnableWithMessageHistory
from agentuniverse.agent.agent_model import AgentModel
@@ -56,26 +54,8 @@ class RagPlanner(Planner):
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()
if not input_object.get_data('output_stream'):
res = chain_with_history.invoke(input=planner_input, config={"configurable": {"session_id": "unused"}})
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}
else:
return self.stream(agent_model, chain_with_history, planner_input, chat_history, input_object)
def stream(self, agent_model: AgentModel, chain: RunnableSerializable[Any, str], planner_input: dict, chat_history,
input_object: InputObject):
result = []
for token in chain.stream(input=planner_input, config={"configurable": {"session_id": "unused"}}):
self.stream_output(input_object, {
'type': 'token',
'data': {
'token': token,
'agent_info': agent_model.info
}
})
result.append(token)
return {**planner_input, self.output_key: ''.join(result), 'chat_history': generate_memories(chat_history)}
res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}
def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> ChatPrompt:
"""Prompt module processing.

View File

@@ -64,14 +64,14 @@ class ReActPlanner(Planner):
max_iterations=agent_model.plan.get('planner').get("max_iterations", 15))
return agent_executor.invoke(input=planner_input, memory=memory.as_langchain() if memory else None,
chat_history=chat_history, config=self.get_run_config(input_object))
chat_history=chat_history, config=self.get_run_config(agent_model, input_object))
@staticmethod
def get_run_config(input_object: InputObject) -> RunnableConfig:
def get_run_config(agent_model: AgentModel, input_object: InputObject) -> RunnableConfig:
config = RunnableConfig()
callbacks = []
output_stream = input_object.get_data('output_stream')
callbacks.append(StreamOutPutCallbackHandler(output_stream))
callbacks.append(StreamOutPutCallbackHandler(output_stream, agent_info=agent_model.info))
config.setdefault("callbacks", callbacks)
return config

View File

@@ -16,10 +16,13 @@ from langchain_core.callbacks import BaseCallbackHandler
class StreamOutPutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, queue_stream: asyncio.Queue, color: Optional[str] = None) -> None:
def __init__(self, queue_stream: asyncio.Queue, color: Optional[str] = None, agent_info: dict = None) -> None:
"""Initialize callback handler."""
self.queueStream = queue_stream
self.color = color
if agent_info is None:
agent_info = {}
self.agent_info = agent_info
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
@@ -32,7 +35,13 @@ class StreamOutPutCallbackHandler(BaseCallbackHandler):
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
self.queueStream.put_nowait("Thought:"+action.log)
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": "\nThought:" + action.log,
"agent_info": self.agent_info
}
})
def on_tool_end(
self,
@@ -44,9 +53,21 @@ class StreamOutPutCallbackHandler(BaseCallbackHandler):
) -> None:
"""If not the final action, print out observation."""
if observation_prefix is not None:
self.queueStream.put_nowait(observation_prefix + output)
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": '\n' + observation_prefix + output,
"agent_info": self.agent_info
}
})
else:
self.queueStream.put_nowait('Observation:'+output)
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": '\n Observation:' + output,
"agent_info": self.agent_info
}
})
def on_text(
self,
@@ -61,4 +82,10 @@ class StreamOutPutCallbackHandler(BaseCallbackHandler):
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
self.queueStream.put_nowait("Thought:" + finish.log + "\n")
self.queueStream.put_nowait({
"type": "ReAct",
"data": {
"output": '\nThought:' + finish.output,
"agent_info": self.agent_info
}
})

View File

@@ -51,8 +51,7 @@ class ReviewingPlanner(Planner):
history_messages_key="chat_history",
input_messages_key=self.input_key,
) | StrOutputParser()
res = asyncio.run(
chain_with_history.ainvoke(input=planner_input, config={"configurable": {"session_id": "unused"}}))
res = self.invoke_chain(agent_model, chain_with_history, planner_input, chat_history, input_object)
return {**planner_input, self.output_key: res, 'chat_history': generate_memories(chat_history)}
def handle_prompt(self, agent_model: AgentModel, planner_input: dict) -> Prompt:

View File

@@ -207,4 +207,3 @@ class OpenAIStyleLLM(LLM):
"""Return the maximum length of the context."""
if super().max_context_length():
return super().max_context_length()
return 4000

View File

@@ -0,0 +1,13 @@
name: 'info_sql_database_tool'
description: ''
tool_type: 'api'
input_keys: ['input']
langchain:
module: langchain_community.tools
class_name: InfoSQLDatabaseTool
init_params:
db_wrapper: demo_sqldb_wrapper
metadata:
type: 'TOOL'
module: 'sample_standard_app.app.core.tool.langchain_tool.sql_langchain_tool'
class: 'SqlLangchainTool'

View File

@@ -29,7 +29,7 @@ class LangChainTool(Tool):
def initialize_by_component_configer(self, component_configer: ToolConfiger) -> 'Tool':
super().initialize_by_component_configer(component_configer)
self.tool = self.init_langchain_tool(component_configer)
if not component_configer.description:
if not component_configer.description and self.tool is not None:
self.description = self.tool.description
return self

View File

@@ -0,0 +1,13 @@
name: 'list_sql_database_tool'
description: ''
tool_type: 'api'
input_keys: ['input']
langchain:
module: langchain_community.tools
class_name: ListSQLDatabaseTool
init_params:
db_wrapper: demo_sqldb_wrapper
metadata:
type: 'TOOL'
module: 'sample_standard_app.app.core.tool.langchain_tool.sql_langchain_tool'
class: 'SqlLangchainTool'

View File

@@ -0,0 +1,13 @@
name: 'query_sql_database_tool'
description: ''
tool_type: 'api'
input_keys: ['input']
langchain:
module: langchain_community.tools
class_name: QuerySQLDataBaseTool
init_params:
db_wrapper: demo_sqldb_wrapper
metadata:
type: 'TOOL'
module: 'sample_standard_app.app.core.tool.langchain_tool.sql_langchain_tool'
class: 'SqlLangchainTool'

View File

@@ -0,0 +1,39 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Time : 2024/7/9 20:04
# @Author : weizjajj
# @Email : weizhongjie.wzj@antgroup.com
# @FileName: sql_langchain_tool.py
from typing import Type, Optional
from langchain_core.tools import BaseTool, Tool as LangchainTool
from agentuniverse.agent.action.tool.tool import ToolInput
from agentuniverse.database.sqldb_wrapper_manager import SQLDBWrapperManager
from sample_standard_app.app.core.tool.langchain_tool.langchain_tool import LangChainTool
class SqlLangchainTool(LangChainTool):
db_wrapper_name: Optional[str] = ""
clz: Type[BaseTool] = BaseTool
def execute(self, tool_input: ToolInput):
if self.tool is None:
self.get_sql_database()
return super().execute(tool_input)
def get_sql_database(self):
db_wrapper = SQLDBWrapperManager().get_instance_obj(self.db_wrapper_name)
self.tool = self.clz(db=db_wrapper.sql_database)
self.description = self.tool.description
def as_langchain(self) -> LangchainTool:
if self.tool is None:
self.get_sql_database()
return super().as_langchain()
def get_langchain_tool(self, init_params: dict, clz: Type[BaseTool]):
self.db_wrapper_name = init_params.get("db_wrapper")
self.clz = clz