修复channel问题

This commit is contained in:
weizjajj
2025-05-09 13:33:08 +08:00
parent bfa01d0e41
commit 87b64b3681
15 changed files with 394 additions and 471 deletions

View File

@@ -105,7 +105,7 @@ class Agent(ComponentBase, ABC):
scene_code = input_object.get_data("scene_code")
if scene_code:
AuTraceManager().trace_context.set_scene_code(scene_code)
FrameworkContextManager().set_context("scene_code", scene_code)
@trace_agent
def run(self, **kwargs) -> OutputObject:

View File

@@ -8,9 +8,6 @@
from typing import Optional
from pydantic import BaseModel
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
from agentuniverse.base.util.billing_center import BillingCenterInfo
class AgentModel(BaseModel):
"""The parent class of all agent models, containing only attributes."""
@@ -22,10 +19,6 @@ class AgentModel(BaseModel):
action: Optional[dict] = dict()
work_pattern: Optional[dict] = dict()
def billing_center_params(self):
billing_center_info = BillingCenterInfo(agent_id=self.info.get("name"))
return billing_center_info
def llm_params(self) -> dict:
"""
Returns:
@@ -39,6 +32,4 @@ class AgentModel(BaseModel):
params['model'] = value
else:
params[key] = value
if ApplicationConfigManager().app_configer.use_billing_center:
params["billing_center_params"] = self.billing_center_params()
return params

View File

@@ -14,18 +14,24 @@ import uuid
from functools import wraps
from agentuniverse.agent.memory.conversation_memory.conversation_memory_module import ConversationMemoryModule
from agentuniverse.base.util.billing_center import trace_billing
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
from agentuniverse.base.util.monitor.monitor import Monitor
from agentuniverse.llm.llm_output import LLMOutput
def trace_llm(func):
"""Annotation: @trace_llm
Decorator to trace the LLM invocation, add llm input and output to the monitor.
"""
def plugins(func):
llm_plugins = ApplicationConfigManager().app_configer.llm_plugins
warp_func = func
for item in llm_plugins:
warp_func = item(func)
return warp_func
@wraps(func)
@trace_billing
async def wrapper_async(*args, **kwargs):
# get llm input from arguments
llm_input = _get_input(func, *args, **kwargs)
@@ -42,7 +48,7 @@ def trace_llm(func):
if self and hasattr(self, 'tracing'):
if self.tracing is False:
return await func(*args, **kwargs)
return await plugins(func)(*args, **kwargs)
# add invocation chain to the monitor module.
Monitor.add_invocation_chain({'source': source, 'type': 'llm'})
@@ -51,7 +57,7 @@ def trace_llm(func):
Monitor().trace_llm_input(source=source, llm_input=llm_input)
# invoke function
result = await func(*args, **kwargs)
result = await plugins(func)(*args, **kwargs)
# not streaming
if isinstance(result, LLMOutput):
# add llm invocation info to monitor
@@ -80,7 +86,6 @@ def trace_llm(func):
return gen_iterator()
@functools.wraps(func)
@trace_billing
def wrapper_sync(*args, **kwargs):
# get llm input from arguments
llm_input = _get_input(func, *args, **kwargs)
@@ -97,7 +102,7 @@ def trace_llm(func):
if self and hasattr(self, 'tracing'):
if self.tracing is False:
return func(*args, **kwargs)
return plugins(func)(*args, **kwargs)
# add invocation chain to the monitor module.
Monitor.add_invocation_chain({'source': source, 'type': 'llm'})
@@ -105,7 +110,7 @@ def trace_llm(func):
start_time = time.time()
Monitor().trace_llm_input(source=source, llm_input=llm_input)
# invoke function
result = func(*args, **kwargs)
result = plugins(func)(*args, **kwargs)
# not streaming
if isinstance(result, LLMOutput):
# add llm invocation info to monitor
@@ -135,12 +140,12 @@ def trace_llm(func):
return gen_iterator()
if asyncio.iscoroutinefunction(func):
# async function
return wrapper_async
else:
# sync function
return wrapper_sync
def get_caller_info(instance: object = None):
source_list = Monitor.get_invocation_chain()
if len(source_list) > 0:

View File

@@ -1,11 +1,11 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
import importlib
# @Time : 2024/3/12 16:17
# @Author : jerry.zzw
# @Email : jerry.zzw@antgroup.com
# @FileName: app_configer.py
from typing import Optional, Dict
from typing import Optional, Dict, Any
from agentuniverse.base.config.component_configer.configers.llm_configer import LLMConfiger
from agentuniverse.base.config.component_configer.configers.tool_configer import ToolConfiger
@@ -51,9 +51,7 @@ class AppConfiger(object):
self.__llm_configer_map: Dict[str, LLMConfiger] = {}
self.__agent_llm_set: Optional[set[str]] = set()
self.__agent_tool_set: Optional[set[str]] = set()
self.__billing_center: Optional[dict] = {}
self.__use_billing_center: Optional[bool] = False
self.__billing_center_url: Optional[bool] = None
self.__llm_plugins: Optional[Any] = set()
@property
def base_info_appname(self) -> Optional[str]:
@@ -229,16 +227,17 @@ class AppConfiger(object):
self.__agent_tool_set = value
@property
def billing_center(self):
return self.__billing_center
def llm_plugins(self):
return self.__llm_plugins
@property
def billing_center_url(self):
return self.__billing_center_url
@property
def use_billing_center(self):
return self.__use_billing_center
@classmethod
def load_llm_plugins(cls, plugin_modules):
funcs = []
for item in plugin_modules:
module_name, func_name = item.rsplit('.', 1)
module = importlib.import_module(module_name)
funcs.append(getattr(module, func_name))
return funcs
def load_by_configer(self, configer: Configer) -> 'AppConfiger':
"""Load the AppConfiger by the given Configer.
@@ -275,7 +274,5 @@ class AppConfiger(object):
self.__core_log_sink_package_list = configer.value.get('CORE_PACKAGE', {}).get('log_sink')
self.__core_llm_channel_package_list = configer.value.get('CORE_PACKAGE', {}).get('llm_channel')
self.__conversation_memory_configer = configer.value.get('CONVERSATION_MEMORY', {})
self.__billing_center = configer.value.get('BILLING_CENTER', {})
self.__use_billing_center = self.billing_center.get("use_billing_center")
self.__billing_center_url = self.billing_center.get("billing_center_url")
self.__llm_plugins = self.load_llm_plugins(configer.value.get("PLUGINS", {}).get("llm_plugins", []))
return self

