Fix: MCPToolkit can't pass right args to MCPTool

Fix: RequestTask can't close correctly if task thread end without put an eof into queue
Fix: LLM Instrument can't trace token usage correctly in streaming mode
Fix: function 'get_invocation_chain' returns None in some case
Fix: Wrong logic while adding raw token usage dict
Fix: Sls sink use async func to process log
ADD: MCPManager add an async func 'safe_close_stack_async' to close all session in asyncio env
ADD: add sls base sink
This commit is contained in:
AniviaTn
2025-11-10 16:59:56 +08:00
parent fcebe7c9a3
commit 440e3c7cec
8 changed files with 156 additions and 34 deletions

View File

@@ -99,24 +99,14 @@ class MCPToolkit(Toolkit):
server_name=self.server_name,
origin_tool_name=tool.name,
args_model_schema=tool.inputSchema,
input_keys=tool.inputSchema['required'],
**self.get_mcp_server_connect_args()
input_keys=tool.inputSchema.get('required', []),
transport=self.transport,
url=self.url,
command=self.command,
args=self.args,
env=self.env,
connection_kwargs=self.connection_kwargs
)
# TODO The following revision still has issues that need to be fixed.
# tool_instance = MCPTool(
# name=tool_name,
# description=f'{tool.description}\n{str(tool.inputSchema)}',
# server_name=self.server_name,
# origin_tool_name=tool.name,
# args_model_schema=tool.inputSchema,
# input_keys=tool.inputSchema['required'],
# transport=self.transport,
# url=self.url,
# command=self.command,
# args=self.args,
# env=self.env,
# connection_kwargs=self.connection_kwargs
# )
ToolManager().register(tool_instance.get_instance_code(), tool_instance)
def _initialize_by_component_configer(self, component_configer: ComponentConfiger) -> 'MCPToolkit':

View File

@@ -90,7 +90,12 @@ class RequestTask:
first_chunk = True
start_time = time.time()
while True:
output: str = self.queue.get()
try:
output = self.queue.get(timeout=1)
except queue.Empty:
if not self.thread.is_alive():
break
continue
if output is None:
break
if output == EOF_SIGNAL:
@@ -134,7 +139,12 @@ class RequestTask:
first_chunk = True
start_time = time.time()
while True:
output: str = self.queue.get()
try:
output = self.queue.get(timeout=1)
except queue.Empty:
if not self.thread.is_alive():
break
continue
if output is None:
break
if output == EOF_SIGNAL:
@@ -179,6 +189,9 @@ class RequestTask:
except asyncio.TimeoutError:
await asyncio.sleep(1)
print("Waiting for data timed out. Retrying...")
if self.async_task and self.async_task.done():
LOGGER.error("Task finished without EOF")
break
continue
if output is None:
break
@@ -220,7 +233,12 @@ class RequestTask:
try:
self.next_state(TaskStateEnum.RUNNING)
while True:
output: str = self.queue.get()
try:
output = self.queue.get(timeout=1)
except queue.Empty:
if not self.thread.is_alive():
break
continue
if output is None:
break
if output == EOF_SIGNAL:

View File

@@ -760,5 +760,13 @@ class MCPSessionManager:
self.__exit_stack.set(None)
self.__mcp_session_dict.set(None)
async def safe_close_stack_async(self) -> None:
if isinstance(self.exit_stack, AsyncExitStack):
await self.exit_stack.aclose()
elif isinstance(self.exit_stack, SyncAsyncExitStack):
self.exit_stack.close()
self.__exit_stack.set(None)
self.__mcp_session_dict.set(None)
def run_async(self, func, *args, **kwargs):
return self.exit_stack.run_async(func, *args, **kwargs)

View File

@@ -321,6 +321,8 @@ class StreamingResultProcessor:
if pseudo_result.usage:
self.metrics_recorder.record_token_usage(pseudo_result.usage,
self.labels)
add_current_token_usage(pseudo_result.usage,
self.span.context.span_id)
LLMSpanAttributesSetter.set_success_attributes(self.span, duration,
pseudo_result)

View File

@@ -200,14 +200,28 @@ class Monitor(BaseModel):
def get_invocation_chain():
"""Get the invocation chain in the framework context."""
trace_id = AuTraceManager().get_trace_id()
return FrameworkContextManager().get_context(trace_id + '_invocation_chain', []) if trace_id is not None else []
current_chain = FrameworkContextManager().get_context(trace_id + '_invocation_chain')
if isinstance(current_chain, list):
return current_chain
else:
current_chain = []
FrameworkContextManager().set_context(
trace_id + '_invocation_chain', current_chain)
return current_chain
@staticmethod
def get_invocation_chain_bak():
"""Get the invocation chain bak version in the framework context."""
trace_id = AuTraceManager().get_trace_id()
return FrameworkContextManager().get_context(trace_id + '_invocation_chain_bak',
[]) if trace_id is not None else []
current_chain = FrameworkContextManager().get_context(
trace_id + '_invocation_chain_bak')
if isinstance(current_chain, list):
return current_chain
else:
current_chain = []
FrameworkContextManager().set_context(
trace_id + '_invocation_chain_bak', current_chain)
return current_chain
@staticmethod
def init_token_usage():
@@ -225,9 +239,14 @@ class Monitor(BaseModel):
if trace_id is not None:
old_token_usage: dict = FrameworkContextManager().get_context(trace_id + '_token_usage')
if old_token_usage is not None:
result_usage = {}
for key, value in cur_token_usage.items():
old_token_usage[key] = old_token_usage[key] + value if key in old_token_usage else value
FrameworkContextManager().set_context(trace_id + '_token_usage', old_token_usage)
try:
result_usage[key] = old_token_usage[key] + value if key in old_token_usage else value
except:
# not addable value
pass
FrameworkContextManager().set_context(trace_id + '_token_usage', result_usage)
@staticmethod
def clear_token_usage():

