mirror of
https://github.com/agentuniverse-ai/agentUniverse.git
synced 2026-02-09 01:59:19 +08:00
feat: add Ollama embedding
This commit is contained in:
@@ -0,0 +1,153 @@
|
||||
# !/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# @Time : 2025-10-05 21:40 PM
|
||||
# @Author : Cascade AI
|
||||
# @Email : cascade@windsurf.ai
|
||||
# @FileName: ollama_embedding.py
|
||||
|
||||
from typing import Any, Optional, List
|
||||
from pydantic import Field
|
||||
import httpx
|
||||
import asyncio
|
||||
|
||||
from agentuniverse.base.util.env_util import get_from_env
|
||||
from agentuniverse.agent.action.knowledge.embedding.embedding import Embedding
|
||||
from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger
|
||||
|
||||
|
||||
class OllamaEmbedding(Embedding):
|
||||
"""The Ollama embedding class."""
|
||||
|
||||
ollama_base_url: Optional[str] = Field(
|
||||
default_factory=lambda: get_from_env("OLLAMA_BASE_URL") or "http://localhost:11434")
|
||||
|
||||
ollama_api_key: Optional[str] = Field(
|
||||
default_factory=lambda: get_from_env("OLLAMA_API_KEY"))
|
||||
|
||||
embedding_model_name: Optional[str] = None
|
||||
embedding_dims: Optional[int] = None
|
||||
timeout: Optional[int] = Field(default=30)
|
||||
|
||||
client: Any = None
|
||||
async_client: Any = None
|
||||
|
||||
def get_embeddings(self, texts: List[str], **kwargs) -> List[List[float]]:
|
||||
"""
|
||||
Retrieve text embeddings for a list of input texts using Ollama API.
|
||||
Args:
|
||||
texts (List[str]): A list of input texts to be embedded.
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings corresponding to the input texts.
|
||||
Raises:
|
||||
Exception: If the API call fails or if required configuration is missing.
|
||||
"""
|
||||
self._initialize_clients()
|
||||
|
||||
try:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = self.client.post(
|
||||
f"{self.ollama_base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self.embedding_model_name,
|
||||
"prompt": text
|
||||
},
|
||||
timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
embeddings.append(result["embedding"])
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to get embeddings: {e}")
|
||||
|
||||
async def async_get_embeddings(self, texts: List[str], **kwargs) -> List[List[float]]:
|
||||
"""
|
||||
Retrieve text embeddings for a list of input texts using Ollama API asynchronously.
|
||||
Args:
|
||||
texts (List[str]): A list of input texts to be embedded.
|
||||
Returns:
|
||||
List[List[float]]: A list of embeddings corresponding to the input texts.
|
||||
Raises:
|
||||
Exception: If the API call fails or if required configuration is missing.
|
||||
"""
|
||||
self._initialize_clients()
|
||||
|
||||
try:
|
||||
async def get_single_embedding(text: str) -> List[float]:
|
||||
response = await self.async_client.post(
|
||||
f"{self.ollama_base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self.embedding_model_name,
|
||||
"prompt": text
|
||||
},
|
||||
timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result["embedding"]
|
||||
|
||||
tasks = [get_single_embedding(text) for text in texts]
|
||||
embeddings = await asyncio.gather(*tasks)
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to get embeddings: {e}")
|
||||
|
||||
def as_langchain(self) -> Any:
|
||||
"""
|
||||
Convert the OllamaEmbedding instance to a LangChain OllamaEmbeddings instance.
|
||||
"""
|
||||
self._initialize_clients()
|
||||
|
||||
try:
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
return OllamaEmbeddings(
|
||||
base_url=self.ollama_base_url,
|
||||
model=self.embedding_model_name
|
||||
)
|
||||
except ImportError:
|
||||
raise Exception("langchain_community is required for LangChain integration")
|
||||
|
||||
def _initialize_by_component_configer(self, embedding_configer: ComponentConfiger) -> 'Embedding':
|
||||
"""
|
||||
Initialize the embedding by the ComponentConfiger object.
|
||||
Args:
|
||||
embedding_configer(ComponentConfiger): A configer contains embedding configuration.
|
||||
Returns:
|
||||
Embedding: A OllamaEmbedding instance.
|
||||
"""
|
||||
super()._initialize_by_component_configer(embedding_configer)
|
||||
if hasattr(embedding_configer, "ollama_base_url"):
|
||||
self.ollama_base_url = embedding_configer.ollama_base_url
|
||||
if hasattr(embedding_configer, "ollama_api_key"):
|
||||
self.ollama_api_key = embedding_configer.ollama_api_key
|
||||
if hasattr(embedding_configer, "timeout"):
|
||||
self.timeout = embedding_configer.timeout
|
||||
return self
|
||||
|
||||
def _initialize_clients(self) -> None:
|
||||
if not self.ollama_base_url:
|
||||
raise Exception("OLLAMA_BASE_URL is missing")
|
||||
if not self.embedding_model_name:
|
||||
raise Exception("embedding_model_name is missing")
|
||||
|
||||
headers = {}
|
||||
if self.ollama_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.ollama_api_key}"
|
||||
|
||||
if self.client is None:
|
||||
self.client = httpx.Client(
|
||||
base_url=self.ollama_base_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
if self.async_client is None:
|
||||
self.async_client = httpx.AsyncClient(
|
||||
base_url=self.ollama_base_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
@@ -0,0 +1,8 @@
|
||||
name: 'ollama_embedding'
|
||||
description: 'embedding use ollama api with support for mxbai-embed-large, nomic-embed-text, all-minilm models'
|
||||
embedding_model_name: 'mxbai-embed-large'
|
||||
|
||||
metadata:
|
||||
type: 'EMBEDDING'
|
||||
module: 'agentuniverse.agent.action.knowledge.embedding.ollama_embedding'
|
||||
class: 'OllamaEmbedding'
|
||||
@@ -0,0 +1,130 @@
|
||||
# !/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# @Time : 2025-10-05 21:40 PM
|
||||
# @Author : Cascade AI
|
||||
# @Email : cascade@windsurf.ai
|
||||
# @FileName: test_ollama_embedding.py
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import httpx
|
||||
from agentuniverse.agent.action.knowledge.embedding.ollama_embedding import OllamaEmbedding
|
||||
|
||||
|
||||
class OllamaEmbeddingTest(unittest.TestCase):
|
||||
"""
|
||||
Test cases for OllamaEmbedding class
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.embedding = OllamaEmbedding()
|
||||
self.embedding.ollama_base_url = "http://localhost:11434"
|
||||
self.embedding.embedding_model_name = "mxbai-embed-large"
|
||||
self.embedding.timeout = 30
|
||||
|
||||
def test_get_embeddings_real_api(self) -> None:
|
||||
"""
|
||||
Test get_embeddings with real Ollama API.
|
||||
This test requires Ollama to be running locally with the model available.
|
||||
Skip if Ollama is not available.
|
||||
"""
|
||||
try:
|
||||
# Test with different models
|
||||
models_to_test = ["mxbai-embed-large", "nomic-embed-text", "all-minilm"]
|
||||
|
||||
for model in models_to_test:
|
||||
with self.subTest(model=model):
|
||||
self.embedding.embedding_model_name = model
|
||||
res = self.embedding.get_embeddings(texts=["hello world"])
|
||||
print(f"Model {model} - Embedding shape: {len(res[0]) if res and res[0] else 'None'}")
|
||||
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertIsInstance(res[0], list)
|
||||
self.assertGreater(len(res[0]), 0)
|
||||
|
||||
except Exception as e:
|
||||
self.skipTest(f"Ollama API not available or model not found: {e}")
|
||||
|
||||
def test_async_get_embeddings_real_api(self) -> None:
|
||||
"""
|
||||
Test async_get_embeddings with real Ollama API.
|
||||
This test requires Ollama to be running locally with the model available.
|
||||
Skip if Ollama is not available.
|
||||
"""
|
||||
try:
|
||||
res = asyncio.run(
|
||||
self.embedding.async_get_embeddings(texts=["hello world"]))
|
||||
print(f"Async embedding result shape: {len(res[0]) if res and res[0] else 'None'}")
|
||||
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertIsInstance(res[0], list)
|
||||
self.assertGreater(len(res[0]), 0)
|
||||
|
||||
except Exception as e:
|
||||
self.skipTest(f"Ollama API not available: {e}")
|
||||
|
||||
def test_as_langchain(self) -> None:
|
||||
"""
|
||||
Test LangChain integration.
|
||||
This test requires langchain_community to be installed.
|
||||
"""
|
||||
try:
|
||||
langchain_embedding = self.embedding.as_langchain()
|
||||
self.assertIsNotNone(langchain_embedding)
|
||||
print(f"LangChain embedding type: {type(langchain_embedding)}")
|
||||
|
||||
# Test with LangChain interface if Ollama is available
|
||||
try:
|
||||
res = langchain_embedding.embed_documents(texts=["hello world"])
|
||||
print(f"LangChain embedding result shape: {len(res[0]) if res and res[0] else 'None'}")
|
||||
self.assertIsInstance(res, list)
|
||||
except Exception as e:
|
||||
print(f"LangChain embedding test skipped: {e}")
|
||||
|
||||
except ImportError:
|
||||
self.skipTest("langchain_community not available")
|
||||
except Exception as e:
|
||||
self.skipTest(f"LangChain integration test failed: {e}")
|
||||
|
||||
def test_initialization_errors(self) -> None:
|
||||
"""Test initialization error handling"""
|
||||
# Test missing base URL
|
||||
embedding = OllamaEmbedding()
|
||||
embedding.ollama_base_url = None
|
||||
embedding.embedding_model_name = "test-model"
|
||||
|
||||
with self.assertRaises(Exception) as context:
|
||||
embedding.get_embeddings(["test"])
|
||||
self.assertIn("OLLAMA_BASE_URL is missing", str(context.exception))
|
||||
|
||||
# Test missing model name
|
||||
embedding = OllamaEmbedding()
|
||||
embedding.ollama_base_url = "http://localhost:11434"
|
||||
embedding.embedding_model_name = None
|
||||
|
||||
with self.assertRaises(Exception) as context:
|
||||
embedding.get_embeddings(["test"])
|
||||
self.assertIn("embedding_model_name is missing", str(context.exception))
|
||||
|
||||
def test_multiple_models_configuration(self) -> None:
|
||||
"""Test configuration with different supported models"""
|
||||
supported_models = ["mxbai-embed-large", "nomic-embed-text", "all-minilm"]
|
||||
|
||||
for model in supported_models:
|
||||
with self.subTest(model=model):
|
||||
embedding = OllamaEmbedding()
|
||||
embedding.ollama_base_url = "http://localhost:11434"
|
||||
embedding.embedding_model_name = model
|
||||
embedding.timeout = 30
|
||||
|
||||
# Test that configuration is set correctly
|
||||
self.assertEqual(embedding.embedding_model_name, model)
|
||||
self.assertEqual(embedding.ollama_base_url, "http://localhost:11434")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user