feat: Qdrant storage

Signed-off-by: Anush008 <anushshetty90@gmail.com>
This commit is contained in:
Anush008
2025-08-23 00:41:01 +05:30
parent 98252a20c2
commit 281dccd353
5 changed files with 436 additions and 9 deletions

View File

@@ -43,7 +43,7 @@ class Knowledge(ComponentBase):
description (str): The description of the knowledge.
stores (List[str]): The stores for the knowledge, which are used to store knowledge
and provide retrieval capabilities, such as ChromaDB store or Redis Store.
and provide retrieval capabilities, such as ChromaDB store, Redis Store or Qdrant Store.
query_paraphrasers (List[str]): Query paraphrasers used to paraphrase the original query string,
such as extracting keywords and splitting into sub-queries.

View File

@@ -0,0 +1,188 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Time : 2025/8/21
# @Author : Anush008
# @Email : anushshetty90@gmail.com
# @FileName: qdrant_store.py
from typing import Any, List, Optional, ClassVar
import uuid
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from agentuniverse.agent.action.knowledge.embedding.embedding_manager import (
EmbeddingManager,
)
from agentuniverse.agent.action.knowledge.store.document import Document
from agentuniverse.agent.action.knowledge.store.query import Query
from agentuniverse.agent.action.knowledge.store.store import Store
from agentuniverse.base.config.component_configer.component_configer import (
ComponentConfiger,
)
DEFAULT_CONNECTION_ARGS = {
"host": "localhost",
"port": 6333,
"https": False,
}
class QdrantStore(Store):
"""Qdrant-based vector store implementation.
Stores documents with vectors in a Qdrant collection and supports similarity search.
Attributes:
connection_args (Optional[dict]): Qdrant connection parameters.
collection_name (Optional[str]): Qdrant collection name.
distance (Optional[str]): Distance metric, one of "COSINE", "EUCLID", "DOT".
embedding_model (Optional[str]): Embedding model key managed by `EmbeddingManager`.
similarity_top_k (Optional[int]): Default top-k for search.
with_vectors (bool): If True, include vectors in query results.
"""
connection_args: Optional[dict] = None
collection_name: Optional[str] = "qdrant_db"
distance: Optional[str] = "COSINE"
embedding_model: Optional[str] = None
similarity_top_k: Optional[int] = 10
with_vectors: bool = False
client: Optional[QdrantClient] = None
VECTOR_NAME: ClassVar[str] = "embedding"
def _metric_from_str(self) -> Distance:
return {
"COSINE": Distance.COSINE,
"EUCLID": Distance.EUCLID,
"DOT": Distance.DOT,
"MANHATTAN": Distance.MANHATTAN,
}.get((self.distance or "COSINE").upper(), Distance.COSINE)
def _new_client(self) -> Any:
args = self.connection_args or DEFAULT_CONNECTION_ARGS
self.client = QdrantClient(**args)
return self.client
def _initialize_by_component_configer(self, qdrant_store_configer: ComponentConfiger) -> "QdrantStore":
super()._initialize_by_component_configer(qdrant_store_configer)
if hasattr(qdrant_store_configer, "connection_args"):
self.connection_args = qdrant_store_configer.connection_args
else:
self.connection_args = DEFAULT_CONNECTION_ARGS
if hasattr(qdrant_store_configer, "collection_name"):
self.collection_name = qdrant_store_configer.collection_name
if hasattr(qdrant_store_configer, "distance"):
self.distance = qdrant_store_configer.distance
if hasattr(qdrant_store_configer, "embedding_model"):
self.embedding_model = qdrant_store_configer.embedding_model
if hasattr(qdrant_store_configer, "similarity_top_k"):
self.similarity_top_k = qdrant_store_configer.similarity_top_k
if hasattr(qdrant_store_configer, "with_vectors"):
self.with_vectors = bool(qdrant_store_configer.with_vectors)
return self
def _ensure_collection(self, dim: int):
if self.client is None:
self.client = self._new_client()
if not self.client.collection_exists(self.collection_name):
metric = self._metric_from_str()
self.client.create_collection(
collection_name=self.collection_name,
vectors_config={self.VECTOR_NAME: VectorParams(size=dim, distance=metric)},
)
def query(self, query: Query, **kwargs) -> List[Document]:
if self.client is None:
return []
embedding = query.embeddings
if self.embedding_model is not None and (not embedding or len(embedding) == 0):
model = EmbeddingManager().get_instance_obj(self.embedding_model)
embedding = model.get_embeddings([query.query_str], text_type="query")
limit = query.similarity_top_k if query.similarity_top_k else self.similarity_top_k
if embedding and len(embedding) > 0:
query_vector = embedding[0]
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
using=self.VECTOR_NAME,
limit=limit,
with_payload=True,
with_vectors=self.with_vectors,
).points
else:
results = []
return self.to_documents(results)
def insert_document(self, documents: List[Document], **kwargs):
self.upsert_document(documents, **kwargs)
def upsert_document(self, documents: List[Document], **kwargs):
if self.client is None:
return
points: List[PointStruct] = []
for document in documents:
vector = document.embedding
if (not vector or len(vector) == 0) and self.embedding_model:
vector = EmbeddingManager().get_instance_obj(self.embedding_model).get_embeddings([document.text])[0]
if not vector or len(vector) == 0:
continue
self._ensure_collection(dim=len(vector))
payload = {"text": document.text, "metadata": document.metadata}
try:
point_id = str(uuid.UUID(str(document.id)))
except Exception:
# fallback to deterministic UUID5 if document id is not UUID
point_id = str(uuid.uuid5(uuid.NAMESPACE_URL, str(document.id)))
points.append(
PointStruct(
id=point_id,
vector={self.VECTOR_NAME: vector},
payload=payload,
)
)
if points:
self.client.upsert(collection_name=self.collection_name, points=points)
def update_document(self, documents: List[Document], **kwargs):
self.upsert_document(documents, **kwargs)
def delete_document(self, document_id: str, **kwargs):
if self.client is None:
return
self.client.delete(collection_name=self.collection_name, points_selector=[document_id])
@staticmethod
def to_documents(results) -> List[Document]:
if results is None:
return []
documents: List[Document] = []
for scored_point in results:
payload = scored_point.payload or {}
text = payload.get("text")
metadata = payload.get("metadata")
vector = scored_point.vector
if vector and isinstance(vector, dict):
vector = vector.get(QdrantStore.VECTOR_NAME, [])
else:
vector = []
documents.append(
Document(
id=str(scored_point.id),
text=text,
embedding=vector,
metadata=metadata,
)
)
return documents

