From 9437b68bb945a6e98a0d9a75f1d37ea5da96f6d4 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 23 Sep 2024 16:37:53 +0200 Subject: [PATCH] Add Cassandra vector store implementation --- dictionary.txt | 1 + graphrag/index/verbs/text/embed/text_embed.py | 2 +- graphrag/vector_stores/__init__.py | 2 + graphrag/vector_stores/cassandra.py | 122 ++++++++++++++++++ graphrag/vector_stores/typing.py | 6 +- poetry.lock | 80 +++++++++++- pyproject.toml | 2 + 7 files changed, 212 insertions(+), 3 deletions(-) create mode 100644 graphrag/vector_stores/cassandra.py diff --git a/dictionary.txt b/dictionary.txt index 824d6faa98..5a1da92525 100644 --- a/dictionary.txt +++ b/dictionary.txt @@ -63,6 +63,7 @@ numpy pypi nbformat semversioner +cassio # Library Methods iterrows diff --git a/graphrag/index/verbs/text/embed/text_embed.py b/graphrag/index/verbs/text/embed/text_embed.py index 76ac97d76f..0431991db5 100644 --- a/graphrag/index/verbs/text/embed/text_embed.py +++ b/graphrag/index/verbs/text/embed/text_embed.py @@ -75,7 +75,7 @@ async def text_embed( max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai vector_store: # The optional configuration for the vector store - type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb + type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, cassandra <...> ``` """ diff --git a/graphrag/vector_stores/__init__.py b/graphrag/vector_stores/__init__.py index d4c11760aa..6c3fd4ca56 100644 --- a/graphrag/vector_stores/__init__.py +++ b/graphrag/vector_stores/__init__.py @@ -5,12 +5,14 @@ from .azure_ai_search import AzureAISearch from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult +from .cassandra import CassandraVectorStore from .lancedb import LanceDBVectorStore from .typing import VectorStoreFactory, VectorStoreType __all__ = [ "AzureAISearch", "BaseVectorStore", + "CassandraVectorStore", "LanceDBVectorStore", "VectorStoreDocument", "VectorStoreFactory", diff --git a/graphrag/vector_stores/cassandra.py b/graphrag/vector_stores/cassandra.py new file mode 100644 index 0000000000..34d163edde --- /dev/null +++ b/graphrag/vector_stores/cassandra.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Apache Cassandra vector store implementation package.""" + +from typing import Any + +import cassio +from cassandra.cluster import Session +from cassio.table import MetadataVectorCassandraTable +from typing_extensions import override + +from graphrag.model.types import TextEmbedder + +from .base import ( + DEFAULT_VECTOR_SIZE, + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class CassandraVectorStore(BaseVectorStore): + """The Apache Cassandra vector storage implementation.""" + + @override + def connect( + self, + *, + session: Session | None = None, + keyspace: str | None = None, + **kwargs: Any, + ) -> None: + """Connect to the Apache Cassandra database. + + Parameters + ---------- + session : + The Cassandra session. If not provided, it is resolved from cassio. + keyspace : + The Cassandra keyspace. If not provided, it is resolved from cassio. + """ + self.db_connection = cassio.config.check_resolve_session(session) + self.keyspace = cassio.config.check_resolve_keyspace(keyspace) + + @override + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + if overwrite: + self.db_connection.execute( + f"DROP TABLE IF EXISTS {self.keyspace}.{self.collection_name};" + ) + + if not documents: + return + + if not self.document_collection or overwrite: + dimension = DEFAULT_VECTOR_SIZE + for doc in documents: + if doc.vector: + dimension = len(doc.vector) + break + self.document_collection = MetadataVectorCassandraTable( + table=self.collection_name, + vector_dimension=dimension, + primary_key_type="TEXT", + ) + + futures = [ + self.document_collection.put_async( + row_id=doc.id, + body_blob=doc.text, + vector=doc.vector, + metadata=doc.attributes, + ) + for doc in documents + if doc.vector + ] + + for future in futures: + future.result() + + @override + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + msg = "Cassandra vector store doesn't support filtering by IDs." + raise NotImplementedError(msg) + + @override + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + response = self.document_collection.metric_ann_search( + vector=query_embedding, + n=k, + metric="cos", + **kwargs, + ) + + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=doc["row_id"], + text=doc["body_blob"], + vector=doc["vector"], + attributes=doc["metadata"], + ), + score=doc["distance"], + ) + for doc in response + ] + + @override + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector( + query_embedding=query_embedding, k=k, **kwargs + ) + return [] diff --git a/graphrag/vector_stores/typing.py b/graphrag/vector_stores/typing.py index 0b5a5cd195..adfe83103a 100644 --- a/graphrag/vector_stores/typing.py +++ b/graphrag/vector_stores/typing.py @@ -6,6 +6,7 @@ from enum import Enum from typing import ClassVar +from . import BaseVectorStore, CassandraVectorStore from .azure_ai_search import AzureAISearch from .lancedb import LanceDBVectorStore @@ -15,6 +16,7 @@ class VectorStoreType(str, Enum): LanceDB = "lancedb" AzureAISearch = "azure_ai_search" + Cassandra = "cassandra" class VectorStoreFactory: @@ -30,13 +32,15 @@ def register(cls, vector_store_type: str, vector_store: type): @classmethod def get_vector_store( cls, vector_store_type: VectorStoreType | str, kwargs: dict - ) -> LanceDBVectorStore | AzureAISearch: + ) -> BaseVectorStore: """Get the vector store type from a string.""" match vector_store_type: case VectorStoreType.LanceDB: return LanceDBVectorStore(**kwargs) case VectorStoreType.AzureAISearch: return AzureAISearch(**kwargs) + case VectorStoreType.Cassandra: + return CassandraVectorStore(**kwargs) case _: if vector_store_type in cls.vector_store_types: return cls.vector_store_types[vector_store_type](**kwargs) diff --git a/poetry.lock b/poetry.lock index 042d2c17ec..6244571778 100644 --- a/poetry.lock +++ b/poetry.lock @@ -391,6 +391,69 @@ files = [ {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"}, ] +[[package]] +name = "cassandra-driver" +version = "3.29.2" +description = "DataStax Driver for Apache Cassandra" +optional = false +python-versions = "*" +files = [ + {file = "cassandra-driver-3.29.2.tar.gz", hash = "sha256:c4310a7d0457f51a63fb019d8ef501588c491141362b53097fbc62fa06559b7c"}, + {file = "cassandra_driver-3.29.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:957208093ff2353230d0d83edf8c8e8582e4f2999d9a33292be6558fec943562"}, + {file = "cassandra_driver-3.29.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d70353b6d9d6e01e2b261efccfe90ce0aa6f416588e6e626ca2ed0aff6b540cf"}, + {file = "cassandra_driver-3.29.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06ad489e4df2cc7f41d3aca8bd8ddeb8071c4fb98240ed07f1dcd9b5180fd879"}, + {file = "cassandra_driver-3.29.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7f1dfa33c3d93350057d6dc163bb92748b6e6a164c408c75cf2c59be0a203b7"}, + {file = "cassandra_driver-3.29.2-cp310-cp310-win32.whl", hash = "sha256:f9df1e6ae4201eb2eae899cb0649d46b3eb0843f075199b51360bc9d59679a31"}, + {file = "cassandra_driver-3.29.2-cp310-cp310-win_amd64.whl", hash = "sha256:c4a005bc0b4fd8b5716ad931e1cc788dbd45967b0bcbdc3dfde33c7f9fde40d4"}, + {file = "cassandra_driver-3.29.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e31cee01a6fc8cf7f32e443fa0031bdc75eed46126831b7a807ab167b4dc1316"}, + {file = "cassandra_driver-3.29.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:52edc6d4bd7d07b10dc08b7f044dbc2ebe24ad7009c23a65e0916faed1a34065"}, + {file = "cassandra_driver-3.29.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb3a9f24fc84324d426a69dc35df66de550833072a4d9a4d63d72fda8fcaecb9"}, + {file = "cassandra_driver-3.29.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e89de04809d02bb1d5d03c0946a7baaaf85e93d7e6414885b4ea2616efe9de0"}, + {file = "cassandra_driver-3.29.2-cp311-cp311-win32.whl", hash = "sha256:7104e5043e9cc98136d7fafe2418cbc448dacb4e1866fe38ff5be76f227437ef"}, + {file = "cassandra_driver-3.29.2-cp311-cp311-win_amd64.whl", hash = "sha256:69aa53f1bdb23487765faa92eef57366637878eafc412f46af999e722353b22f"}, + {file = "cassandra_driver-3.29.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a1e994a82b2e6ab022c5aec24e03ad49fca5f3d47e566a145de34eb0e768473a"}, + {file = "cassandra_driver-3.29.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2039201ae5d9b7c7ce0930af7138d2637ca16a4c7aaae2fbdd4355fbaf3003c5"}, + {file = "cassandra_driver-3.29.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8067fad22e76e250c3846507d804f90b53e943bba442fa1b26583bcac692aaf1"}, + {file = "cassandra_driver-3.29.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee0ebe8eb4fb007d8001ffcd1c3828b74defeb01075d8a1f1116ae9c60f75541"}, + {file = "cassandra_driver-3.29.2-cp312-cp312-win32.whl", hash = "sha256:83dc9399cdabe482fd3095ca54ec227212d8c491b563a7276f6c100e30ee856c"}, + {file = "cassandra_driver-3.29.2-cp312-cp312-win_amd64.whl", hash = "sha256:6c74610f56a4c53863a5d44a2af9c6c3405da19d51966fabd85d7f927d5c6abc"}, + {file = "cassandra_driver-3.29.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c86b0a796ff67d66de7df5f85243832a4dc853217f6a3eade84694f6f4fae151"}, + {file = "cassandra_driver-3.29.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c53700b0d1f8c1d777eaa9e9fb6d17839d9a83f27a61649e0cbaa15d9d3df34b"}, + {file = "cassandra_driver-3.29.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d348c769aa6c37919e7d6247e8cf09c23d387b7834a340408bd7d611f174d80"}, + {file = "cassandra_driver-3.29.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8c496318e3c136cf12ab21e1598fee4b48ea1c71746ea8cc9d32e4dcd09cb93"}, + {file = "cassandra_driver-3.29.2-cp38-cp38-win32.whl", hash = "sha256:d180183451bec81c15e0441fa37a63dc52c6489e860e832cadd854373b423141"}, + {file = "cassandra_driver-3.29.2-cp38-cp38-win_amd64.whl", hash = "sha256:a66b20c421d8fb21f18bd0ac713de6f09c5c25b6ab3d6043c3779b9c012d7c98"}, + {file = "cassandra_driver-3.29.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:70d4d0dce373943308ad461a425fc70a23d0f524859367b8c6fc292400f39954"}, + {file = "cassandra_driver-3.29.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b86427fab4d5a96e91ad82bb9338d4101ae4d3758ba96c356e0198da3de4d350"}, + {file = "cassandra_driver-3.29.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c25b42e1a99f377a933d79ae93ea27601e337a5abb7bb843a0e951cf1b3836f7"}, + {file = "cassandra_driver-3.29.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e36437288d6cd6f6c74b8ee5997692126e24adc2da3d031dc11c7dfea8bc220"}, + {file = "cassandra_driver-3.29.2-cp39-cp39-win32.whl", hash = "sha256:e967c1341a651f03bdc466f3835d72d3c0a0648b562035e6d780fa0b796c02f6"}, + {file = "cassandra_driver-3.29.2-cp39-cp39-win_amd64.whl", hash = "sha256:c5a9aab2367e8aad48ae853847a5a8985749ac5f102676de2c119b33fef13b42"}, +] + +[package.dependencies] +geomet = ">=0.1,<0.3" + +[package.extras] +cle = ["cryptography (>=35.0)"] +graph = ["gremlinpython (==3.4.6)"] + +[[package]] +name = "cassio" +version = "0.1.9" +description = "A framework-agnostic Python library to seamlessly integrate Apache Cassandra(R) with ML/LLM/genAI workloads." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "cassio-0.1.9-py3-none-any.whl", hash = "sha256:0139d44d5bbd475df77806366c845465f6b08181c0e98ad9acec9f4047d6ab53"}, + {file = "cassio-0.1.9.tar.gz", hash = "sha256:5c3e5d15769396a98f0f260aead6a2c6e707ab1a13fe94f24341d5ef6bdddd6a"}, +] + +[package.dependencies] +cassandra-driver = ">=3.28.0,<4.0.0" +numpy = ">=1.0" +requests = ">=2.31.0,<3.0.0" + [[package]] name = "certifi" version = "2024.8.30" @@ -1196,6 +1259,21 @@ docs = ["POT", "Pyro4", "Pyro4 (>=4.27)", "annoy", "matplotlib", "memory-profile test = ["POT", "pytest", "pytest-cov", "testfixtures", "visdom (>=0.1.8,!=0.1.8.7)"] test-win = ["POT", "pytest", "pytest-cov", "testfixtures"] +[[package]] +name = "geomet" +version = "0.2.1.post1" +description = "GeoJSON <-> WKT/WKB conversion utilities" +optional = false +python-versions = ">2.6, !=3.3.*, <4" +files = [ + {file = "geomet-0.2.1.post1-py3-none-any.whl", hash = "sha256:a41a1e336b381416d6cbed7f1745c848e91defaa4d4c1bdc1312732e46ffad2b"}, + {file = "geomet-0.2.1.post1.tar.gz", hash = "sha256:91d754f7c298cbfcabd3befdb69c641c27fe75e808b27aa55028605761d17e95"}, +] + +[package.dependencies] +click = "*" +six = "*" + [[package]] name = "graspologic" version = "3.4.1" @@ -4809,4 +4887,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "daefb001881f6f6fb97fac78381652dfdf99def5c0d7a04b7cf65de4ae3959a3" +content-hash = "e19aaaa99890a5336ec84db246280bcef611d32b8127f0ecf5873cb12feb0ef9" diff --git a/pyproject.toml b/pyproject.toml index ec32220d04..3323dcd043 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ datashaper = "^0.0.49" azure-search-documents = "^11.4.0" lancedb = "^0.13.0" + # Async IO aiolimiter = "^1.1.0" aiofiles = "^24.1.0" @@ -87,6 +88,7 @@ azure-identity = "^1.17.1" json-repair = "^0.28.4" future = "^1.0.0" # Needed until graspologic fixes their dependency +cassio = "^0.1.9" [tool.poetry.group.dev.dependencies] coverage = "^7.6.0"