feat: add Ollama embedding

This commit is contained in:
Veteran-ChengQin
2025-10-06 14:38:04 +00:00
parent 41c5a06600
commit 04fa5aa6ea
3 changed files with 291 additions and 0 deletions

View File

@@ -0,0 +1,153 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
#
# @Time : 2025-10-05 21:40 PM
# @Author : Cascade AI
# @Email : cascade@windsurf.ai
# @FileName: ollama_embedding.py
from typing import Any, Optional, List
from pydantic import Field
import httpx
import asyncio
from agentuniverse.base.util.env_util import get_from_env
from agentuniverse.agent.action.knowledge.embedding.embedding import Embedding
from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger
class OllamaEmbedding(Embedding):
"""The Ollama embedding class."""
ollama_base_url: Optional[str] = Field(
default_factory=lambda: get_from_env("OLLAMA_BASE_URL") or "http://localhost:11434")
ollama_api_key: Optional[str] = Field(
default_factory=lambda: get_from_env("OLLAMA_API_KEY"))
embedding_model_name: Optional[str] = None
embedding_dims: Optional[int] = None
timeout: Optional[int] = Field(default=30)
client: Any = None
async_client: Any = None
def get_embeddings(self, texts: List[str], **kwargs) -> List[List[float]]:
"""
Retrieve text embeddings for a list of input texts using Ollama API.
Args:
texts (List[str]): A list of input texts to be embedded.
Returns:
List[List[float]]: A list of embeddings corresponding to the input texts.
Raises:
Exception: If the API call fails or if required configuration is missing.
"""
self._initialize_clients()
try:
embeddings = []
for text in texts:
response = self.client.post(
f"{self.ollama_base_url}/api/embeddings",
json={
"model": self.embedding_model_name,
"prompt": text
},
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
embeddings.append(result["embedding"])
return embeddings
except Exception as e:
raise Exception(f"Failed to get embeddings: {e}")
async def async_get_embeddings(self, texts: List[str], **kwargs) -> List[List[float]]:
"""
Retrieve text embeddings for a list of input texts using Ollama API asynchronously.
Args:
texts (List[str]): A list of input texts to be embedded.
Returns:
List[List[float]]: A list of embeddings corresponding to the input texts.
Raises:
Exception: If the API call fails or if required configuration is missing.
"""
self._initialize_clients()
try:
async def get_single_embedding(text: str) -> List[float]:
response = await self.async_client.post(
f"{self.ollama_base_url}/api/embeddings",
json={
"model": self.embedding_model_name,
"prompt": text
},
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
return result["embedding"]
tasks = [get_single_embedding(text) for text in texts]
embeddings = await asyncio.gather(*tasks)
return embeddings
except Exception as e:
raise Exception(f"Failed to get embeddings: {e}")
def as_langchain(self) -> Any:
"""
Convert the OllamaEmbedding instance to a LangChain OllamaEmbeddings instance.
"""
self._initialize_clients()
try:
from langchain_community.embeddings import OllamaEmbeddings
return OllamaEmbeddings(
base_url=self.ollama_base_url,
model=self.embedding_model_name
)
except ImportError:
raise Exception("langchain_community is required for LangChain integration")
def _initialize_by_component_configer(self, embedding_configer: ComponentConfiger) -> 'Embedding':
"""
Initialize the embedding by the ComponentConfiger object.
Args:
embedding_configer(ComponentConfiger): A configer contains embedding configuration.
Returns:
Embedding: A OllamaEmbedding instance.
"""
super()._initialize_by_component_configer(embedding_configer)
if hasattr(embedding_configer, "ollama_base_url"):
self.ollama_base_url = embedding_configer.ollama_base_url
if hasattr(embedding_configer, "ollama_api_key"):
self.ollama_api_key = embedding_configer.ollama_api_key
if hasattr(embedding_configer, "timeout"):
self.timeout = embedding_configer.timeout
return self
def _initialize_clients(self) -> None:
if not self.ollama_base_url:
raise Exception("OLLAMA_BASE_URL is missing")
if not self.embedding_model_name:
raise Exception("embedding_model_name is missing")
headers = {}
if self.ollama_api_key:
headers["Authorization"] = f"Bearer {self.ollama_api_key}"
if self.client is None:
self.client = httpx.Client(
base_url=self.ollama_base_url,
headers=headers,
timeout=self.timeout
)
if self.async_client is None:
self.async_client = httpx.AsyncClient(
base_url=self.ollama_base_url,
headers=headers,
timeout=self.timeout
)

View File

@@ -0,0 +1,8 @@
name: 'ollama_embedding'
description: 'embedding use ollama api with support for mxbai-embed-large, nomic-embed-text, all-minilm models'
embedding_model_name: 'mxbai-embed-large'
metadata:
type: 'EMBEDDING'
module: 'agentuniverse.agent.action.knowledge.embedding.ollama_embedding'
class: 'OllamaEmbedding'

View File

@@ -0,0 +1,130 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
#
# @Time : 2025-10-05 21:40 PM
# @Author : Cascade AI
# @Email : cascade@windsurf.ai
# @FileName: test_ollama_embedding.py
import asyncio
import unittest
from unittest.mock import Mock, patch, AsyncMock
import httpx
from agentuniverse.agent.action.knowledge.embedding.ollama_embedding import OllamaEmbedding
class OllamaEmbeddingTest(unittest.TestCase):
"""
Test cases for OllamaEmbedding class
"""
def setUp(self) -> None:
self.embedding = OllamaEmbedding()
self.embedding.ollama_base_url = "http://localhost:11434"
self.embedding.embedding_model_name = "mxbai-embed-large"
self.embedding.timeout = 30
def test_get_embeddings_real_api(self) -> None:
"""
Test get_embeddings with real Ollama API.
This test requires Ollama to be running locally with the model available.
Skip if Ollama is not available.
"""
try:
# Test with different models
models_to_test = ["mxbai-embed-large", "nomic-embed-text", "all-minilm"]
for model in models_to_test:
with self.subTest(model=model):
self.embedding.embedding_model_name = model
res = self.embedding.get_embeddings(texts=["hello world"])
print(f"Model {model} - Embedding shape: {len(res[0]) if res and res[0] else 'None'}")
self.assertIsInstance(res, list)
self.assertEqual(len(res), 1)
self.assertIsInstance(res[0], list)
self.assertGreater(len(res[0]), 0)
except Exception as e:
self.skipTest(f"Ollama API not available or model not found: {e}")
def test_async_get_embeddings_real_api(self) -> None:
"""
Test async_get_embeddings with real Ollama API.
This test requires Ollama to be running locally with the model available.
Skip if Ollama is not available.
"""
try:
res = asyncio.run(
self.embedding.async_get_embeddings(texts=["hello world"]))
print(f"Async embedding result shape: {len(res[0]) if res and res[0] else 'None'}")
self.assertIsInstance(res, list)
self.assertEqual(len(res), 1)
self.assertIsInstance(res[0], list)
self.assertGreater(len(res[0]), 0)
except Exception as e:
self.skipTest(f"Ollama API not available: {e}")
def test_as_langchain(self) -> None:
"""
Test LangChain integration.
This test requires langchain_community to be installed.
"""
try:
langchain_embedding = self.embedding.as_langchain()
self.assertIsNotNone(langchain_embedding)
print(f"LangChain embedding type: {type(langchain_embedding)}")
# Test with LangChain interface if Ollama is available
try:
res = langchain_embedding.embed_documents(texts=["hello world"])
print(f"LangChain embedding result shape: {len(res[0]) if res and res[0] else 'None'}")
self.assertIsInstance(res, list)
except Exception as e:
print(f"LangChain embedding test skipped: {e}")
except ImportError:
self.skipTest("langchain_community not available")
except Exception as e:
self.skipTest(f"LangChain integration test failed: {e}")
def test_initialization_errors(self) -> None:
"""Test initialization error handling"""
# Test missing base URL
embedding = OllamaEmbedding()
embedding.ollama_base_url = None
embedding.embedding_model_name = "test-model"
with self.assertRaises(Exception) as context:
embedding.get_embeddings(["test"])
self.assertIn("OLLAMA_BASE_URL is missing", str(context.exception))
# Test missing model name
embedding = OllamaEmbedding()
embedding.ollama_base_url = "http://localhost:11434"
embedding.embedding_model_name = None
with self.assertRaises(Exception) as context:
embedding.get_embeddings(["test"])
self.assertIn("embedding_model_name is missing", str(context.exception))
def test_multiple_models_configuration(self) -> None:
"""Test configuration with different supported models"""
supported_models = ["mxbai-embed-large", "nomic-embed-text", "all-minilm"]
for model in supported_models:
with self.subTest(model=model):
embedding = OllamaEmbedding()
embedding.ollama_base_url = "http://localhost:11434"
embedding.embedding_model_name = model
embedding.timeout = 30
# Test that configuration is set correctly
self.assertEqual(embedding.embedding_model_name, model)
self.assertEqual(embedding.ollama_base_url, "http://localhost:11434")
if __name__ == '__main__':
unittest.main()