View File

@@ -1,184 +0,0 @@
import asyncio
import functools
import json
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, Any
import httpx
from agentuniverse.base.util.logging.logging_util import LOGGER
from agentuniverse.llm.llm_output import LLMOutput
from pydantic import BaseModel
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
from agentuniverse.base.util.monitor.monitor import Monitor
from agentuniverse.base.util.tracing.au_trace_manager import AuTraceManager
client = httpx.Client(timeout=60)
thread_pool = ThreadPoolExecutor(max_workers=5)
class BillingCenterInfo(BaseModel):
app_id: Optional[str] = None
agent_id: Optional[str] = None
session_id: Optional[str] = None
trace_id: Optional[str] = None
scene_code: Optional[str] = None
base_url: Optional[str] = ""
model: Optional[str] = None
usage: Optional[dict] = None
input: Optional[dict] = None
output: Optional[dict] = None
def __init__(self, **kwargs):
params = kwargs
app_id = ApplicationConfigManager().app_configer.base_info_appname
params['app_id'] = app_id
agent_id = kwargs.get("agent_id", None)
if not agent_id:
agent_id = self._get_caller_agent()
params['agent_id'] = agent_id
if not kwargs.get("trace_id"):
trace_id = AuTraceManager().get_trace_id()
params['trace_id'] = trace_id
if not kwargs.get("session_id"):
session_id = AuTraceManager().get_session_id()
params['session_id'] = session_id
if not kwargs.get("scene_code"):
scene_code = AuTraceManager().get_scene_code()
params["scene_code"] = scene_code
super().__init__(**params)
@staticmethod
def _get_caller_agent():
source_list = Monitor.get_invocation_chain()
if len(source_list) > 0:
# 逆序遍历
for item in reversed(source_list):
if item.get("type") == "agent":
return item.get("source", None)
return "unknown"
def push_billing_center_info(self):
params = self.model_dump()
LOGGER.info(json.dumps(params, ensure_ascii=False, indent=4))
endpoint = ApplicationConfigManager().app_configer.billing_center_url + "/billing/token/usage"
response = client.post(
endpoint, json=params, headers={
"content-type": "application/json"
}
)
if response.status_code != 200:
LOGGER.error(f"Push billing center info error {params}")
else:
LOGGER.info(
f"Push billing center success {json.dumps(params, ensure_ascii=False, indent=4)}")
def trace_billing(func):
@functools.wraps(func)
async def wrapper_async(*args, **kwargs):
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
result = await func(*args, **kwargs)
if not ApplicationConfigManager().app_configer.use_billing_center:
return result
if billing_type == "proxy":
return result
billing_center_params = kwargs.pop("billing_center_params", None)
if not billing_center_params:
billing_center_params = BillingCenterInfo()
llm = args[0]
llm.update_billing_center_params(billing_center_params, kwargs)
if isinstance(result, LLMOutput):
billing_center_params.usage = llm.get_billing_tokens(kwargs, result)
billing_center_params.output = result.model_dump()
LOGGER.info(f"Billing info {billing_center_params.usage}")
thread_pool.submit(billing_center_params.push_billing_center_info)
else:
async def generator():
async for item in llm.async_billing_tokens_from_stream(result, billing_center_params):
yield item
LOGGER.info(f"Billing info {billing_center_params.usage}")
thread_pool.submit(billing_center_params.push_billing_center_info)
return generator()
return result
@functools.wraps(func)
def wrapper_sync(*args, **kwargs):
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
result = func(*args, **kwargs)
if not ApplicationConfigManager().app_configer.use_billing_center:
return result
if billing_type == "proxy":
return result
billing_center_params = kwargs.pop("billing_center_params", None)
if not billing_center_params:
billing_center_params = BillingCenterInfo()
llm = args[0]
llm.update_billing_center_params(billing_center_params, kwargs)
if isinstance(result, LLMOutput):
billing_center_params.usage = llm.get_billing_tokens(kwargs, result)
billing_center_params.output = result.model_dump()
LOGGER.info(f"Billing info {billing_center_params.usage}")
thread_pool.submit(billing_center_params.push_billing_center_info)
return result
else:
def generator():
for item in llm.billing_tokens_from_stream(result, billing_center_params):
yield item
LOGGER.info(f"Billing info {billing_center_params.usage}")
thread_pool.submit(billing_center_params.push_billing_center_info)
return generator()
if asyncio.iscoroutinefunction(func):
# async function
return wrapper_async
else:
# sync function
return wrapper_sync
class BillingCenter(BaseModel):
@classmethod
def get_base_url(cls, base_url: str):
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
if ApplicationConfigManager().app_configer.use_billing_center and billing_type == "proxy":
return ApplicationConfigManager().app_configer.billing_center_url
return base_url
@classmethod
def billing_center_openai_channel_headers(cls, llm, params_map: dict[str, Any]):
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
if not ApplicationConfigManager().app_configer.use_billing_center:
return llm.channel_ext_headers, params_map
if billing_type == "push":
params_map.pop("billing_center_params")
return llm.channel_ext_headers, params_map
billing_center_params: BillingCenterInfo = params_map.pop("billing_center_params")
extra_headers = billing_center_params.model_dump(
include={'agent_id', 'app_id', 'scene_code', 'session_id', 'trace_id'}
)
extra_headers["OriginalUrl"] = llm.channel_api_base
if llm.channel_proxy:
extra_headers["proxy"] = llm.channel_proxy if llm.channel_proxy else ""
return {**llm.channel_ext_headers, **extra_headers}, params_map
@classmethod
def billing_center_openai_headers(cls, llm, params_map: dict[str, Any]):
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
if not ApplicationConfigManager().app_configer.use_billing_center:
return llm.ext_headers, params_map
if billing_type == "push":
params_map.pop("billing_center_params")
return llm.ext_headers, params_map
billing_center_params: BillingCenterInfo = params_map.pop("billing_center_params")
extra_headers = billing_center_params.model_dump(
include={'agent_id', 'app_id', 'scene_code', 'session_id', 'trace_id'}
)
extra_headers["OriginalUrl"] = llm.api_base
if llm.proxy:
extra_headers["proxy"] = llm.proxy
return {**llm.ext_headers, **extra_headers}, params_map

