add youtube tool

This commit is contained in:
xmhu2001
2025-07-13 00:02:08 +08:00
parent 45ded8f467
commit 5292c383d4
3 changed files with 268 additions and 0 deletions

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Time : 2025/7/12 23:00
# @Author : xmhu2001
# @Email : xmhu2001@qq.com
# @FileName: youtube_tool.py
from typing import Optional, List, Dict, Any
from enum import Enum
from pydantic import Field
from agentuniverse.agent.action.tool.tool import Tool
from agentuniverse.base.annotation.retry import retry
from agentuniverse.base.util.env_util import get_from_env
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
import re
service_name = "youtube"
api_version = "v3"
class Mode(Enum):
VIDEO_SEARCH = "search"
TRENDING_VIDEOS = "trending"
CHANNEL_INFO = "channel_info"
class YouTubeTool(Tool):
service: Optional[Any] = None
api_key: Optional[str] = Field(default_factory=lambda: get_from_env("YOUTUBE_API_KEY"))
max_results: int = Field(10, description="Maximum video results to return")
def _initialize_service(self):
if not self.api_key:
raise ValueError("YouTube API key not provided, please set the YOUTUBE_API_KEY environment variable.")
if self.service is None:
self.service = build(service_name, api_version, developerKey=self.api_key)
return self.service
def parse_duration(self, duration_str):
"""Converts ISO 8601 duration format to seconds"""
match = re.match(r'PT(?:(\d+)H)?(?:(\d+)M)?(?:(\d+)S)?', duration_str)
if not match:
return 0
hours = int(match.group(1)) if match.group(1) else 0
minutes = int(match.group(2)) if match.group(2) else 0
seconds = int(match.group(3)) if match.group(3) else 0
return hours * 3600 + minutes * 60 + seconds
@retry(3, 1.0)
def _search_videos(self, query: str) -> List[Dict]:
try:
search_response = self.service.search().list(
q=query,
part='id',
type='video',
maxResults=self.max_results
).execute()
video_ids = [item['id']['videoId'] for item in search_response.get('items', [])]
if not video_ids:
return []
video_response = self.service.videos().list(
id=','.join(video_ids),
part='snippet,statistics,contentDetails'
).execute()
results = []
for item in video_response.get('items', []):
results.append({
'id': item['id'],
'title': item['snippet']['title'],
'view_count': int(item['statistics'].get('viewCount', 0)),
'like_count': int(item['statistics'].get('likeCount', 0)),
'comment_count': int(item['statistics'].get('commentCount', 0)),
'duration_seconds': self.parse_duration(item['contentDetails']['duration']),
'url': f"https://www.youtube.com/watch?v={item['id']}"
})
return results
except HttpError as e:
error_msg = f"Error searching for videos: {str(e)}"
if "quotaExceeded" in str(e):
error_msg += " (API quota may be exhausted)"
return [{"error": error_msg}]
@retry(3, 1.0)
def _analyze_channel(self, channel_id: str) -> Dict:
try:
response = self.service.channels().list(
id=channel_id,
part='snippet,statistics,contentDetails'
).execute()
if not response.get('items'):
return {"error": "Channel not found"}
channel_info = response['items'][0]
playlist_id = channel_info['contentDetails']['relatedPlaylists']['uploads']
video_list = []
next_page_token = None
for _ in range(self.max_results):
playlist_items = self.service.playlistItems().list(
playlistId=playlist_id,
part='snippet,contentDetails',
maxResults=self.max_results,
pageToken=next_page_token
).execute()
for item in playlist_items.get('items', []):
video_list.append({
'id': item['contentDetails']['videoId'],
'title': item['snippet']['title'],
'published_at': item['snippet']['publishedAt']
})
next_page_token = playlist_items.get('nextPageToken')
if not next_page_token:
break
return {
'name': channel_info['snippet']['title'],
'description': channel_info['snippet'].get('description', ''),
'subscriber_count': int(channel_info['statistics'].get('subscriberCount', 0)),
'total_view_count': int(channel_info['statistics'].get('viewCount', 0)),
'total_video_count': int(channel_info['statistics'].get('videoCount', 0)),
'latest_video_list': video_list
}
except HttpError as e:
error_msg = f"Error analyzing channel: {str(e)}"
if "quotaExceeded" in str(e):
error_msg += " (API quota may be exhausted)"
return {"error": error_msg}
@retry(3, 1.0)
def _get_trending_videos(self, region_code: Optional[str] = None) -> List[Dict]:
try:
request_param = {
'part': 'snippet,statistics,contentDetails',
'chart': 'mostPopular',
'maxResults': self.max_results
}
if region_code:
request_param['regionCode'] = region_code
response = self.service.videos().list(**request_param).execute()
results = []
for item in response.get('items', []):
view_count = int(item['statistics'].get('viewCount', 0))
results.append({
'id': item['id'],
'title': item['snippet']['title'],
'channel_title': item['snippet']['channelTitle'],
'published_at': item['snippet']['publishedAt'],
'view_count': view_count,
'like_count': int(item['statistics'].get('likeCount', 0)),
'comment_count': int(item['statistics'].get('commentCount', 0)),
'duration_seconds': self.parse_duration(item['contentDetails']['duration']),
'url': f"https://www.youtube.com/watch?v={item['id']}"
})
return results
except HttpError as e:
error_msg = f"Error getting trending videos: {str(e)}"
if "quotaExceeded" in str(e):
error_msg += " (API quota may be exhausted)"
return [{"error": error_msg}]
def execute(self,
mode: str,
input: Optional[str] = None
) -> List[Dict] | Dict:
if self.service is None:
self._initialize_service()
if mode == Mode.VIDEO_SEARCH.value:
if input is None:
raise ValueError("Query string is required for video search mode.")
return self._search_videos(input)
elif mode == Mode.CHANNEL_INFO.value:
if input is None:
raise ValueError("Channel ID is required for channel info mode.")
return self._analyze_channel(input)
elif mode == Mode.TRENDING_VIDEOS.value:
return self._get_trending_videos(input)
else:
raise ValueError(f"Invalid mode: {mode}. Must be one of {[m.value for m in Mode]}")

