mirror of
https://github.com/agentuniverse-ai/agentUniverse.git
synced 2026-02-09 01:59:19 +08:00
修复channel问题
This commit is contained in:
@@ -105,7 +105,7 @@ class Agent(ComponentBase, ABC):
|
||||
|
||||
scene_code = input_object.get_data("scene_code")
|
||||
if scene_code:
|
||||
AuTraceManager().trace_context.set_scene_code(scene_code)
|
||||
FrameworkContextManager().set_context("scene_code", scene_code)
|
||||
|
||||
@trace_agent
|
||||
def run(self, **kwargs) -> OutputObject:
|
||||
|
||||
@@ -8,9 +8,6 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
|
||||
from agentuniverse.base.util.billing_center import BillingCenterInfo
|
||||
|
||||
|
||||
class AgentModel(BaseModel):
|
||||
"""The parent class of all agent models, containing only attributes."""
|
||||
@@ -22,10 +19,6 @@ class AgentModel(BaseModel):
|
||||
action: Optional[dict] = dict()
|
||||
work_pattern: Optional[dict] = dict()
|
||||
|
||||
def billing_center_params(self):
|
||||
billing_center_info = BillingCenterInfo(agent_id=self.info.get("name"))
|
||||
return billing_center_info
|
||||
|
||||
def llm_params(self) -> dict:
|
||||
"""
|
||||
Returns:
|
||||
@@ -39,6 +32,4 @@ class AgentModel(BaseModel):
|
||||
params['model'] = value
|
||||
else:
|
||||
params[key] = value
|
||||
if ApplicationConfigManager().app_configer.use_billing_center:
|
||||
params["billing_center_params"] = self.billing_center_params()
|
||||
return params
|
||||
|
||||
@@ -14,18 +14,24 @@ import uuid
|
||||
from functools import wraps
|
||||
|
||||
from agentuniverse.agent.memory.conversation_memory.conversation_memory_module import ConversationMemoryModule
|
||||
from agentuniverse.base.util.billing_center import trace_billing
|
||||
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
|
||||
from agentuniverse.base.util.monitor.monitor import Monitor
|
||||
from agentuniverse.llm.llm_output import LLMOutput
|
||||
|
||||
|
||||
def trace_llm(func):
|
||||
"""Annotation: @trace_llm
|
||||
|
||||
Decorator to trace the LLM invocation, add llm input and output to the monitor.
|
||||
"""
|
||||
def plugins(func):
|
||||
llm_plugins = ApplicationConfigManager().app_configer.llm_plugins
|
||||
warp_func = func
|
||||
for item in llm_plugins:
|
||||
warp_func = item(func)
|
||||
return warp_func
|
||||
|
||||
@wraps(func)
|
||||
@trace_billing
|
||||
async def wrapper_async(*args, **kwargs):
|
||||
# get llm input from arguments
|
||||
llm_input = _get_input(func, *args, **kwargs)
|
||||
@@ -42,7 +48,7 @@ def trace_llm(func):
|
||||
|
||||
if self and hasattr(self, 'tracing'):
|
||||
if self.tracing is False:
|
||||
return await func(*args, **kwargs)
|
||||
return await plugins(func)(*args, **kwargs)
|
||||
|
||||
# add invocation chain to the monitor module.
|
||||
Monitor.add_invocation_chain({'source': source, 'type': 'llm'})
|
||||
@@ -51,7 +57,7 @@ def trace_llm(func):
|
||||
Monitor().trace_llm_input(source=source, llm_input=llm_input)
|
||||
|
||||
# invoke function
|
||||
result = await func(*args, **kwargs)
|
||||
result = await plugins(func)(*args, **kwargs)
|
||||
# not streaming
|
||||
if isinstance(result, LLMOutput):
|
||||
# add llm invocation info to monitor
|
||||
@@ -80,7 +86,6 @@ def trace_llm(func):
|
||||
return gen_iterator()
|
||||
|
||||
@functools.wraps(func)
|
||||
@trace_billing
|
||||
def wrapper_sync(*args, **kwargs):
|
||||
# get llm input from arguments
|
||||
llm_input = _get_input(func, *args, **kwargs)
|
||||
@@ -97,7 +102,7 @@ def trace_llm(func):
|
||||
|
||||
if self and hasattr(self, 'tracing'):
|
||||
if self.tracing is False:
|
||||
return func(*args, **kwargs)
|
||||
return plugins(func)(*args, **kwargs)
|
||||
|
||||
# add invocation chain to the monitor module.
|
||||
Monitor.add_invocation_chain({'source': source, 'type': 'llm'})
|
||||
@@ -105,7 +110,7 @@ def trace_llm(func):
|
||||
start_time = time.time()
|
||||
Monitor().trace_llm_input(source=source, llm_input=llm_input)
|
||||
# invoke function
|
||||
result = func(*args, **kwargs)
|
||||
result = plugins(func)(*args, **kwargs)
|
||||
# not streaming
|
||||
if isinstance(result, LLMOutput):
|
||||
# add llm invocation info to monitor
|
||||
@@ -135,12 +140,12 @@ def trace_llm(func):
|
||||
return gen_iterator()
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# async function
|
||||
|
||||
return wrapper_async
|
||||
else:
|
||||
# sync function
|
||||
return wrapper_sync
|
||||
|
||||
|
||||
def get_caller_info(instance: object = None):
|
||||
source_list = Monitor.get_invocation_chain()
|
||||
if len(source_list) > 0:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# !/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import importlib
|
||||
# @Time : 2024/3/12 16:17
|
||||
# @Author : jerry.zzw
|
||||
# @Email : jerry.zzw@antgroup.com
|
||||
# @FileName: app_configer.py
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from agentuniverse.base.config.component_configer.configers.llm_configer import LLMConfiger
|
||||
from agentuniverse.base.config.component_configer.configers.tool_configer import ToolConfiger
|
||||
@@ -51,9 +51,7 @@ class AppConfiger(object):
|
||||
self.__llm_configer_map: Dict[str, LLMConfiger] = {}
|
||||
self.__agent_llm_set: Optional[set[str]] = set()
|
||||
self.__agent_tool_set: Optional[set[str]] = set()
|
||||
self.__billing_center: Optional[dict] = {}
|
||||
self.__use_billing_center: Optional[bool] = False
|
||||
self.__billing_center_url: Optional[bool] = None
|
||||
self.__llm_plugins: Optional[Any] = set()
|
||||
|
||||
@property
|
||||
def base_info_appname(self) -> Optional[str]:
|
||||
@@ -229,16 +227,17 @@ class AppConfiger(object):
|
||||
self.__agent_tool_set = value
|
||||
|
||||
@property
|
||||
def billing_center(self):
|
||||
return self.__billing_center
|
||||
def llm_plugins(self):
|
||||
return self.__llm_plugins
|
||||
|
||||
@property
|
||||
def billing_center_url(self):
|
||||
return self.__billing_center_url
|
||||
|
||||
@property
|
||||
def use_billing_center(self):
|
||||
return self.__use_billing_center
|
||||
@classmethod
|
||||
def load_llm_plugins(cls, plugin_modules):
|
||||
funcs = []
|
||||
for item in plugin_modules:
|
||||
module_name, func_name = item.rsplit('.', 1)
|
||||
module = importlib.import_module(module_name)
|
||||
funcs.append(getattr(module, func_name))
|
||||
return funcs
|
||||
|
||||
def load_by_configer(self, configer: Configer) -> 'AppConfiger':
|
||||
"""Load the AppConfiger by the given Configer.
|
||||
@@ -275,7 +274,5 @@ class AppConfiger(object):
|
||||
self.__core_log_sink_package_list = configer.value.get('CORE_PACKAGE', {}).get('log_sink')
|
||||
self.__core_llm_channel_package_list = configer.value.get('CORE_PACKAGE', {}).get('llm_channel')
|
||||
self.__conversation_memory_configer = configer.value.get('CONVERSATION_MEMORY', {})
|
||||
self.__billing_center = configer.value.get('BILLING_CENTER', {})
|
||||
self.__use_billing_center = self.billing_center.get("use_billing_center")
|
||||
self.__billing_center_url = self.billing_center.get("billing_center_url")
|
||||
self.__llm_plugins = self.load_llm_plugins(configer.value.get("PLUGINS", {}).get("llm_plugins", []))
|
||||
return self
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Optional, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from agentuniverse.base.util.logging.logging_util import LOGGER
|
||||
from agentuniverse.llm.llm_output import LLMOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
|
||||
from agentuniverse.base.util.monitor.monitor import Monitor
|
||||
from agentuniverse.base.util.tracing.au_trace_manager import AuTraceManager
|
||||
|
||||
client = httpx.Client(timeout=60)
|
||||
thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
|
||||
class BillingCenterInfo(BaseModel):
|
||||
app_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
trace_id: Optional[str] = None
|
||||
scene_code: Optional[str] = None
|
||||
base_url: Optional[str] = ""
|
||||
model: Optional[str] = None
|
||||
usage: Optional[dict] = None
|
||||
input: Optional[dict] = None
|
||||
output: Optional[dict] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
params = kwargs
|
||||
app_id = ApplicationConfigManager().app_configer.base_info_appname
|
||||
params['app_id'] = app_id
|
||||
agent_id = kwargs.get("agent_id", None)
|
||||
if not agent_id:
|
||||
agent_id = self._get_caller_agent()
|
||||
params['agent_id'] = agent_id
|
||||
if not kwargs.get("trace_id"):
|
||||
trace_id = AuTraceManager().get_trace_id()
|
||||
params['trace_id'] = trace_id
|
||||
if not kwargs.get("session_id"):
|
||||
session_id = AuTraceManager().get_session_id()
|
||||
params['session_id'] = session_id
|
||||
if not kwargs.get("scene_code"):
|
||||
scene_code = AuTraceManager().get_scene_code()
|
||||
params["scene_code"] = scene_code
|
||||
super().__init__(**params)
|
||||
|
||||
@staticmethod
|
||||
def _get_caller_agent():
|
||||
source_list = Monitor.get_invocation_chain()
|
||||
if len(source_list) > 0:
|
||||
# 逆序遍历
|
||||
for item in reversed(source_list):
|
||||
if item.get("type") == "agent":
|
||||
return item.get("source", None)
|
||||
return "unknown"
|
||||
|
||||
def push_billing_center_info(self):
|
||||
params = self.model_dump()
|
||||
LOGGER.info(json.dumps(params, ensure_ascii=False, indent=4))
|
||||
endpoint = ApplicationConfigManager().app_configer.billing_center_url + "/billing/token/usage"
|
||||
response = client.post(
|
||||
endpoint, json=params, headers={
|
||||
"content-type": "application/json"
|
||||
}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
LOGGER.error(f"Push billing center info error {params}")
|
||||
else:
|
||||
LOGGER.info(
|
||||
f"Push billing center success {json.dumps(params, ensure_ascii=False, indent=4)}")
|
||||
|
||||
|
||||
def trace_billing(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper_async(*args, **kwargs):
|
||||
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
|
||||
result = await func(*args, **kwargs)
|
||||
if not ApplicationConfigManager().app_configer.use_billing_center:
|
||||
return result
|
||||
if billing_type == "proxy":
|
||||
return result
|
||||
billing_center_params = kwargs.pop("billing_center_params", None)
|
||||
if not billing_center_params:
|
||||
billing_center_params = BillingCenterInfo()
|
||||
llm = args[0]
|
||||
llm.update_billing_center_params(billing_center_params, kwargs)
|
||||
if isinstance(result, LLMOutput):
|
||||
billing_center_params.usage = llm.get_billing_tokens(kwargs, result)
|
||||
billing_center_params.output = result.model_dump()
|
||||
LOGGER.info(f"Billing info {billing_center_params.usage}")
|
||||
thread_pool.submit(billing_center_params.push_billing_center_info)
|
||||
else:
|
||||
async def generator():
|
||||
async for item in llm.async_billing_tokens_from_stream(result, billing_center_params):
|
||||
yield item
|
||||
LOGGER.info(f"Billing info {billing_center_params.usage}")
|
||||
thread_pool.submit(billing_center_params.push_billing_center_info)
|
||||
|
||||
return generator()
|
||||
return result
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper_sync(*args, **kwargs):
|
||||
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
|
||||
result = func(*args, **kwargs)
|
||||
if not ApplicationConfigManager().app_configer.use_billing_center:
|
||||
return result
|
||||
if billing_type == "proxy":
|
||||
return result
|
||||
billing_center_params = kwargs.pop("billing_center_params", None)
|
||||
if not billing_center_params:
|
||||
billing_center_params = BillingCenterInfo()
|
||||
llm = args[0]
|
||||
llm.update_billing_center_params(billing_center_params, kwargs)
|
||||
if isinstance(result, LLMOutput):
|
||||
billing_center_params.usage = llm.get_billing_tokens(kwargs, result)
|
||||
billing_center_params.output = result.model_dump()
|
||||
LOGGER.info(f"Billing info {billing_center_params.usage}")
|
||||
thread_pool.submit(billing_center_params.push_billing_center_info)
|
||||
return result
|
||||
else:
|
||||
def generator():
|
||||
for item in llm.billing_tokens_from_stream(result, billing_center_params):
|
||||
yield item
|
||||
LOGGER.info(f"Billing info {billing_center_params.usage}")
|
||||
thread_pool.submit(billing_center_params.push_billing_center_info)
|
||||
|
||||
return generator()
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# async function
|
||||
return wrapper_async
|
||||
else:
|
||||
# sync function
|
||||
return wrapper_sync
|
||||
|
||||
|
||||
class BillingCenter(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_base_url(cls, base_url: str):
|
||||
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
|
||||
if ApplicationConfigManager().app_configer.use_billing_center and billing_type == "proxy":
|
||||
return ApplicationConfigManager().app_configer.billing_center_url
|
||||
return base_url
|
||||
|
||||
@classmethod
|
||||
def billing_center_openai_channel_headers(cls, llm, params_map: dict[str, Any]):
|
||||
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
|
||||
if not ApplicationConfigManager().app_configer.use_billing_center:
|
||||
return llm.channel_ext_headers, params_map
|
||||
if billing_type == "push":
|
||||
params_map.pop("billing_center_params")
|
||||
return llm.channel_ext_headers, params_map
|
||||
billing_center_params: BillingCenterInfo = params_map.pop("billing_center_params")
|
||||
extra_headers = billing_center_params.model_dump(
|
||||
include={'agent_id', 'app_id', 'scene_code', 'session_id', 'trace_id'}
|
||||
)
|
||||
extra_headers["OriginalUrl"] = llm.channel_api_base
|
||||
if llm.channel_proxy:
|
||||
extra_headers["proxy"] = llm.channel_proxy if llm.channel_proxy else ""
|
||||
return {**llm.channel_ext_headers, **extra_headers}, params_map
|
||||
|
||||
@classmethod
|
||||
def billing_center_openai_headers(cls, llm, params_map: dict[str, Any]):
|
||||
billing_type = ApplicationConfigManager().app_configer.billing_center.get("billing_type")
|
||||
if not ApplicationConfigManager().app_configer.use_billing_center:
|
||||
return llm.ext_headers, params_map
|
||||
if billing_type == "push":
|
||||
params_map.pop("billing_center_params")
|
||||
return llm.ext_headers, params_map
|
||||
billing_center_params: BillingCenterInfo = params_map.pop("billing_center_params")
|
||||
extra_headers = billing_center_params.model_dump(
|
||||
include={'agent_id', 'app_id', 'scene_code', 'session_id', 'trace_id'}
|
||||
)
|
||||
extra_headers["OriginalUrl"] = llm.api_base
|
||||
if llm.proxy:
|
||||
extra_headers["proxy"] = llm.proxy
|
||||
return {**llm.ext_headers, **extra_headers}, params_map
|
||||
0
agentuniverse/base/util/manager/__init__.py
Normal file
0
agentuniverse/base/util/manager/__init__.py
Normal file
79
agentuniverse/base/util/manager/billing_center.py
Normal file
79
agentuniverse/base/util/manager/billing_center.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import asyncio
|
||||
import functools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from agentuniverse.base.util.logging.logging_util import LOGGER
|
||||
from agentuniverse.base.util.manager.billing_center_manager import BillingCenterManager, BillingCenter, \
|
||||
BillingCenterInfo
|
||||
from agentuniverse.llm.llm_output import LLMOutput
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=5)
|
||||
|
||||
def trace_billing(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper_async(*args, **kwargs):
|
||||
billing_type = BillingCenterManager().billing_center_type
|
||||
if not BillingCenterManager().use_billing_center:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
billing_center_info = BillingCenterInfo()
|
||||
if billing_type == "proxy":
|
||||
extra_headers = BillingCenter.llm_headers(args[0], billing_center_info)
|
||||
api_base = BillingCenterManager().billing_center_url
|
||||
kwargs['extra_headers'] = extra_headers
|
||||
kwargs['api_base'] = api_base
|
||||
return func(*args, **kwargs)
|
||||
llm = args[0]
|
||||
BillingCenter.update_billing_center_params(llm, billing_center_info, kwargs)
|
||||
result = func(*args, **kwargs)
|
||||
if isinstance(result, LLMOutput):
|
||||
billing_center_info.usage = BillingCenter.get_billing_tokens(kwargs, result)
|
||||
billing_center_info.output = result.model_dump()
|
||||
LOGGER.info(f"Billing info {billing_center_info.usage}")
|
||||
thread_pool.submit(billing_center_info.push_billing_center_info)
|
||||
else:
|
||||
async def generator():
|
||||
async for item in BillingCenter.async_billing_tokens_from_stream(llm, result, billing_center_info):
|
||||
yield item
|
||||
LOGGER.info(f"Billing info {billing_center_info.usage}")
|
||||
thread_pool.submit(billing_center_info.push_billing_center_info)
|
||||
|
||||
return generator()
|
||||
return result
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper_sync(*args, **kwargs):
|
||||
billing_type = BillingCenterManager().billing_center_type
|
||||
if not BillingCenterManager().use_billing_center:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
billing_center_info = BillingCenterInfo()
|
||||
if billing_type == "proxy":
|
||||
extra_headers = BillingCenter.llm_headers(args[0], billing_center_info)
|
||||
api_base = BillingCenterManager().billing_center_url
|
||||
kwargs['extra_headers'] = extra_headers
|
||||
kwargs['api_base'] = api_base
|
||||
return func(*args, **kwargs)
|
||||
result = func(*args, **kwargs)
|
||||
llm = args[0]
|
||||
BillingCenter.update_billing_center_params(llm, billing_center_info, kwargs)
|
||||
if isinstance(result, LLMOutput):
|
||||
billing_center_info.usage = BillingCenter.get_billing_tokens(llm, kwargs, result)
|
||||
billing_center_info.output = result.model_dump()
|
||||
LOGGER.info(f"Billing info {billing_center_info.usage}")
|
||||
thread_pool.submit(billing_center_info.push_billing_center_info)
|
||||
return result
|
||||
else:
|
||||
def generator():
|
||||
for item in BillingCenter.billing_tokens_from_stream(llm, result, billing_center_info):
|
||||
yield item
|
||||
LOGGER.info(f"Billing info {billing_center_info.usage}")
|
||||
thread_pool.submit(billing_center_info.push_billing_center_info)
|
||||
|
||||
return generator()
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# async function
|
||||
return wrapper_async
|
||||
else:
|
||||
# sync function
|
||||
return wrapper_sync
|
||||
228
agentuniverse/base/util/manager/billing_center_manager.py
Normal file
228
agentuniverse/base/util/manager/billing_center_manager.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# !/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
# @Time : 2025/5/9 09:32
|
||||
# @Author : weizjajj
|
||||
# @Email : weizhongjie.wzj@antgroup.com
|
||||
# @FileName: billing_center_manager.py
|
||||
import json
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Iterator, AsyncIterator, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
|
||||
from agentuniverse.base.context.framework_context_manager import FrameworkContextManager
|
||||
from agentuniverse.base.util.logging.logging_util import LOGGER
|
||||
from agentuniverse.base.util.monitor.monitor import Monitor
|
||||
from agentuniverse.base.util.tracing.au_trace_manager import AuTraceManager
|
||||
from agentuniverse.llm.llm_channel.llm_channel import LLMChannel
|
||||
from agentuniverse.llm.openai_style_llm import OpenAIStyleLLM
|
||||
|
||||
from agentuniverse.llm.llm_output import LLMOutput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentuniverse.base.annotation.singleton import singleton
|
||||
|
||||
client = httpx.Client(timeout=60)
|
||||
|
||||
|
||||
class BillingCenterInfo(BaseModel):
|
||||
app_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
trace_id: Optional[str] = None
|
||||
scene_code: Optional[str] = None
|
||||
base_url: Optional[str] = ""
|
||||
model: Optional[str] = None
|
||||
usage: Optional[dict] = None
|
||||
input: Optional[dict] = None
|
||||
output: Optional[dict] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
params = kwargs
|
||||
app_id = ApplicationConfigManager().app_configer.base_info_appname
|
||||
params['app_id'] = app_id
|
||||
agent_id = kwargs.get("agent_id", None)
|
||||
if not agent_id:
|
||||
agent_id = self._get_caller_agent()
|
||||
params['agent_id'] = agent_id
|
||||
if not kwargs.get("trace_id"):
|
||||
trace_id = AuTraceManager().get_trace_id()
|
||||
params['trace_id'] = trace_id
|
||||
if not kwargs.get("session_id"):
|
||||
session_id = AuTraceManager().get_session_id()
|
||||
params['session_id'] = session_id
|
||||
if not kwargs.get("scene_code"):
|
||||
scene_code = FrameworkContextManager().get_context("scene_code")
|
||||
params["scene_code"] = scene_code
|
||||
super().__init__(**params)
|
||||
|
||||
@staticmethod
|
||||
def _get_caller_agent():
|
||||
source_list = Monitor.get_invocation_chain()
|
||||
if len(source_list) > 0:
|
||||
# 逆序遍历
|
||||
for item in reversed(source_list):
|
||||
if item.get("type") == "agent":
|
||||
return item.get("source", None)
|
||||
return "unknown"
|
||||
|
||||
def push_billing_center_info(self):
|
||||
params = self.model_dump()
|
||||
LOGGER.info(json.dumps(params, ensure_ascii=False, indent=4))
|
||||
endpoint = BillingCenterManager().billing_center_url + "/billing/token/usage"
|
||||
response = client.post(
|
||||
endpoint, json=params, headers={
|
||||
"content-type": "application/json"
|
||||
}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
LOGGER.error(f"Push billing center info error {params}")
|
||||
else:
|
||||
LOGGER.info(
|
||||
f"Push billing center success {json.dumps(params, ensure_ascii=False, indent=4)}")
|
||||
|
||||
|
||||
@singleton
|
||||
class BillingCenterManager:
|
||||
def __init__(self, configer=None):
|
||||
billing_center_info = configer.value.get("BILLING_CENTER")
|
||||
self.use_billing_center = billing_center_info.get("use_billing_center")
|
||||
self.billing_center_url = billing_center_info.get("billing_center_url")
|
||||
self.billing_center_type = billing_center_info.get("billing_type")
|
||||
|
||||
|
||||
class BillingCenter(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def llm_headers(cls, llm, billing_center_info: BillingCenterInfo):
|
||||
extra_headers = None
|
||||
if hasattr(llm, "ext_headers"):
|
||||
extra_headers = llm.ext_headers
|
||||
elif hasattr(llm, "headers"):
|
||||
extra_headers = llm.headers
|
||||
if not extra_headers:
|
||||
extra_headers = {
|
||||
"content-type": "application/json"
|
||||
}
|
||||
extra_headers = extra_headers.copy()
|
||||
keys = ['agent_id', 'app_id', 'scene_code', 'session_id', 'trace_id']
|
||||
billing_center_info_dict = billing_center_info.model_dump()
|
||||
for key in keys:
|
||||
extra_headers[key] = billing_center_info_dict.get(key, "") if billing_center_info_dict.get(key) else ""
|
||||
if isinstance(llm, OpenAIStyleLLM):
|
||||
extra_headers["base_url"] = llm.api_base
|
||||
if llm.proxy:
|
||||
extra_headers["proxy"] = llm.proxy
|
||||
return extra_headers
|
||||
elif isinstance(llm, LLMChannel):
|
||||
extra_headers["base_url"] = llm.channel_api_base
|
||||
if llm.channel_proxy:
|
||||
extra_headers["proxy"] = llm.channel_proxy
|
||||
else:
|
||||
extra_headers["base_url"] = llm.model_name
|
||||
return extra_headers
|
||||
|
||||
@classmethod
|
||||
def get_base_url(cls, base_url: str):
|
||||
billing_type = BillingCenterManager().billing_center_type
|
||||
if BillingCenterManager().use_billing_center and billing_type == "proxy":
|
||||
return BillingCenterManager().billing_center_url
|
||||
return base_url
|
||||
|
||||
@classmethod
|
||||
def billing_tokens_from_stream(cls, llm, generator: Iterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
usage = None
|
||||
for llm_output in generator:
|
||||
content += llm_output.text
|
||||
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
|
||||
usage = llm_output.raw.get("usage")
|
||||
yield llm_output
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
if usage:
|
||||
billing_center_params.usage = usage
|
||||
return
|
||||
|
||||
billing_center_params.usage = cls.get_billing_tokens(llm, billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
@classmethod
|
||||
async def async_billing_tokens_from_stream(cls, llm, generator: AsyncIterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
usage = None
|
||||
async for llm_output in generator:
|
||||
content += llm_output.text
|
||||
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
|
||||
usage = llm_output.raw.get("usage")
|
||||
yield llm_output
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
if usage:
|
||||
billing_center_params.usage = usage
|
||||
return
|
||||
billing_center_params.usage = cls.get_billing_tokens(llm, billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
@classmethod
|
||||
def update_billing_center_params(cls, llm, params: BillingCenterInfo, input: dict) -> BillingCenterInfo:
|
||||
if "model" in input:
|
||||
params.model = input.get("model")
|
||||
else:
|
||||
params.model = llm.model_name
|
||||
|
||||
if isinstance(llm, LLMChannel):
|
||||
params.base_url = llm.channel_api_base
|
||||
elif isinstance(llm, OpenAIStyleLLM):
|
||||
params.base_url = llm.api_base
|
||||
elif getattr(llm, "service_id"):
|
||||
params.base_url = llm.service_id
|
||||
elif getattr(llm, "serviceId"):
|
||||
params.base_url = llm.serviceId
|
||||
elif getattr(llm, "endpoint"):
|
||||
params.base_url = llm.endpoint
|
||||
else:
|
||||
params = llm.model_name
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def _get_billing_tokens(cls, llm, input: dict, output: LLMOutput) -> dict:
|
||||
text = ""
|
||||
if "messages" in input:
|
||||
messages = input.get("messages")
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text += item.get("content")
|
||||
elif "prompt" in input:
|
||||
text = str(input.get("prompt"))
|
||||
prompt_tokens = llm.get_num_tokens(text)
|
||||
output = output.text
|
||||
completion_tokens = llm.get_num_tokens(output)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_billing_tokens(cls, llm, input: dict, output: LLMOutput):
|
||||
output_json = output.raw
|
||||
if "usage" in output_json and output_json['usage'].get("prompt_tokens", 0) > 0:
|
||||
return output_json['usage']
|
||||
return cls._get_billing_tokens(llm, input, output)
|
||||
@@ -65,6 +65,7 @@ class AuTraceContext:
|
||||
self._span_id_counter += 1
|
||||
return child_span_id
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
|
||||
@@ -18,7 +18,6 @@ from agentuniverse.base.component.component_base import ComponentBase
|
||||
from agentuniverse.base.component.component_enum import ComponentEnum
|
||||
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
|
||||
from agentuniverse.base.config.component_configer.configers.llm_configer import LLMConfiger
|
||||
from agentuniverse.base.util.billing_center import BillingCenterInfo
|
||||
from agentuniverse.base.util.logging.logging_util import LOGGER
|
||||
from agentuniverse.llm.llm_channel.llm_channel import LLMChannel
|
||||
from agentuniverse.llm.llm_channel.llm_channel_manager import LLMChannelManager
|
||||
@@ -186,15 +185,10 @@ class LLM(ComponentBase):
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
|
||||
def get_num_tokens_from_model(self, text: str, model=None):
|
||||
if model:
|
||||
return self.get_num_tokens(text)
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
LOGGER.error("get_num_tokens_from_model error")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(text))
|
||||
|
||||
def as_langchain_runnable(self, params=None) -> Runnable:
|
||||
@@ -236,63 +230,3 @@ class LLM(ComponentBase):
|
||||
copied.async_client = self.async_client
|
||||
copied.langchain_instance = self.langchain_instance
|
||||
return copied
|
||||
|
||||
def get_billing_tokens(self, input: dict, output: LLMOutput) -> dict:
|
||||
text = ""
|
||||
if "messages" in input:
|
||||
messages = input.get("messages")
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
content += item.get("content")
|
||||
elif "prompt" in input:
|
||||
text = str(input.get("prompt"))
|
||||
prompt_tokens = self.get_num_tokens_from_model(text, input.get("model"))
|
||||
output = output.text
|
||||
completion_tokens = self.get_num_tokens_from_model(output, input.get("model"))
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
def update_billing_center_params(self, params: BillingCenterInfo, input: dict) -> BillingCenterInfo:
|
||||
if "model" in input:
|
||||
params.model = input.get("model")
|
||||
else:
|
||||
params.model = self.model_name
|
||||
params.input = input
|
||||
return params
|
||||
|
||||
def billing_tokens_from_stream(self, generator: Iterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
for llm_output in generator:
|
||||
content += llm_output.text
|
||||
yield llm_output
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
async def async_billing_tokens_from_stream(self, generator: AsyncIterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
async for llm_output in generator:
|
||||
content += llm_output.text
|
||||
yield llm_output
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
@@ -19,7 +19,6 @@ from agentuniverse.base.component.component_base import ComponentBase
|
||||
from agentuniverse.base.component.component_enum import ComponentEnum
|
||||
from agentuniverse.base.config.application_configer.application_config_manager import ApplicationConfigManager
|
||||
from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger
|
||||
from agentuniverse.base.util.billing_center import BillingCenter, BillingCenterInfo
|
||||
from agentuniverse.llm.llm_channel.langchain_instance.default_channel_langchain_instance import \
|
||||
DefaultChannelLangchainInstance
|
||||
from agentuniverse.llm.llm_output import LLMOutput
|
||||
@@ -33,8 +32,8 @@ class LLMChannel(ComponentBase):
|
||||
channel_proxy: Optional[str] = None
|
||||
channel_model_name: Optional[str] = None
|
||||
channel_ext_info: Optional[dict] = None
|
||||
channel_ext_headers: Optional[dict] = {}
|
||||
channel_ext_params: Optional[dict] = {}
|
||||
ext_headers: Optional[dict] = {}
|
||||
ext_params: Optional[dict] = {}
|
||||
|
||||
model_support_stream: Optional[bool] = None
|
||||
model_support_max_context_length: Optional[int] = None
|
||||
@@ -71,15 +70,15 @@ class LLMChannel(ComponentBase):
|
||||
self.model_support_max_tokens = component_configer.model_support_max_tokens
|
||||
if hasattr(component_configer, "model_is_openai_protocol_compatible"):
|
||||
self.model_is_openai_protocol_compatible = component_configer.model_is_openai_protocol_compatible
|
||||
if component_configer.configer.value.get("extra_headers"):
|
||||
self.channel_ext_headers = component_configer.configer.value.get("extra_headers", {})
|
||||
if component_configer.configer.value.get("extra_params"):
|
||||
self.channel_ext_params = component_configer.configer.value.get("extra_params", {})
|
||||
self.channel_ext_params["stream_options"] = {
|
||||
if component_configer.configer.value.get("ext_headers"):
|
||||
self.ext_headers = component_configer.configer.value.get("extra_headers", {})
|
||||
if component_configer.configer.value.get("ext_params"):
|
||||
self.ext_params = component_configer.configer.value.get("extra_params", {})
|
||||
self.ext_params["stream_options"] = {
|
||||
"include_usage": True
|
||||
}
|
||||
else:
|
||||
self.channel_ext_params = {
|
||||
self.ext_params = {
|
||||
"stream_options": {
|
||||
"include_usage": True
|
||||
}
|
||||
@@ -129,25 +128,24 @@ class LLMChannel(ComponentBase):
|
||||
streaming = kwargs.pop('stream')
|
||||
if self.model_support_stream is False and streaming is True:
|
||||
streaming = False
|
||||
extra_headers, kwargs = BillingCenter.billing_center_openai_channel_headers(self, kwargs)
|
||||
support_max_tokens = self.model_support_max_tokens
|
||||
max_tokens = kwargs.pop('max_tokens', None) or self.channel_model_config.get('max_tokens',
|
||||
None) or support_max_tokens
|
||||
if support_max_tokens:
|
||||
max_tokens = min(support_max_tokens, max_tokens)
|
||||
|
||||
ext_params = self.channel_ext_params.copy()
|
||||
ext_params = self.ext_params.copy()
|
||||
if not streaming:
|
||||
ext_params.pop("stream_options")
|
||||
self.client = self._new_client()
|
||||
ext_params.pop("stream_options", "")
|
||||
self.client = self._new_client(kwargs.pop("api_base", None))
|
||||
chat_completion = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=kwargs.pop('model', self.channel_model_name),
|
||||
temperature=kwargs.pop('temperature', self.channel_model_config.get('temperature')),
|
||||
stream=kwargs.pop('stream', streaming),
|
||||
max_tokens=max_tokens,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=ext_params,
|
||||
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
|
||||
**kwargs,
|
||||
)
|
||||
if not streaming:
|
||||
@@ -172,16 +170,14 @@ class LLMChannel(ComponentBase):
|
||||
ext_params = self.channel_ext_params.copy()
|
||||
if not streaming:
|
||||
ext_params.pop("stream_options")
|
||||
|
||||
extra_headers, kwargs = BillingCenter.billing_center_openai_channel_headers(self, kwargs)
|
||||
self.async_client = self._new_async_client()
|
||||
self.async_client = self._new_async_client(kwargs.pop("api_base", None))
|
||||
chat_completion = await self.async_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=kwargs.pop('model', self.channel_model_name),
|
||||
temperature=kwargs.pop('temperature', self.channel_model_config.get('temperature')),
|
||||
stream=kwargs.pop('stream', streaming),
|
||||
max_tokens=max_tokens,
|
||||
extra_headers=extra_headers,
|
||||
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
|
||||
extra_body=ext_params,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -214,28 +210,28 @@ class LLMChannel(ComponentBase):
|
||||
def max_context_length(self) -> int:
|
||||
return self.channel_model_config.get('max_context_length')
|
||||
|
||||
def _new_client(self):
|
||||
def _new_client(self, api_base: str = None):
|
||||
"""Initialize the openai client."""
|
||||
if self.client is not None:
|
||||
return self.client
|
||||
return OpenAI(
|
||||
api_key=self.channel_api_key,
|
||||
organization=self.channel_organization,
|
||||
base_url=BillingCenter.get_base_url(self.channel_api_base),
|
||||
base_url=api_base if api_base else self.channel_api_base,
|
||||
timeout=self.channel_model_config.get('request_timeout'),
|
||||
max_retries=self.channel_model_config.get('max_retries'),
|
||||
http_client=httpx.Client(proxy=self.channel_proxy) if self.channel_proxy else None,
|
||||
**(self.channel_model_config.get('client_args') or {}),
|
||||
)
|
||||
|
||||
def _new_async_client(self):
|
||||
def _new_async_client(self, api_base=None):
|
||||
"""Initialize the openai async client."""
|
||||
if self.async_client is not None:
|
||||
return self.async_client
|
||||
return AsyncOpenAI(
|
||||
api_key=self.channel_api_key,
|
||||
organization=self.channel_organization,
|
||||
base_url=BillingCenter.get_base_url(self.channel_api_base),
|
||||
base_url=api_base if api_base else self.channel_api_base,
|
||||
timeout=self.channel_model_config.get('request_timeout'),
|
||||
max_retries=self.channel_model_config.get('max_retries'),
|
||||
http_client=httpx.AsyncClient(proxy=self.channel_proxy) if self.channel_proxy else None,
|
||||
@@ -276,77 +272,3 @@ class LLMChannel(ComponentBase):
|
||||
"""Return the full name of the component."""
|
||||
appname = ApplicationConfigManager().app_configer.base_info_appname
|
||||
return f'{appname}.{self.component_type.value.lower()}.{self.channel_name}'
|
||||
|
||||
def billing_tokens_from_stream(self, generator: Iterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
usage = None
|
||||
for llm_output in generator:
|
||||
content += llm_output.text
|
||||
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
|
||||
usage = llm_output.raw.get("usage")
|
||||
yield llm_output
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
if usage:
|
||||
billing_center_params.usage = usage
|
||||
return
|
||||
|
||||
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
async def async_billing_tokens_from_stream(self, generator: AsyncIterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
usage = None
|
||||
async for llm_output in generator:
|
||||
content += llm_output.text
|
||||
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
|
||||
usage = llm_output.raw.get("usage")
|
||||
yield llm_output
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
if usage:
|
||||
billing_center_params.usage = usage
|
||||
return
|
||||
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
def update_billing_center_params(self, params: BillingCenterInfo, input: dict) -> BillingCenterInfo:
|
||||
if "model" in input:
|
||||
params.model = input.get("model")
|
||||
else:
|
||||
params.model = self.model_name
|
||||
params.base_url = self.channel_api_base
|
||||
params.input = input
|
||||
return params
|
||||
|
||||
def get_billing_tokens(self, input: dict, output: LLMOutput):
|
||||
output_json = output.raw
|
||||
if "usage" in output_json and output_json['usage'].get("prompt_tokens", 0) > 0:
|
||||
return output_json['usage']
|
||||
messages = input.get("messages")
|
||||
text = ""
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
content += item.get("content")
|
||||
prompt_tokens = self.get_num_tokens(text)
|
||||
output = output.text
|
||||
completion_tokens = self.get_num_tokens(output)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ from langchain_core.language_models.base import BaseLanguageModel
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
|
||||
from agentuniverse.base.config.component_configer.configers.llm_configer import LLMConfiger
|
||||
from agentuniverse.base.util.billing_center import BillingCenter, BillingCenterInfo
|
||||
from agentuniverse.base.util.env_util import get_from_env
|
||||
from agentuniverse.base.util.system_util import process_yaml_func
|
||||
from agentuniverse.llm.llm import LLM, LLMOutput
|
||||
@@ -45,14 +44,14 @@ class OpenAIStyleLLM(LLM):
|
||||
ext_params: Optional[dict] = {}
|
||||
ext_headers: Optional[dict] = {}
|
||||
|
||||
def _new_client(self):
|
||||
def _new_client(self, api_base=None):
|
||||
"""Initialize the openai client."""
|
||||
if self.client is not None:
|
||||
return self.client
|
||||
return OpenAI(
|
||||
api_key=self.api_key,
|
||||
organization=self.organization,
|
||||
base_url=BillingCenter.get_base_url(self.api_base),
|
||||
base_url=api_base if api_base else self.api_base,
|
||||
timeout=self.request_timeout,
|
||||
max_retries=self.max_retries,
|
||||
http_client=httpx.Client(proxy=self.proxy) if self.proxy else None,
|
||||
@@ -60,14 +59,14 @@ class OpenAIStyleLLM(LLM):
|
||||
**(self.client_args or {}),
|
||||
)
|
||||
|
||||
def _new_async_client(self):
|
||||
def _new_async_client(self, api_base: str = None):
|
||||
"""Initialize the openai async client."""
|
||||
if self.async_client is not None:
|
||||
return self.async_client
|
||||
return AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
organization=self.organization,
|
||||
base_url=BillingCenter.get_base_url(self.api_base),
|
||||
base_url=api_base if api_base else self.api_base,
|
||||
timeout=self.request_timeout,
|
||||
max_retries=self.max_retries,
|
||||
http_client=httpx.AsyncClient(proxy=self.proxy) if self.proxy else None,
|
||||
@@ -84,17 +83,21 @@ class OpenAIStyleLLM(LLM):
|
||||
streaming = kwargs.pop("streaming") if "streaming" in kwargs else self.streaming
|
||||
if 'stream' in kwargs:
|
||||
streaming = kwargs.pop('stream')
|
||||
self.client = self._new_client()
|
||||
self.client = self._new_client(kwargs.pop("api_base", None))
|
||||
ext_params = self.ext_params.copy()
|
||||
if streaming and "stream_options" not in ext_params:
|
||||
ext_params["stream_options"] = {
|
||||
"include_usage": True
|
||||
}
|
||||
client = self.client
|
||||
extra_headers, kwargs = BillingCenter.billing_center_openai_headers(self, kwargs)
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=kwargs.pop('model', self.model_name),
|
||||
temperature=kwargs.pop('temperature', self.temperature),
|
||||
stream=kwargs.pop('stream', streaming),
|
||||
max_tokens=kwargs.pop('max_tokens', self.max_tokens),
|
||||
extra_headers=extra_headers,
|
||||
extra_body=self.ext_params,
|
||||
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
|
||||
extra_body=ext_params,
|
||||
**kwargs,
|
||||
)
|
||||
if not streaming:
|
||||
@@ -112,17 +115,21 @@ class OpenAIStyleLLM(LLM):
|
||||
streaming = kwargs.pop("streaming") if "streaming" in kwargs else self.streaming
|
||||
if 'stream' in kwargs:
|
||||
streaming = kwargs.pop('stream')
|
||||
self.async_client = self._new_async_client()
|
||||
self.async_client = self._new_async_client(kwargs.pop("api_base", None))
|
||||
async_client = self.async_client
|
||||
extra_headers, kwargs = BillingCenter.billing_center_openai_headers(self, kwargs)
|
||||
ext_params = self.ext_params.copy()
|
||||
if streaming and "stream_options" not in ext_params:
|
||||
ext_params["stream_options"] = {
|
||||
"include_usage": True
|
||||
}
|
||||
chat_completion = await async_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=kwargs.pop('model', self.model_name),
|
||||
temperature=kwargs.pop('temperature', self.temperature),
|
||||
stream=kwargs.pop('stream', streaming),
|
||||
max_tokens=kwargs.pop('max_tokens', self.max_tokens),
|
||||
extra_headers=extra_headers,
|
||||
extra_body=self.ext_params,
|
||||
extra_headers=kwargs.pop("extra_headers", self.ext_headers),
|
||||
extra_body=ext_params,
|
||||
**kwargs,
|
||||
)
|
||||
if not streaming:
|
||||
@@ -196,10 +203,10 @@ class OpenAIStyleLLM(LLM):
|
||||
if 'proxy' in component_configer.configer.value:
|
||||
proxy = component_configer.configer.value.get('proxy')
|
||||
self.proxy = process_yaml_func(proxy, component_configer.yaml_func_instance)
|
||||
if component_configer.configer.value.get("extra_headers"):
|
||||
self.ext_headers = component_configer.configer.value.get("extra_headers")
|
||||
if component_configer.configer.value.get("extra_params"):
|
||||
self.ext_params = component_configer.configer.value.get("extra_params")
|
||||
if component_configer.configer.value.get("ext_headers"):
|
||||
self.ext_headers = component_configer.configer.value.get("ext_headers")
|
||||
if component_configer.configer.value.get("ext_params"):
|
||||
self.ext_params = component_configer.configer.value.get("ext_params")
|
||||
|
||||
return super().initialize_by_component_configer(component_configer)
|
||||
|
||||
@@ -224,71 +231,3 @@ class OpenAIStyleLLM(LLM):
|
||||
"""Return the maximum length of the context."""
|
||||
if super().max_context_length():
|
||||
return super().max_context_length()
|
||||
|
||||
def get_billing_tokens(self, input: dict, output: LLMOutput):
|
||||
output_json = output.raw
|
||||
if "usage" in output_json and output_json['usage'].get("prompt_tokens", 0) > 0:
|
||||
return output_json['usage']
|
||||
messages = input.get("messages")
|
||||
text = ""
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
content += item.get("content")
|
||||
prompt_tokens = self.get_num_tokens_from_model(text, model=input.get("model"))
|
||||
output = output.text
|
||||
completion_tokens = self.get_num_tokens_from_model(output, model=input.get("model"))
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
def billing_tokens_from_stream(self, generator: Iterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
usage = None
|
||||
for llm_output in generator:
|
||||
content += llm_output.text
|
||||
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
|
||||
usage = llm_output.raw.get("usage")
|
||||
yield llm_output
|
||||
if usage:
|
||||
billing_center_params.usage = usage
|
||||
return
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
async def async_billing_tokens_from_stream(self, generator: AsyncIterator[LLMOutput],
|
||||
billing_center_params: BillingCenterInfo):
|
||||
content = ""
|
||||
usage = None
|
||||
async for llm_output in generator:
|
||||
content += llm_output.text
|
||||
if "usage" in llm_output.raw and llm_output.raw.get("usage"):
|
||||
usage = llm_output.raw.get("usage")
|
||||
yield llm_output
|
||||
if usage:
|
||||
billing_center_params.usage = usage
|
||||
return
|
||||
llm_output = LLMOutput(
|
||||
text=content,
|
||||
raw={}
|
||||
)
|
||||
billing_center_params.output = llm_output.model_dump()
|
||||
billing_center_params.usage = self.get_billing_tokens(billing_center_params.input,
|
||||
llm_output)
|
||||
|
||||
def update_billing_center_params(self, params: BillingCenterInfo, input: dict):
|
||||
params.base_url = self.api_base
|
||||
return super().update_billing_center_params(params, input)
|
||||
|
||||
@@ -76,5 +76,16 @@ dir = './monitor'
|
||||
[EXTENSION_MODULES]
|
||||
class_list = [
|
||||
'${ROOT_PACKAGE}.config.config_extension.ConfigExtension',
|
||||
'${ROOT_PACKAGE}.config.yaml_func_extension.YamlFuncExtension'
|
||||
'${ROOT_PACKAGE}.config.yaml_func_extension.YamlFuncExtension',
|
||||
'agentuniverse_ant_ext.manager.billing_center.billing_center_manager.BillingCenterManager'
|
||||
]
|
||||
|
||||
[BILLING_CENTER]
|
||||
use_billing_center = true
|
||||
billing_center_url = "http://localhost:8888/v1"
|
||||
billing_type = "proxy"
|
||||
|
||||
[PLUGINS]
|
||||
llm_plugins = [
|
||||
"agentuniverse_ant_ext.manager.billing_center.billing_center.trace_billing"
|
||||
]
|
||||
@@ -35,7 +35,7 @@ profile:
|
||||
#
|
||||
# Note: The current configuration uses Way 1, as shown below
|
||||
llm_model:
|
||||
name: 'deepseek-chat'
|
||||
name: 'qwen2.5-72b-instruct'
|
||||
action:
|
||||
# Please select the tools and knowledge base.
|
||||
tool:
|
||||
|
||||
@@ -36,7 +36,7 @@ class DemoAgentTest(unittest.TestCase):
|
||||
output_stream = queue.Queue(10)
|
||||
instance: Agent = AgentManager().get_instance_obj('demo_agent')
|
||||
Thread(target=self.read_output, args=(output_stream,)).start()
|
||||
result = instance.run(input='你来自哪里,名字是什么,请详细介绍一下数据库', output_stream=output_stream)
|
||||
result = instance.run(input='你来自哪里,名字是什么,请详细介绍一下数据库', output_stream=output_stream,scene_code="billing_center_test")
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user