View File

@@ -0,0 +1,79 @@
import asyncio
import functools
from concurrent.futures.thread import ThreadPoolExecutor
from agentuniverse.base.util.logging.logging_util import LOGGER
from agentuniverse.base.util.manager.billing_center_manager import BillingCenterManager, BillingCenter, \
BillingCenterInfo
from agentuniverse.llm.llm_output import LLMOutput
thread_pool = ThreadPoolExecutor(max_workers=5)
def trace_billing(func):
@functools.wraps(func)
async def wrapper_async(*args, **kwargs):
billing_type = BillingCenterManager().billing_center_type
if not BillingCenterManager().use_billing_center:
result = func(*args, **kwargs)
return result
billing_center_info = BillingCenterInfo()
if billing_type == "proxy":
extra_headers = BillingCenter.llm_headers(args[0], billing_center_info)
api_base = BillingCenterManager().billing_center_url
kwargs['extra_headers'] = extra_headers
kwargs['api_base'] = api_base
return func(*args, **kwargs)
llm = args[0]
BillingCenter.update_billing_center_params(llm, billing_center_info, kwargs)
result = func(*args, **kwargs)
if isinstance(result, LLMOutput):
billing_center_info.usage = BillingCenter.get_billing_tokens(kwargs, result)
billing_center_info.output = result.model_dump()
LOGGER.info(f"Billing info {billing_center_info.usage}")
thread_pool.submit(billing_center_info.push_billing_center_info)
else:
async def generator():
async for item in BillingCenter.async_billing_tokens_from_stream(llm, result, billing_center_info):
yield item
LOGGER.info(f"Billing info {billing_center_info.usage}")
thread_pool.submit(billing_center_info.push_billing_center_info)
return generator()
return result
@functools.wraps(func)
def wrapper_sync(*args, **kwargs):
billing_type = BillingCenterManager().billing_center_type
if not BillingCenterManager().use_billing_center:
result = func(*args, **kwargs)
return result
billing_center_info = BillingCenterInfo()
if billing_type == "proxy":
extra_headers = BillingCenter.llm_headers(args[0], billing_center_info)
api_base = BillingCenterManager().billing_center_url
kwargs['extra_headers'] = extra_headers
kwargs['api_base'] = api_base
return func(*args, **kwargs)
result = func(*args, **kwargs)
llm = args[0]
BillingCenter.update_billing_center_params(llm, billing_center_info, kwargs)
if isinstance(result, LLMOutput):
billing_center_info.usage = BillingCenter.get_billing_tokens(llm, kwargs, result)
billing_center_info.output = result.model_dump()
LOGGER.info(f"Billing info {billing_center_info.usage}")
thread_pool.submit(billing_center_info.push_billing_center_info)
return result
else:
def generator():
for item in BillingCenter.billing_tokens_from_stream(llm, result, billing_center_info):
yield item
LOGGER.info(f"Billing info {billing_center_info.usage}")
thread_pool.submit(billing_center_info.push_billing_center_info)
return generator()
if asyncio.iscoroutinefunction(func):
# async function
return wrapper_async
else:
# sync function
return wrapper_sync

