Skip to content

Commit

Permalink
Add Astra DB vector store implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 17, 2024
1 parent f7f96c3 commit 513c021
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 5 deletions.
1 change: 1 addition & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ numpy
pypi
nbformat
semversioner
astrapy

# Library Methods
iterrows
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/verbs/text/embed/text_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def text_embed(
max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai
organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai
vector_store: # The optional configuration for the vector store
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb
type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, astradb
<...>
```
"""
Expand Down
2 changes: 2 additions & 0 deletions graphrag/vector_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

"""A package containing vector-storage implementations."""

from .astradb import AstraDB
from .azure_ai_search import AzureAISearch
from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult
from .lancedb import LanceDBVectorStore
from .typing import VectorStoreFactory, VectorStoreType

__all__ = [
"AstraDB",
"AzureAISearch",
"BaseVectorStore",
"LanceDBVectorStore",
Expand Down
116 changes: 116 additions & 0 deletions graphrag/vector_stores/astradb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The Astra DB vector store implementation package."""

import json
from typing import Any
from typing_extensions import override

from astrapy import DataAPIClient

from graphrag.model.types import TextEmbedder

from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult


class AstraDB(BaseVectorStore):
"""The Astra DB vector storage implementation."""

@override
def connect(self,
*,
token: str | None = None,
api_endpoint: str | None = None,
namespace: str | None = None,
**kwargs: Any) -> Any:
self.db_connection = DataAPIClient().get_database(
api_endpoint=api_endpoint,
token=token,
namespace=namespace,
)
self.document_collection = self.db_connection.get_collection(
self.collection_name,
namespace=namespace
)

@override
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
if overwrite:
self.document_collection.drop()

if not documents:
return

self.db_connection.create_collection(
name=self.collection_name,
dimension=len(documents[0].vector),
check_exists=False,
)

batch = [
{
"content": doc.text,
"_id": doc.id,
"$vector": doc.vector,
"metadata": json.dumps(doc.attributes),
}
for doc in documents
if doc.vector is not None
]

if batch and len(batch) > 0:
self.document_collection.insert_many(batch)

@override
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
if include_ids is None or len(include_ids) == 0:
self.query_filter = {}
else:
self.query_filter = {"_id": {"$in": include_ids}}
return self.query_filter

@override
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
response = self.document_collection.find(
filter=self.query_filter or {},
projection={
"_id": True,
"content": True,
"metadata": True,
"$vector": True,
},
limit=k,
include_similarity=True,
sort={"$vector": query_embedding},
)
return [
VectorStoreSearchResult(
document=VectorStoreDocument(
id=doc["_id"],
text=doc["content"],
vector=doc["$vector"],
attributes=doc["metadata"],
),
score=doc["$similarity"],
)
for doc in response
]

@override
def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
query_embedding = text_embedder(text)
if query_embedding:
return self.similarity_search_by_vector(
query_embedding=query_embedding, k=k
)
return []



8 changes: 5 additions & 3 deletions graphrag/vector_stores/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from enum import Enum
from typing import ClassVar

from .azure_ai_search import AzureAISearch
from .lancedb import LanceDBVectorStore
from . import BaseVectorStore, AstraDB, AzureAISearch, LanceDBVectorStore


class VectorStoreType(str, Enum):
"""The supported vector store types."""

AstraDB = "astradb"
LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"

Expand All @@ -30,9 +30,11 @@ def register(cls, vector_store_type: str, vector_store: type):
@classmethod
def get_vector_store(
cls, vector_store_type: VectorStoreType | str, kwargs: dict
) -> LanceDBVectorStore | AzureAISearch:
) -> BaseVectorStore:
"""Get the vector store type from a string."""
match vector_store_type:
case VectorStoreType.AstraDB:
return AstraDB(**kwargs)
case VectorStoreType.LanceDB:
return LanceDBVectorStore(**kwargs)
case VectorStoreType.AzureAISearch:
Expand Down
Loading

0 comments on commit 513c021

Please sign in to comment.