mirror of
https://github.com/agentuniverse-ai/agentUniverse.git
synced 2026-02-09 01:59:19 +08:00
Merge pull request #432 from xmhu2001/feature/jina_reranker
implement jina reranker
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
# !/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
# @Time : 2025/8/10 23:00
|
||||
# @Author : xmhu2001
|
||||
# @Email : xmhu2001@qq.com
|
||||
# @FileName: jina_reranker.py
|
||||
|
||||
from typing import List, Optional
|
||||
import requests
|
||||
|
||||
from agentuniverse.agent.action.knowledge.doc_processor.doc_processor import DocProcessor
|
||||
from agentuniverse.agent.action.knowledge.store.document import Document
|
||||
from agentuniverse.agent.action.knowledge.store.query import Query
|
||||
from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger
|
||||
from agentuniverse.base.util.env_util import get_from_env
|
||||
|
||||
api_base = "https://api.jina.ai/v1/rerank"
|
||||
|
||||
class JinaReranker(DocProcessor):
|
||||
"""Document reranker using Jina AI's Rerank API.
|
||||
|
||||
This processor reranks documents based on their relevance to a query
|
||||
using Jina AI's reranking models.
|
||||
"""
|
||||
api_key: Optional[str] = None
|
||||
model_name: str = "jina-reranker-v2-base-multilingual"
|
||||
top_n: int = 10
|
||||
|
||||
def _process_docs(self, origin_docs: List[Document], query: Query = None) -> List[Document]:
|
||||
"""Rerank documents based on their relevance to the query.
|
||||
|
||||
Args:
|
||||
origin_docs: List of documents to be reranked.
|
||||
query: Query object containing the search query string.
|
||||
|
||||
Returns:
|
||||
List[Document]: Reranked documents sorted by relevance score.
|
||||
|
||||
Raises:
|
||||
Exception: If the query is missing, the API key is not set, or the API call fails.
|
||||
"""
|
||||
if not query or not query.query_str:
|
||||
raise Exception("Jina AI reranker needs an origin string query.")
|
||||
if not self.api_key:
|
||||
raise Exception(
|
||||
"Jina AI API key is not set. Please configure it in the component or environment variables.")
|
||||
if not origin_docs:
|
||||
return []
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"query": query.query_str,
|
||||
"documents": [doc.text for doc in origin_docs],
|
||||
"top_n": self.top_n,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(api_base, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
results = response.json().get("results", [])
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"Jina AI rerank API call error: {e}")
|
||||
|
||||
rerank_docs = []
|
||||
for result in results:
|
||||
index = result.get("index")
|
||||
relevance_score = result.get("relevance_score")
|
||||
|
||||
if index is None or relevance_score is None:
|
||||
continue
|
||||
|
||||
if origin_docs[index].metadata:
|
||||
origin_docs[index].metadata["relevance_score"] = relevance_score
|
||||
else:
|
||||
origin_docs[index].metadata = {"relevance_score": relevance_score}
|
||||
|
||||
rerank_docs.append(origin_docs[index])
|
||||
|
||||
return rerank_docs
|
||||
|
||||
def _initialize_by_component_configer(self, doc_processor_configer: ComponentConfiger) -> 'DocProcessor':
|
||||
"""Initialize reranker parameters from component configuration.
|
||||
|
||||
Args:
|
||||
doc_processor_configer: Configuration object for the doc processor.
|
||||
|
||||
Returns:
|
||||
DocProcessor: The initialized document processor instance.
|
||||
"""
|
||||
super()._initialize_by_component_configer(doc_processor_configer)
|
||||
|
||||
self.api_key = get_from_env("JINA_API_KEY")
|
||||
|
||||
if hasattr(doc_processor_configer, "api_key"):
|
||||
self.api_key = doc_processor_configer.api_key
|
||||
if hasattr(doc_processor_configer, "model_name"):
|
||||
self.model_name = doc_processor_configer.model_name
|
||||
if hasattr(doc_processor_configer, "top_n"):
|
||||
self.top_n = doc_processor_configer.top_n
|
||||
|
||||
return self
|
||||
@@ -0,0 +1,6 @@
|
||||
name: 'jina_reranker'
|
||||
description: 'reranker use jina api'
|
||||
metadata:
|
||||
type: 'DOC_PROCESSOR'
|
||||
module: 'agentuniverse.agent.action.knowledge.doc_processor.jina_reranker'
|
||||
class: 'JinaReranker'
|
||||
@@ -0,0 +1,135 @@
|
||||
# !/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
# @Time : 2025/8/10 23:00
|
||||
# @Author : xmhu2001
|
||||
# @Email : xmhu2001@qq.com
|
||||
# @FileName: test_jina_reranker.py
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agentuniverse.agent.action.knowledge.doc_processor.jina_reranker import JinaReranker
|
||||
from agentuniverse.agent.action.knowledge.store.document import Document
|
||||
from agentuniverse.agent.action.knowledge.store.query import Query
|
||||
from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger
|
||||
from agentuniverse.base.config.configer import Configer
|
||||
|
||||
|
||||
class TestJinaReranker(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
cfg = Configer()
|
||||
cfg.value = {
|
||||
'name': 'jina_reranker',
|
||||
'description': 'reranker use jina api',
|
||||
'api_key': 'test_api_key',
|
||||
'model_name': 'test_model',
|
||||
'top_n': 5
|
||||
}
|
||||
self.configer = ComponentConfiger()
|
||||
self.configer.load_by_configer(cfg)
|
||||
self.reranker = JinaReranker()
|
||||
|
||||
self.test_docs = [
|
||||
Document(text='Document 1', metadata={'id': 1}),
|
||||
Document(text='Document 2', metadata={'id': 2}),
|
||||
Document(text='Document 3', metadata={'id': 3}),
|
||||
Document(text='Document 4', metadata={'id': 4}),
|
||||
Document(text='Document 5', metadata={'id': 5})
|
||||
]
|
||||
|
||||
self.test_query = Query(query_str='test query')
|
||||
|
||||
def test_initialize_by_component_configer_with_env(self):
|
||||
with patch('agentuniverse.base.util.env_util.get_from_env') as mock_get_env:
|
||||
mock_get_env.return_value = 'test_api_key'
|
||||
self.reranker = JinaReranker()
|
||||
self.reranker._initialize_by_component_configer(self.configer)
|
||||
|
||||
self.assertEqual(self.reranker.api_key, 'test_api_key')
|
||||
self.assertEqual(self.reranker.model_name, 'test_model')
|
||||
self.assertEqual(self.reranker.top_n, 5)
|
||||
|
||||
@patch('requests.post')
|
||||
def test_process_docs(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'results': [
|
||||
{'index': 2, 'relevance_score': 0.9},
|
||||
{'index': 0, 'relevance_score': 0.8},
|
||||
{'index': 4, 'relevance_score': 0.7},
|
||||
{'index': 1, 'relevance_score': 0.6},
|
||||
{'index': 3, 'relevance_score': 0.5}
|
||||
]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.reranker.api_key = 'test_api_key'
|
||||
|
||||
result_docs = self.reranker._process_docs(self.test_docs, self.test_query)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
'https://api.jina.ai/v1/rerank ',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer test_api_key'
|
||||
},
|
||||
json={
|
||||
'model': 'jina-reranker-v2-base-multilingual',
|
||||
'query': 'test query',
|
||||
'documents': [doc.text for doc in self.test_docs]
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(len(result_docs), 5)
|
||||
self.assertEqual(result_docs[0].metadata['id'], 3)
|
||||
self.assertEqual(result_docs[0].metadata['relevance_score'], 0.9)
|
||||
self.assertEqual(result_docs[1].metadata['id'], 1)
|
||||
self.assertEqual(result_docs[1].metadata['relevance_score'], 0.8)
|
||||
|
||||
@patch('requests.post')
|
||||
def test_process_docs_with_top_n(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
'results': [
|
||||
{'index': 2, 'relevance_score': 0.9},
|
||||
{'index': 0, 'relevance_score': 0.8}
|
||||
]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.reranker.api_key = 'test_api_key'
|
||||
self.reranker.top_n = 2
|
||||
|
||||
result_docs = self.reranker._process_docs(self.test_docs, self.test_query)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
'https://api.jina.ai/v1/rerank ',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer test_api_key'
|
||||
},
|
||||
json={
|
||||
'model': 'jina-reranker-v2-base-multilingual',
|
||||
'query': 'test query',
|
||||
'documents': [doc.text for doc in self.test_docs],
|
||||
'top_n': 2
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(len(result_docs), 2)
|
||||
|
||||
def test_process_docs_no_api_key(self):
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.reranker._process_docs(self.test_docs, self.test_query)
|
||||
|
||||
self.assertTrue('Jina AI API key is not set' in str(context.exception))
|
||||
|
||||
def test_process_docs_no_docs(self):
|
||||
self.reranker.api_key = 'test_api_key'
|
||||
result_docs = self.reranker._process_docs([], self.test_query)
|
||||
self.assertEqual(len(result_docs), 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user