View File

@@ -0,0 +1,228 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Time : 2025/5/9 09:32
# @Author : weizjajj
# @Email : weizhongjie.wzj@antgroup.com
# @FileName: billing_center_manager.py
import json
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Iterator, AsyncIterator, Optional
import httpx
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
from agentuniverse.base.context.framework_context_manager import FrameworkContextManager
from agentuniverse.base.util.logging.logging_util import LOGGER
from agentuniverse.base.util.monitor.monitor import Monitor
from agentuniverse.base.util.tracing.au_trace_manager import AuTraceManager
from agentuniverse.llm.llm_channel.llm_channel import LLMChannel
from agentuniverse.llm.openai_style_llm import OpenAIStyleLLM
from agentuniverse.llm.llm_output import LLMOutput
from pydantic import BaseModel
from agentuniverse.base.annotation.singleton import singleton
client = httpx.Client(timeout=60)
class BillingCenterInfo(BaseModel):
app_id: Optional[str] = None
agent_id: Optional[str] = None
session_id: Optional[str] = None
trace_id: Optional[str] = None
scene_code: Optional[str] = None
base_url: Optional[str] = ""
model: Optional[str] = None
usage: Optional[dict] = None
input: Optional[dict] = None
output: Optional[dict] = None
def __init__(self, **kwargs):
params = kwargs
app_id = ApplicationConfigManager().app_configer.base_info_appname
params['app_id'] = app_id
agent_id = kwargs.get("agent_id", None)
if not agent_id:
agent_id = self._get_caller_agent()
params['agent_id'] = agent_id
if not kwargs.get("trace_id"):
trace_id = AuTraceManager().get_trace_id()
params['trace_id'] = trace_id
if not kwargs.get("session_id"):
session_id = AuTraceManager().get_session_id()
params['session_id'] = session_id
if not kwargs.get("scene_code"):
scene_code = FrameworkContextManager().get_context("scene_code")
params["scene_code"] = scene_code
super().__init__(**params)
@staticmethod
def _get_caller_agent():
source_list = Monitor.get_invocation_chain()
if len(source_list) > 0:
# 逆序遍历
for item in reversed(source_list):
if item.get("type") == "agent":
return item.get("source", None)
return "unknown"
def push_billing_center_info(self):
params = self.model_dump()
LOGGER.info(json.dumps(params, ensure_ascii=False, indent=4))
endpoint = BillingCenterManager().billing_center_url + "/billing/token/usage"
response = client.post(
endpoint, json=params, headers={
"content-type": "application/json"
}
)
if response.status_code != 200:
LOGGER.error(f"Push billing center info error {params}")
else:
LOGGER.info(
f"Push billing center success {json.dumps(params, ensure_ascii=False, indent=4)}")
@singleton
class BillingCenterManager:
def __init__(self, configer=None):
billing_center_info = configer.value.get("BILLING_CENTER")
self.use_billing_center = billing_center_info.get("use_billing_center")
self.billing_center_url = billing_center_info.get("billing_center_url")
self.billing_center_type = billing_center_info.get("billing_type")
class BillingCenter(BaseModel):
@classmethod
def llm_headers(cls, llm, billing_center_info: BillingCenterInfo):
extra_headers = None
if hasattr(llm, "ext_headers"):
extra_headers = llm.ext_headers
elif hasattr(llm, "headers"):
extra_headers = llm.headers
if not extra_headers:
extra_headers = {
"content-type": "application/json"
}
extra_headers = extra_headers.copy()
keys = ['agent_id', 'app_id', 'scene_code', 'session_id', 'trace_id']
billing_center_info_dict = billing_center_info.model_dump()
for key in keys:
extra_headers[key] = billing_center_info_dict.get(key, "") if billing_center_info_dict.get(key) else ""
if isinstance(llm, OpenAIStyleLLM):
extra_headers["base_url"] = llm.api_base
if llm.proxy:
extra_headers["proxy"] = llm.proxy
return extra_headers
elif isinstance(llm, LLMChannel):
extra_headers["base_url"] = llm.channel_api_base
if llm.channel_proxy:
extra_headers["proxy"] = llm.channel_proxy
else:
extra_headers["base_url"] = llm.model_name
return extra_headers
@classmethod
def get_base_url(cls, base_url: str):
billing_type = BillingCenterManager().billing_center_type
if BillingCenterManager().use_billing_center and billing_type == "proxy":
return BillingCenterManager().billing_center_url
return base_url
@classmethod
def billing_tokens_from_stream(cls, llm, generator: Iterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
usage = None
for llm_output in generator:
content += llm_output.text
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
usage = llm_output.raw.get("usage")
yield llm_output
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
if usage:
billing_center_params.usage = usage
return
billing_center_params.usage = cls.get_billing_tokens(llm, billing_center_params.input,
llm_output)
@classmethod
async def async_billing_tokens_from_stream(cls, llm, generator: AsyncIterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
usage = None
async for llm_output in generator:
content += llm_output.text
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
usage = llm_output.raw.get("usage")
yield llm_output
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
if usage:
billing_center_params.usage = usage
return
billing_center_params.usage = cls.get_billing_tokens(llm, billing_center_params.input,
llm_output)
@classmethod
def update_billing_center_params(cls, llm, params: BillingCenterInfo, input: dict) -> BillingCenterInfo:
if "model" in input:
params.model = input.get("model")
else:
params.model = llm.model_name
if isinstance(llm, LLMChannel):
params.base_url = llm.channel_api_base
elif isinstance(llm, OpenAIStyleLLM):
params.base_url = llm.api_base
elif getattr(llm, "service_id"):
params.base_url = llm.service_id
elif getattr(llm, "serviceId"):
params.base_url = llm.serviceId
elif getattr(llm, "endpoint"):
params.base_url = llm.endpoint
else:
params = llm.model_name
return params
@classmethod
def _get_billing_tokens(cls, llm, input: dict, output: LLMOutput) -> dict:
text = ""
if "messages" in input:
messages = input.get("messages")
for message in messages:
content = message.get("content")
if isinstance(content, str):
text += content
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
text += item.get("content")
elif "prompt" in input:
text = str(input.get("prompt"))
prompt_tokens = llm.get_num_tokens(text)
output = output.text
completion_tokens = llm.get_num_tokens(output)
total_tokens = prompt_tokens + completion_tokens
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
@classmethod
def get_billing_tokens(cls, llm, input: dict, output: LLMOutput):
output_json = output.raw
if "usage" in output_json and output_json['usage'].get("prompt_tokens", 0) > 0:
return output_json['usage']
return cls._get_billing_tokens(llm, input, output)

View File

@@ -65,6 +65,7 @@ class AuTraceContext:
self._span_id_counter += 1
return child_span_id
def to_dict(self) -> dict:
return {
"session_id": self.session_id,

View File

@@ -18,7 +18,6 @@ from agentuniverse.base.component.component_base import ComponentBase
from agentuniverse.base.component.component_enum import ComponentEnum
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
from agentuniverse.base.config.component_configer.configers.llm_configer import LLMConfiger
from agentuniverse.base.util.billing_center import BillingCenterInfo
from agentuniverse.base.util.logging.logging_util import LOGGER
from agentuniverse.llm.llm_channel.llm_channel import LLMChannel
from agentuniverse.llm.llm_channel.llm_channel_manager import LLMChannelManager
@@ -186,15 +185,10 @@ class LLM(ComponentBase):
Returns:
The integer number of tokens in the text.
"""
def get_num_tokens_from_model(self, text: str, model=None):
if model:
return self.get_num_tokens(text)
encoding = tiktoken.get_encoding("cl100k_base")
try:
encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
LOGGER.error("get_num_tokens_from_model error")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
def as_langchain_runnable(self, params=None) -> Runnable:
@@ -236,63 +230,3 @@ class LLM(ComponentBase):
copied.async_client = self.async_client
copied.langchain_instance = self.langchain_instance
return copied
def get_billing_tokens(self, input: dict, output: LLMOutput) -> dict:
text = ""
if "messages" in input:
messages = input.get("messages")
for message in messages:
content = message.get("content")
if isinstance(content, str):
text += content
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
content += item.get("content")
elif "prompt" in input:
text = str(input.get("prompt"))
prompt_tokens = self.get_num_tokens_from_model(text, input.get("model"))
output = output.text
completion_tokens = self.get_num_tokens_from_model(output, input.get("model"))
total_tokens = prompt_tokens + completion_tokens
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
def update_billing_center_params(self, params: BillingCenterInfo, input: dict) -> BillingCenterInfo:
if "model" in input:
params.model = input.get("model")
else:
params.model = self.model_name
params.input = input
return params
def billing_tokens_from_stream(self, generator: Iterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
for llm_output in generator:
content += llm_output.text
yield llm_output
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
llm_output)
async def async_billing_tokens_from_stream(self, generator: AsyncIterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
async for llm_output in generator:
content += llm_output.text
yield llm_output
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
llm_output)

View File

@@ -19,7 +19,6 @@ from agentuniverse.base.component.component_base import ComponentBase
from agentuniverse.base.component.component_enum import ComponentEnum
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger
from agentuniverse.base.util.billing_center import BillingCenter, BillingCenterInfo
from agentuniverse.llm.llm_channel.langchain_instance.default_channel_langchain_instance import \
DefaultChannelLangchainInstance
from agentuniverse.llm.llm_output import LLMOutput
@@ -33,8 +32,8 @@ class LLMChannel(ComponentBase):
channel_proxy: Optional[str] = None
channel_model_name: Optional[str] = None
channel_ext_info: Optional[dict] = None
channel_ext_headers: Optional[dict] = {}
channel_ext_params: Optional[dict] = {}
ext_headers: Optional[dict] = {}
ext_params: Optional[dict] = {}
model_support_stream: Optional[bool] = None
model_support_max_context_length: Optional[int] = None
@@ -71,15 +70,15 @@ class LLMChannel(ComponentBase):
self.model_support_max_tokens = component_configer.model_support_max_tokens
if hasattr(component_configer, "model_is_openai_protocol_compatible"):
self.model_is_openai_protocol_compatible = component_configer.model_is_openai_protocol_compatible
if component_configer.configer.value.get("extra_headers"):
self.channel_ext_headers = component_configer.configer.value.get("extra_headers", {})
if component_configer.configer.value.get("extra_params"):
self.channel_ext_params = component_configer.configer.value.get("extra_params", {})
self.channel_ext_params["stream_options"] = {
if component_configer.configer.value.get("ext_headers"):
self.ext_headers = component_configer.configer.value.get("extra_headers", {})
if component_configer.configer.value.get("ext_params"):
self.ext_params = component_configer.configer.value.get("extra_params", {})
self.ext_params["stream_options"] = {
"include_usage": True
}
else:
self.channel_ext_params = {
self.ext_params = {
"stream_options": {
"include_usage": True
}
@@ -129,25 +128,24 @@ class LLMChannel(ComponentBase):
streaming = kwargs.pop('stream')
if self.model_support_stream is False and streaming is True:
streaming = False
extra_headers, kwargs = BillingCenter.billing_center_openai_channel_headers(self, kwargs)
support_max_tokens = self.model_support_max_tokens
max_tokens = kwargs.pop('max_tokens', None) or self.channel_model_config.get('max_tokens',
None) or support_max_tokens
if support_max_tokens:
max_tokens = min(support_max_tokens, max_tokens)
ext_params = self.channel_ext_params.copy()
ext_params = self.ext_params.copy()
if not streaming:
ext_params.pop("stream_options")
self.client = self._new_client()
ext_params.pop("stream_options", "")
self.client = self._new_client(kwargs.pop("api_base", None))
chat_completion = self.client.chat.completions.create(
messages=messages,
model=kwargs.pop('model', self.channel_model_name),
temperature=kwargs.pop('temperature', self.channel_model_config.get('temperature')),
stream=kwargs.pop('stream', streaming),
max_tokens=max_tokens,
extra_headers=extra_headers,
extra_body=ext_params,
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
**kwargs,
)
if not streaming:
@@ -172,16 +170,14 @@ class LLMChannel(ComponentBase):
ext_params = self.channel_ext_params.copy()
if not streaming:
ext_params.pop("stream_options")
extra_headers, kwargs = BillingCenter.billing_center_openai_channel_headers(self, kwargs)
self.async_client = self._new_async_client()
self.async_client = self._new_async_client(kwargs.pop("api_base", None))
chat_completion = await self.async_client.chat.completions.create(
messages=messages,
model=kwargs.pop('model', self.channel_model_name),
temperature=kwargs.pop('temperature', self.channel_model_config.get('temperature')),
stream=kwargs.pop('stream', streaming),
max_tokens=max_tokens,
extra_headers=extra_headers,
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
extra_body=ext_params,
**kwargs,
)
@@ -214,28 +210,28 @@ class LLMChannel(ComponentBase):
def max_context_length(self) -> int:
return self.channel_model_config.get('max_context_length')
def _new_client(self):
def _new_client(self, api_base: str = None):
"""Initialize the openai client."""
if self.client is not None:
return self.client
return OpenAI(
api_key=self.channel_api_key,
organization=self.channel_organization,
base_url=BillingCenter.get_base_url(self.channel_api_base),
base_url=api_base if api_base else self.channel_api_base,
timeout=self.channel_model_config.get('request_timeout'),
max_retries=self.channel_model_config.get('max_retries'),
http_client=httpx.Client(proxy=self.channel_proxy) if self.channel_proxy else None,
**(self.channel_model_config.get('client_args') or {}),
)
def _new_async_client(self):
def _new_async_client(self, api_base=None):
"""Initialize the openai async client."""
if self.async_client is not None:
return self.async_client
return AsyncOpenAI(
api_key=self.channel_api_key,
organization=self.channel_organization,
base_url=BillingCenter.get_base_url(self.channel_api_base),
base_url=api_base if api_base else self.channel_api_base,
timeout=self.channel_model_config.get('request_timeout'),
max_retries=self.channel_model_config.get('max_retries'),
http_client=httpx.AsyncClient(proxy=self.channel_proxy) if self.channel_proxy else None,
@@ -276,77 +272,3 @@ class LLMChannel(ComponentBase):
"""Return the full name of the component."""
appname = ApplicationConfigManager().app_configer.base_info_appname
return f'{appname}.{self.component_type.value.lower()}.{self.channel_name}'
def billing_tokens_from_stream(self, generator: Iterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
usage = None
for llm_output in generator:
content += llm_output.text
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
usage = llm_output.raw.get("usage")
yield llm_output
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
if usage:
billing_center_params.usage = usage
return
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
llm_output)
async def async_billing_tokens_from_stream(self, generator: AsyncIterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
usage = None
async for llm_output in generator:
content += llm_output.text
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
usage = llm_output.raw.get("usage")
yield llm_output
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
if usage:
billing_center_params.usage = usage
return
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
llm_output)
def update_billing_center_params(self, params: BillingCenterInfo, input: dict) -> BillingCenterInfo:
if "model" in input:
params.model = input.get("model")
else:
params.model = self.model_name
params.base_url = self.channel_api_base
params.input = input
return params
def get_billing_tokens(self, input: dict, output: LLMOutput):
output_json = output.raw
if "usage" in output_json and output_json['usage'].get("prompt_tokens", 0) > 0:
return output_json['usage']
messages = input.get("messages")
text = ""
for message in messages:
content = message.get("content")
if isinstance(content, str):
text += content
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
content += item.get("content")
prompt_tokens = self.get_num_tokens(text)
output = output.text
completion_tokens = self.get_num_tokens(output)
total_tokens = prompt_tokens + completion_tokens
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}

View File

@@ -15,7 +15,6 @@ from langchain_core.language_models.base import BaseLanguageModel
from openai import OpenAI, AsyncOpenAI
from agentuniverse.base.config.component_configer.configers.llm_configer import LLMConfiger
from agentuniverse.base.util.billing_center import BillingCenter, BillingCenterInfo
from agentuniverse.base.util.env_util import get_from_env
from agentuniverse.base.util.system_util import process_yaml_func
from agentuniverse.llm.llm import LLM, LLMOutput
@@ -45,14 +44,14 @@ class OpenAIStyleLLM(LLM):
ext_params: Optional[dict] = {}
ext_headers: Optional[dict] = {}
def _new_client(self):
def _new_client(self, api_base=None):
"""Initialize the openai client."""
if self.client is not None:
return self.client
return OpenAI(
api_key=self.api_key,
organization=self.organization,
base_url=BillingCenter.get_base_url(self.api_base),
base_url=api_base if api_base else self.api_base,
timeout=self.request_timeout,
max_retries=self.max_retries,
http_client=httpx.Client(proxy=self.proxy) if self.proxy else None,
@@ -60,14 +59,14 @@ class OpenAIStyleLLM(LLM):
**(self.client_args or {}),
)
def _new_async_client(self):
def _new_async_client(self, api_base: str = None):
"""Initialize the openai async client."""
if self.async_client is not None:
return self.async_client
return AsyncOpenAI(
api_key=self.api_key,
organization=self.organization,
base_url=BillingCenter.get_base_url(self.api_base),
base_url=api_base if api_base else self.api_base,
timeout=self.request_timeout,
max_retries=self.max_retries,
http_client=httpx.AsyncClient(proxy=self.proxy) if self.proxy else None,
@@ -84,17 +83,21 @@ class OpenAIStyleLLM(LLM):
streaming = kwargs.pop("streaming") if "streaming" in kwargs else self.streaming
if 'stream' in kwargs:
streaming = kwargs.pop('stream')
self.client = self._new_client()
self.client = self._new_client(kwargs.pop("api_base", None))
ext_params = self.ext_params.copy()
if streaming and "stream_options" not in ext_params:
ext_params["stream_options"] = {
"include_usage": True
}
client = self.client
extra_headers, kwargs = BillingCenter.billing_center_openai_headers(self, kwargs)
chat_completion = client.chat.completions.create(
messages=messages,
model=kwargs.pop('model', self.model_name),
temperature=kwargs.pop('temperature', self.temperature),
stream=kwargs.pop('stream', streaming),
max_tokens=kwargs.pop('max_tokens', self.max_tokens),
extra_headers=extra_headers,
extra_body=self.ext_params,
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
extra_body=ext_params,
**kwargs,
)
if not streaming:
@@ -112,17 +115,21 @@ class OpenAIStyleLLM(LLM):
streaming = kwargs.pop("streaming") if "streaming" in kwargs else self.streaming
if 'stream' in kwargs:
streaming = kwargs.pop('stream')
self.async_client = self._new_async_client()
self.async_client = self._new_async_client(kwargs.pop("api_base", None))
async_client = self.async_client
extra_headers, kwargs = BillingCenter.billing_center_openai_headers(self, kwargs)
ext_params = self.ext_params.copy()
if streaming and "stream_options" not in ext_params:
ext_params["stream_options"] = {
"include_usage": True
}
chat_completion = await async_client.chat.completions.create(
messages=messages,
model=kwargs.pop('model', self.model_name),
temperature=kwargs.pop('temperature', self.temperature),
stream=kwargs.pop('stream', streaming),
max_tokens=kwargs.pop('max_tokens', self.max_tokens),
extra_headers=extra_headers,
extra_body=self.ext_params,
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
extra_body=ext_params,
**kwargs,
)
if not streaming:
@@ -196,10 +203,10 @@ class OpenAIStyleLLM(LLM):
if 'proxy' in component_configer.configer.value:
proxy = component_configer.configer.value.get('proxy')
self.proxy = process_yaml_func(proxy, component_configer.yaml_func_instance)
if component_configer.configer.value.get("extra_headers"):
self.ext_headers = component_configer.configer.value.get("extra_headers")
if component_configer.configer.value.get("extra_params"):
self.ext_params = component_configer.configer.value.get("extra_params")
if component_configer.configer.value.get("ext_headers"):
self.ext_headers = component_configer.configer.value.get("ext_headers")
if component_configer.configer.value.get("ext_params"):
self.ext_params = component_configer.configer.value.get("ext_params")
return super().initialize_by_component_configer(component_configer)
@@ -224,71 +231,3 @@ class OpenAIStyleLLM(LLM):
"""Return the maximum length of the context."""
if super().max_context_length():
return super().max_context_length()
def get_billing_tokens(self, input: dict, output: LLMOutput):
output_json = output.raw
if "usage" in output_json and output_json['usage'].get("prompt_tokens", 0) > 0:
return output_json['usage']
messages = input.get("messages")
text = ""
for message in messages:
content = message.get("content")
if isinstance(content, str):
text += content
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
content += item.get("content")
prompt_tokens = self.get_num_tokens_from_model(text, model=input.get("model"))
output = output.text
completion_tokens = self.get_num_tokens_from_model(output, model=input.get("model"))
total_tokens = prompt_tokens + completion_tokens
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
def billing_tokens_from_stream(self, generator: Iterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
usage = None
for llm_output in generator:
content += llm_output.text
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
usage = llm_output.raw.get("usage")
yield llm_output
if usage:
billing_center_params.usage = usage
return
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
llm_output)
async def async_billing_tokens_from_stream(self, generator: AsyncIterator[LLMOutput],
billing_center_params: BillingCenterInfo):
content = ""
usage = None
async for llm_output in generator:
content += llm_output.text
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
usage = llm_output.raw.get("usage")
yield llm_output
if usage:
billing_center_params.usage = usage
return
llm_output = LLMOutput(
text=content,
raw={}
)
billing_center_params.output = llm_output.model_dump()
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
llm_output)
def update_billing_center_params(self, params: BillingCenterInfo, input: dict):
params.base_url = self.api_base
return super().update_billing_center_params(params, input)

View File

@@ -76,5 +76,16 @@ dir = './monitor'
[EXTENSION_MODULES]
class_list = [
'${ROOT_PACKAGE}.config.config_extension.ConfigExtension',
'${ROOT_PACKAGE}.config.yaml_func_extension.YamlFuncExtension'
'${ROOT_PACKAGE}.config.yaml_func_extension.YamlFuncExtension',
'agentuniverse_ant_ext.manager.billing_center.billing_center_manager.BillingCenterManager'
]
[BILLING_CENTER]
use_billing_center = true
billing_center_url = "http://localhost:8888/v1"
billing_type = "proxy"
[PLUGINS]
llm_plugins = [
"agentuniverse_ant_ext.manager.billing_center.billing_center.trace_billing"
]

View File

@@ -35,7 +35,7 @@ profile:
#
# Note: The current configuration uses Way 1, as shown below
llm_model:
name: 'deepseek-chat'
name: 'qwen2.5-72b-instruct'
action:
# Please select the tools and knowledge base.
tool:

View File

@@ -36,7 +36,7 @@ class DemoAgentTest(unittest.TestCase):
output_stream = queue.Queue(10)
instance: Agent = AgentManager().get_instance_obj('demo_agent')
Thread(target=self.read_output, args=(output_stream,)).start()
result = instance.run(input='你来自哪里,名字是什么,请详细介绍一下数据库', output_stream=output_stream)
result = instance.run(input='你来自哪里,名字是什么,请详细介绍一下数据库', output_stream=output_stream,scene_code="billing_center_test")
print(result)