View File

@@ -20,7 +20,7 @@ class Store(ComponentBase):
Store of the knowledge, store class is used to store knowledge
and provide retrieval capabilities,
vector storage, such as ChromaDB store, or non-vector storage, such as Redis Store.
vector storage, such as ChromaDB store, Qdrant Store, or non-vector storage, such as Redis Store.
"""
component_type: ComponentEnum = ComponentEnum.STORE
name: Optional[str] = None

View File

@@ -0,0 +1,238 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Time : 2025/08/22
# @Author : Anush008
# @Email : anushshetty90@gmail.com
# @FileName: qdrant_memory_storage.py
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, List, Optional, ClassVar
import uuid
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
NamedVector,
PointStruct,
Filter,
FieldCondition,
MatchValue,
)
from agentuniverse.agent.action.knowledge.embedding.embedding_manager import EmbeddingManager
from agentuniverse.agent.memory.memory_storage.memory_storage import MemoryStorage
from agentuniverse.agent.memory.message import Message
from agentuniverse.base.config.component_configer.component_configer import ComponentConfiger
DEFAULT_CONNECTION_ARGS: Dict[str, Any] = {
"host": "localhost",
"port": 6333,
"https": False,
}
class QdrantMemoryStorage(MemoryStorage):
"""Qdrant-based memory storage.
Stores chat messages as Qdrant points with a named vector and rich payload.
Attributes:
connection_args (Optional[dict]): Qdrant connection parameters. Supports either `url` or `host`/`port`.
collection_name (Optional[str]): Qdrant collection name.
distance (Optional[str]): Distance metric, one of "COSINE", "EUCLID", "DOT".
embedding_model (Optional[str]): Embedding model instance key managed by `EmbeddingManager`.
"""
class Config:
arbitrary_types_allowed = True
connection_args: Optional[dict] = None
collection_name: Optional[str] = "memory"
distance: Optional[str] = "COSINE"
embedding_model: Optional[str] = None
client: Optional[QdrantClient] = None
VECTOR_NAME: ClassVar[str] = "embedding"
def _initialize_by_component_configer(self, memory_storage_config: ComponentConfiger) -> "QdrantMemoryStorage":
super()._initialize_by_component_configer(memory_storage_config)
if getattr(memory_storage_config, "connection_args", None):
self.connection_args = memory_storage_config.connection_args
else:
self.connection_args = DEFAULT_CONNECTION_ARGS
if getattr(memory_storage_config, "collection_name", None):
self.collection_name = memory_storage_config.collection_name
if getattr(memory_storage_config, "distance", None):
self.distance = memory_storage_config.distance
if getattr(memory_storage_config, "embedding_model", None):
self.embedding_model = memory_storage_config.embedding_model
return self
def _metric_from_str(self) -> Distance:
return {
"COSINE": Distance.COSINE,
"EUCLID": Distance.EUCLID,
"DOT": Distance.DOT,
"MANHATTAN": Distance.MANHATTAN,
}.get((self.distance or "COSINE").upper(), Distance.COSINE)
def _ensure_client(self) -> QdrantClient:
if self.client is not None:
return self.client
args = self.connection_args or DEFAULT_CONNECTION_ARGS
self.client = QdrantClient(**args)
return self.client
def _ensure_collection(self, dim: int) -> None:
client = self._ensure_client()
if not client.collection_exists(self.collection_name):
metric = self._metric_from_str()
client.create_collection(
collection_name=self.collection_name,
vectors_config={self.VECTOR_NAME: VectorParams(size=dim, distance=metric)},
)
@staticmethod
def _build_filter(
session_id: Optional[str], agent_id: Optional[str], source: Optional[str], type_value: Optional[Any]
) -> Optional[Filter]:
must_conditions: List[Any] = []
if session_id:
must_conditions.append(FieldCondition(key="session_id", match=MatchValue(value=session_id)))
if agent_id:
must_conditions.append(FieldCondition(key="agent_id", match=MatchValue(value=agent_id)))
if source:
must_conditions.append(FieldCondition(key="source", match=MatchValue(value=source)))
if type_value:
must_conditions.append(FieldCondition(key="type", match=MatchValue(value=type_value[0])))
if not must_conditions:
return None
return Filter(must=must_conditions)
def delete(self, session_id: str = None, agent_id: str = None, **kwargs) -> None:
client = self._ensure_client()
filt = self._build_filter(session_id=session_id, agent_id=agent_id, source=None, type_value=kwargs.get("type"))
if not filt:
return
client.delete(collection_name=self.collection_name, points_selector=filt)
def add(self, message_list: List[Message], session_id: str = None, agent_id: str = None, **kwargs) -> None:
if not message_list:
return
client = self._ensure_client()
points: List[PointStruct] = []
for message in message_list:
metadata = dict(message.metadata or {})
metadata.update({"gmt_created": datetime.now().isoformat()})
if session_id:
metadata["session_id"] = session_id
if agent_id:
metadata["agent_id"] = agent_id
if message.source:
metadata["source"] = message.source
if message.type:
metadata["type"] = message.type
vector: List[float] = []
if self.embedding_model:
try:
vector = (
EmbeddingManager().get_instance_obj(self.embedding_model).get_embeddings([str(message.content)])
)[0]
except Exception:
vector = []
if vector:
self._ensure_collection(dim=len(vector))
else:
raise ValueError("No vectors available for message. Cannot store message without embeddings.")
payload = {
"content": message.content,
**metadata,
}
try:
point_id = str(uuid.UUID(str(message.id))) if message.id else str(uuid.uuid4())
except Exception:
point_id = str(uuid.uuid4())
points.append(
PointStruct(
id=point_id,
vector={self.VECTOR_NAME: vector} if vector else {},
payload=payload,
)
)
if points:
client.upsert(collection_name=self.collection_name, points=points)
def get(
self, session_id: str = None, agent_id: str = None, top_k=10, input: str = "", source: str = None, **kwargs
) -> List[Message]:
client = self._ensure_client()
filt = self._build_filter(
session_id=session_id, agent_id=agent_id, source=source, type_value=kwargs.get("type")
)
if input:
vector: List[float] = []
if self.embedding_model:
try:
vector = (EmbeddingManager().get_instance_obj(self.embedding_model).get_embeddings([input]))[0]
except Exception:
vector = []
if vector:
results = client.query_points(
collection_name=self.collection_name,
query=vector,
using=self.VECTOR_NAME,
limit=top_k,
with_payload=True,
with_vectors=False,
query_filter=filt,
)
messages = self.to_messages(results)
messages.reverse()
return messages
else:
return []
scroll_result = client.scroll(
collection_name=self.collection_name,
scroll_filter=filt,
limit=top_k,
with_payload=True,
with_vectors=False,
)
points = scroll_result[0]
messages = self.to_messages(points)
messages = sorted(messages, key=lambda msg: (msg.metadata or {}).get("gmt_created", ""))
messages.reverse()
return messages[:top_k]
def to_messages(self, results: Any) -> List[Message]:
message_list: List[Message] = []
if not results:
return message_list
try:
for item in results:
payload: Dict[str, Any] = getattr(item, "payload", None) or {}
msg = Message(
id=str(getattr(item, "id", None)),
content=payload.get("content"),
metadata={k: v for k, v in payload.items() if k not in {"content"}},
source=payload.get("source"),
type=payload.get("type", ""),
)
message_list.append(msg)
except Exception as e:
print("QdrantMemoryStorage.to_messages failed, exception= " + str(e))
return message_list

View File

@@ -21,14 +21,14 @@ classifiers = [
]
[tool.poetry.dependencies]
python = "^3.10"
python = ">=3.11,<3.12"
requests = "^2.32.0"
cffi = "^1.15.1"
flask = "^2.3.2"
werkzeug = "^3.0.3"
langchain = "0.1.20"
langchain-core = "0.1.52"
langchain-community = "0.0.38"
langchain = "^0.3.27"
langchain-core = "^0.3.27"
langchain-community = "^0.3.27"
openai = '1.55.3'
tiktoken = '<1.0.0'
loguru = '0.7.2'
@@ -53,10 +53,10 @@ googleapis-common-protos = "^1.63.0"
myst-parser = "^2.0.0"
qianfan = "^0.3.12"
dashscope = "^1.19.1"
anthropic = "^0.26.0"
anthropic = "^0.64.0"
ollama = '^0.2.1'
langchain-anthropic = '^0.1.13'
numpy = '^1.26.0'
langchain-anthropic = '^0.3.19'
numpy = '^2.3.2'
pandas = "^2.2.2"
pyarrow = "^16.1.0"
duckduckgo-search = "^6.3.5"
@@ -71,6 +71,7 @@ tomli = "^2.2"
mcp = "~=1.9.0"
opentracing = ">=2.4.0,<3.0.0"
jsonlines = "^4.0.0"
qdrant-client = "^1.15.1"
[tool.poetry.extras]
log_ext = ["aliyun-log-python-sdk"]