View File

@@ -0,0 +1,75 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
import asyncio
from agentuniverse.base.util.logging.log_sink.log_sink import LogSink
from agentuniverse.base.util.logging.logging_config import LoggingConfig
from agentuniverse.base.util.logging.logging_util import \
is_in_coroutine_context
from loguru import logger
# @Time : 2025/11/5 10:51
# @Author : fanen.lhy
# @Email : fanen.lhy@antgroup.com
# @FileName: base_sls_log_sink.py
class BaseSLSLogSink(LogSink):
def process_record(self, record):
raise NotImplementedError("Subclasses must implement process_record.")
def filter(self, record):
if not record['extra'].get('log_type') == self.log_type:
return False
self.process_record(record)
return True
def register_sink(self):
if LoggingConfig.log_extend_module_switch["sls_log"]:
print(
f"biz_logger_is_in_coroutine_context={is_in_coroutine_context()}")
if is_in_coroutine_context():
from agentuniverse_extension.logger.sls_sink import \
AsyncSlsSender, AsyncSlsSink
sls_sender = AsyncSlsSender(LoggingConfig.sls_project,
LoggingConfig.sls_log_store,
LoggingConfig.sls_endpoint,
LoggingConfig.access_key_id,
LoggingConfig.access_key_secret,
LoggingConfig.sls_log_queue_max_size,
LoggingConfig.sls_log_send_interval)
loop = asyncio.get_event_loop_policy().get_event_loop()
loop.create_task(sls_sender.start())
if self.sink_id == -1:
self.sink_id = logger.add(
sink=AsyncSlsSink(sls_sender),
format=LoggingConfig.log_format,
filter=self.filter,
level=LoggingConfig.log_level,
enqueue=False
)
else:
from agentuniverse_extension.logger.sls_sink import SlsSink, \
SlsSender
sls_sender = SlsSender(LoggingConfig.sls_project,
LoggingConfig.sls_log_store,
LoggingConfig.sls_endpoint,
LoggingConfig.access_key_id,
LoggingConfig.access_key_secret,
LoggingConfig.sls_log_queue_max_size,
LoggingConfig.sls_log_send_interval)
sls_sender.start_batch_send_thread()
if self.sink_id == -1:
self.sink_id = logger.add(
sink=SlsSink(sls_sender),
format=LoggingConfig.log_format,
filter=self.filter,
level=LoggingConfig.log_level,
enqueue=self.enqueue
)

View File

@@ -55,12 +55,22 @@ class AsyncSlsSender:
if self._bg_task is None or self._bg_task.done():
self._bg_task = self._loop.create_task(self._worker())
async def put(self, item: LogItem, /) -> None:
"""异步放入队列;满了直接丢弃(不阻塞业务协程)"""
def put(self, item: LogItem, /) -> None:
def _safe_put():
try:
self._queue.put_nowait(item)
except asyncio.QueueFull:
pass
try:
self._queue.put_nowait(item)
except asyncio.QueueFull:
logger.error("SLS log queue full drop a log item")
running = asyncio.get_running_loop()
except RuntimeError:
running = None
if running is self._loop:
_safe_put()
else:
self._loop.call_soon_threadsafe(_safe_put)
async def aclose(self, timeout: float | None = 5.0) -> None:
"""
@@ -295,13 +305,13 @@ class AsyncSlsSink:
def __init__(self, sender: AsyncSlsSender):
self._sender = sender
async def __call__(self, message):
def __call__(self, message):
record = message.record
item = LogItem(
contents=[("content", message)],
timestamp=int(record["time"].timestamp())
)
await self._sender.put(item)
self._sender.put(item)
class SlsSink:

View File

@@ -68,12 +68,12 @@ jieba = "^0.42.1"
networkx = "^3.3"
httpx = ">=0.27.2"
tomli = "^2.2"
mcp = "~=1.9.0"
mcp = "<1.22.0"
opentracing = ">=2.4.0,<3.0.0"
jsonlines = "^4.0.0"
EbookLib = "^0.18"
beautifulsoup4 = "^4.12.0"
qdrant-client = "^1.15.1"
# qdrant-client = "^1.15.1" unsupportable version due to numpy version limit
[tool.poetry.extras]
log_ext = ["aliyun-log-python-sdk"]