From 04fa5aa6ea5a26355f5bf85cf047f706ed6dd0c5 Mon Sep 17 00:00:00 2001 From: Veteran-ChengQin <1040242795@qq.com> Date: Mon, 6 Oct 2025 14:38:04 +0000 Subject: [PATCH] feat: add Ollama embedding --- .../knowledge/embedding/ollama_embedding.py | 153 ++++++++++++++++++ .../knowledge/embedding/ollama_embedding.yaml | 8 + .../embedding/test_ollama_embedding.py | 130 +++++++++++++++ 3 files changed, 291 insertions(+) create mode 100644 agentuniverse/agent/action/knowledge/embedding/ollama_embedding.py create mode 100644 agentuniverse/agent/action/knowledge/embedding/ollama_embedding.yaml create mode 100644 tests/test_agentuniverse/unit/agent/action/knowledge/embedding/test_ollama_embedding.py diff --git a/agentuniverse/agent/action/knowledge/embedding/ollama_embedding.py b/agentuniverse/agent/action/knowledge/embedding/ollama_embedding.py new file mode 100644 index 00000000..14bb5269 --- /dev/null +++ b/agentuniverse/agent/action/knowledge/embedding/ollama_embedding.py @@ -0,0 +1,153 @@ +# !/usr/bin/env python3 +# -*- coding:utf-8 -*- +# +# @Time : 2025-10-05 21:40 PM +# @Author : Cascade AI +# @Email : cascade@windsurf.ai +# @FileName: ollama_embedding.py + +from typing import Any, Optional, List +from pydantic import Field +import httpx +import asyncio + +from agentuniverse.base.util.env_util import get_from_env +from agentuniverse.agent.action.knowledge.embedding.embedding import Embedding +from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger + + +class OllamaEmbedding(Embedding): + """The Ollama embedding class.""" + + ollama_base_url: Optional[str] = Field( + default_factory=lambda: get_from_env("OLLAMA_BASE_URL") or "http://localhost:11434") + + ollama_api_key: Optional[str] = Field( + default_factory=lambda: get_from_env("OLLAMA_API_KEY")) + + embedding_model_name: Optional[str] = None + embedding_dims: Optional[int] = None + timeout: Optional[int] = Field(default=30) + + client: Any = None + async_client: Any = None + + def get_embeddings(self, texts: List[str], **kwargs) -> List[List[float]]: + """ + Retrieve text embeddings for a list of input texts using Ollama API. + Args: + texts (List[str]): A list of input texts to be embedded. + Returns: + List[List[float]]: A list of embeddings corresponding to the input texts. + Raises: + Exception: If the API call fails or if required configuration is missing. + """ + self._initialize_clients() + + try: + embeddings = [] + for text in texts: + response = self.client.post( + f"{self.ollama_base_url}/api/embeddings", + json={ + "model": self.embedding_model_name, + "prompt": text + }, + timeout=self.timeout + ) + response.raise_for_status() + result = response.json() + embeddings.append(result["embedding"]) + + return embeddings + + except Exception as e: + raise Exception(f"Failed to get embeddings: {e}") + + async def async_get_embeddings(self, texts: List[str], **kwargs) -> List[List[float]]: + """ + Retrieve text embeddings for a list of input texts using Ollama API asynchronously. + Args: + texts (List[str]): A list of input texts to be embedded. + Returns: + List[List[float]]: A list of embeddings corresponding to the input texts. + Raises: + Exception: If the API call fails or if required configuration is missing. + """ + self._initialize_clients() + + try: + async def get_single_embedding(text: str) -> List[float]: + response = await self.async_client.post( + f"{self.ollama_base_url}/api/embeddings", + json={ + "model": self.embedding_model_name, + "prompt": text + }, + timeout=self.timeout + ) + response.raise_for_status() + result = response.json() + return result["embedding"] + + tasks = [get_single_embedding(text) for text in texts] + embeddings = await asyncio.gather(*tasks) + return embeddings + + except Exception as e: + raise Exception(f"Failed to get embeddings: {e}") + + def as_langchain(self) -> Any: + """ + Convert the OllamaEmbedding instance to a LangChain OllamaEmbeddings instance. + """ + self._initialize_clients() + + try: + from langchain_community.embeddings import OllamaEmbeddings + return OllamaEmbeddings( + base_url=self.ollama_base_url, + model=self.embedding_model_name + ) + except ImportError: + raise Exception("langchain_community is required for LangChain integration") + + def _initialize_by_component_configer(self, embedding_configer: ComponentConfiger) -> 'Embedding': + """ + Initialize the embedding by the ComponentConfiger object. + Args: + embedding_configer(ComponentConfiger): A configer contains embedding configuration. + Returns: + Embedding: A OllamaEmbedding instance. + """ + super()._initialize_by_component_configer(embedding_configer) + if hasattr(embedding_configer, "ollama_base_url"): + self.ollama_base_url = embedding_configer.ollama_base_url + if hasattr(embedding_configer, "ollama_api_key"): + self.ollama_api_key = embedding_configer.ollama_api_key + if hasattr(embedding_configer, "timeout"): + self.timeout = embedding_configer.timeout + return self + + def _initialize_clients(self) -> None: + if not self.ollama_base_url: + raise Exception("OLLAMA_BASE_URL is missing") + if not self.embedding_model_name: + raise Exception("embedding_model_name is missing") + + headers = {} + if self.ollama_api_key: + headers["Authorization"] = f"Bearer {self.ollama_api_key}" + + if self.client is None: + self.client = httpx.Client( + base_url=self.ollama_base_url, + headers=headers, + timeout=self.timeout + ) + if self.async_client is None: + self.async_client = httpx.AsyncClient( + base_url=self.ollama_base_url, + headers=headers, + timeout=self.timeout + ) diff --git a/agentuniverse/agent/action/knowledge/embedding/ollama_embedding.yaml b/agentuniverse/agent/action/knowledge/embedding/ollama_embedding.yaml new file mode 100644 index 00000000..26c20277 --- /dev/null +++ b/agentuniverse/agent/action/knowledge/embedding/ollama_embedding.yaml @@ -0,0 +1,8 @@ +name: 'ollama_embedding' +description: 'embedding use ollama api with support for mxbai-embed-large, nomic-embed-text, all-minilm models' +embedding_model_name: 'mxbai-embed-large' + +metadata: + type: 'EMBEDDING' + module: 'agentuniverse.agent.action.knowledge.embedding.ollama_embedding' + class: 'OllamaEmbedding' diff --git a/tests/test_agentuniverse/unit/agent/action/knowledge/embedding/test_ollama_embedding.py b/tests/test_agentuniverse/unit/agent/action/knowledge/embedding/test_ollama_embedding.py new file mode 100644 index 00000000..49172a4b --- /dev/null +++ b/tests/test_agentuniverse/unit/agent/action/knowledge/embedding/test_ollama_embedding.py @@ -0,0 +1,130 @@ +# !/usr/bin/env python3 +# -*- coding:utf-8 -*- +# +# @Time : 2025-10-05 21:40 PM +# @Author : Cascade AI +# @Email : cascade@windsurf.ai +# @FileName: test_ollama_embedding.py + +import asyncio +import unittest +from unittest.mock import Mock, patch, AsyncMock +import httpx +from agentuniverse.agent.action.knowledge.embedding.ollama_embedding import OllamaEmbedding + + +class OllamaEmbeddingTest(unittest.TestCase): + """ + Test cases for OllamaEmbedding class + """ + + def setUp(self) -> None: + self.embedding = OllamaEmbedding() + self.embedding.ollama_base_url = "http://localhost:11434" + self.embedding.embedding_model_name = "mxbai-embed-large" + self.embedding.timeout = 30 + + def test_get_embeddings_real_api(self) -> None: + """ + Test get_embeddings with real Ollama API. + This test requires Ollama to be running locally with the model available. + Skip if Ollama is not available. + """ + try: + # Test with different models + models_to_test = ["mxbai-embed-large", "nomic-embed-text", "all-minilm"] + + for model in models_to_test: + with self.subTest(model=model): + self.embedding.embedding_model_name = model + res = self.embedding.get_embeddings(texts=["hello world"]) + print(f"Model {model} - Embedding shape: {len(res[0]) if res and res[0] else 'None'}") + + self.assertIsInstance(res, list) + self.assertEqual(len(res), 1) + self.assertIsInstance(res[0], list) + self.assertGreater(len(res[0]), 0) + + except Exception as e: + self.skipTest(f"Ollama API not available or model not found: {e}") + + def test_async_get_embeddings_real_api(self) -> None: + """ + Test async_get_embeddings with real Ollama API. + This test requires Ollama to be running locally with the model available. + Skip if Ollama is not available. + """ + try: + res = asyncio.run( + self.embedding.async_get_embeddings(texts=["hello world"])) + print(f"Async embedding result shape: {len(res[0]) if res and res[0] else 'None'}") + + self.assertIsInstance(res, list) + self.assertEqual(len(res), 1) + self.assertIsInstance(res[0], list) + self.assertGreater(len(res[0]), 0) + + except Exception as e: + self.skipTest(f"Ollama API not available: {e}") + + def test_as_langchain(self) -> None: + """ + Test LangChain integration. + This test requires langchain_community to be installed. + """ + try: + langchain_embedding = self.embedding.as_langchain() + self.assertIsNotNone(langchain_embedding) + print(f"LangChain embedding type: {type(langchain_embedding)}") + + # Test with LangChain interface if Ollama is available + try: + res = langchain_embedding.embed_documents(texts=["hello world"]) + print(f"LangChain embedding result shape: {len(res[0]) if res and res[0] else 'None'}") + self.assertIsInstance(res, list) + except Exception as e: + print(f"LangChain embedding test skipped: {e}") + + except ImportError: + self.skipTest("langchain_community not available") + except Exception as e: + self.skipTest(f"LangChain integration test failed: {e}") + + def test_initialization_errors(self) -> None: + """Test initialization error handling""" + # Test missing base URL + embedding = OllamaEmbedding() + embedding.ollama_base_url = None + embedding.embedding_model_name = "test-model" + + with self.assertRaises(Exception) as context: + embedding.get_embeddings(["test"]) + self.assertIn("OLLAMA_BASE_URL is missing", str(context.exception)) + + # Test missing model name + embedding = OllamaEmbedding() + embedding.ollama_base_url = "http://localhost:11434" + embedding.embedding_model_name = None + + with self.assertRaises(Exception) as context: + embedding.get_embeddings(["test"]) + self.assertIn("embedding_model_name is missing", str(context.exception)) + + def test_multiple_models_configuration(self) -> None: + """Test configuration with different supported models""" + supported_models = ["mxbai-embed-large", "nomic-embed-text", "all-minilm"] + + for model in supported_models: + with self.subTest(model=model): + embedding = OllamaEmbedding() + embedding.ollama_base_url = "http://localhost:11434" + embedding.embedding_model_name = model + embedding.timeout = 30 + + # Test that configuration is set correctly + self.assertEqual(embedding.embedding_model_name, model) + self.assertEqual(embedding.ollama_base_url, "http://localhost:11434") + + +if __name__ == '__main__': + unittest.main()