View File

@@ -0,0 +1,25 @@
name: 'youtube_tool'
description: |
该工具用于在 youtube 上检索视频和分析频道信息。工具支持三种操作模式:
1. 搜索相关视频:通过关键词搜索相关视频
2. 获取热榜视频:获取 youtube 热榜视频,可指定地区
3. 获取频道信息:通过频道 ID 获取频道名、订阅数、总播放量和视频数等信息
工具输入示例:
模式1 - 搜索视频:
input: "machine learning"
mode: "search"
模式3 - 获取热榜视频:
input: "US"
mode: "trending"
模式3 - 获取频道信息:
input: "UC_x5XG1OV2P6uZZ5FSM9Ttw"
mode: "channel_info"
tool_type: 'api'
input_keys: ['mode', 'input']
metadata:
type: 'TOOL'
module: 'agentuniverse.agent.action.tool.common_tool.youtube_tool'
class: 'YouTubeTool'

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Time : 2025/7/12 23:00
# @Author : xmhu2001
# @Email : xmhu2001@qq.com
# @FileName: test_youtube_tool.py
import unittest
import os
from agentuniverse.agent.action.tool.common_tool.youtube_tool import YouTubeTool, Mode
from agentuniverse.agent.action.tool.tool import ToolInput
class YouTubeToolTest(unittest.TestCase):
"""
Test cases for YouTubeTool class
"""
def setUp(self) -> None:
self.tool = YouTubeTool()
def test_search_videos(self) -> None:
tool_input = ToolInput({
'mode': Mode.VIDEO_SEARCH.value,
'input': 'machine learning'
})
result = self.tool.execute(tool_input.mode, tool_input.input)
self.assertTrue(result != [])
def test_analyze_channel(self) -> None:
tool_input = ToolInput({
'mode': Mode.CHANNEL_ANALYZE.value,
'input': 'UC_x5XG1OV2P6uZZ5FSM9Ttw'
})
result = self.tool.execute(tool_input.mode, tool_input.input)
self.assertTrue(result != {})
def test_get_trending_videos_with_region(self) -> None:
tool_input = ToolInput({
'mode': Mode.TRENDING_VIDEOS.value,
'input': 'US'
})
result = self.tool.execute(tool_input.mode, tool_input.input)
self.assertTrue(result != [])
def test_get_trending_videos(self) -> None:
tool_input = ToolInput({
'mode': Mode.TRENDING_VIDEOS.value
})
result = self.tool.execute(mode=tool_input.mode)
self.assertTrue(result != [])
if __name__ == '__main__':
unittest.main()