1. The knowledge base store adds update and upsert methods.

2. Modify the document ID to use the text UUID as the value.
This commit is contained in:
heji
2024-04-30 19:49:22 +08:00
parent fa66a997ad
commit f3db89c635
4 changed files with 62 additions and 3 deletions

View File

@@ -99,6 +99,32 @@ class ChromaStore(Store):
ids=[document.id]
)
def upsert_document(self, documents: List[Document], **kwargs):
"""Upsert document into the store."""
for document in documents:
embedding = document.embedding
if self.embedding_model is not None and len(embedding) == 0:
embedding = self.embedding_model.get_embeddings([document.text])[0]
self.collection.upsert(
documents=[document.text],
metadatas=[document.metadata],
embeddings=[embedding] if embedding is not None else None,
ids=[document.id]
)
def update_document(self, documents: List[Document], **kwargs):
"""Update document into the store."""
for document in documents:
embedding = document.embedding
if self.embedding_model is not None and len(embedding) == 0:
embedding = self.embedding_model.get_embeddings([document.text])[0]
self.collection.update(
documents=[document.text],
metadatas=[document.metadata],
embeddings=[embedding] if embedding is not None else None,
ids=[document.id]
)
@staticmethod
def to_documents(query_result: QueryResult) -> List[Document]:
"""Convert the query results of ChromaDB to the AgentUniverse(AU) document format."""

View File

@@ -8,7 +8,7 @@ import uuid
from typing import Dict, Any, Optional, List
from langchain_core.documents.base import Document as LCDocument
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class Document(BaseModel):
@@ -21,11 +21,18 @@ class Document(BaseModel):
embedding (List[float]): Embedding data associated with the document
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
id: str = None
text: Optional[str] = ""
metadata: Optional[Dict[str, Any]] = None
embedding: List[float] = Field(default_factory=list)
@model_validator(mode='before')
def create_id(cls, values):
text: str = values.get('text', '')
if not values.get('id'):
values['id'] = str(uuid.uuid5(uuid.NAMESPACE_URL, text))
return values
def as_langchain(self) -> LCDocument:
"""Convert to LangChain document format."""
metadata = self.metadata or {}

View File

@@ -87,3 +87,19 @@ class Store(BaseModel):
async def async_delete_document(self, document_id: str, **kwargs):
"""Asynchronously delete the specific document by the document id."""
raise NotImplementedError
def upsert_document(self, documents: List[Document], **kwargs):
"""Upsert document into the store."""
raise NotImplementedError
async def async_upsert_document(self, documents: List[Document], **kwargs):
"""Asynchronously upsert documents into the store."""
raise NotImplementedError
def update_document(self, documents: List[Document], **kwargs):
"""Update document into the store."""
raise NotImplementedError
async def async_update_document(self, documents: List[Document], **kwargs):
"""Asynchronously update documents into the store."""
raise NotImplementedError

View File

@@ -27,6 +27,16 @@ class KnowledgeTest(unittest.TestCase):
embedding_model_name='text-embedding-ada-002'))
self.knowledge = Knowledge(**init_params)
def test_store_update_documents(self) -> None:
store = self.knowledge.store
store.update_document([Document(text='This is an iPhone', metadata={'type': 'Electronic products'}),
Document(text='This is a Tesla.', metadata={'type': 'Industrial products'})])
def test_store_upsert_documents(self) -> None:
store = self.knowledge.store
store.upsert_document([Document(text='This is an iPhone', metadata={'type': 'Cell phone'}),
Document(text='This is a Tesla.', metadata={'type': 'Car'})])
def test_store_insert_documents(self) -> None:
store = self.knowledge.store
store.insert_documents([Document(text='This is a document about engineer'),
@@ -34,7 +44,7 @@ class KnowledgeTest(unittest.TestCase):
def test_query(self) -> None:
store = self.knowledge.store
query = Query(query_str='Which stock is the best?', similarity_top_k=1)
query = Query(query_str='Which one is a cell phone?', similarity_top_k=1)
res = store.query(query)
print(res)