From 7bb3eee8171e46007f483e752159bfbef179babb Mon Sep 17 00:00:00 2001 From: Michael Ulianich Date: Mon, 6 Feb 2023 18:45:56 +0200 Subject: [PATCH] draft --- pymilvus/asyncio/client/grpc_handler.py | 72 +++++++++ pymilvus/asyncio/orm/collection.py | 196 ++++++++++++++++++++++++ pymilvus/orm/collection.py | 50 ++++-- tests/test_collection.py | 17 +- 4 files changed, 308 insertions(+), 27 deletions(-) create mode 100644 pymilvus/asyncio/orm/collection.py diff --git a/pymilvus/asyncio/client/grpc_handler.py b/pymilvus/asyncio/client/grpc_handler.py index 8336a6e04..6a52fbf2d 100644 --- a/pymilvus/asyncio/client/grpc_handler.py +++ b/pymilvus/asyncio/client/grpc_handler.py @@ -1,9 +1,19 @@ +import asyncio + import grpc.aio from ...client.grpc_handler import ( AbstractGrpcHandler, Status, MilvusException, + # retry_on_rpc_failure, + check_pass_param, + get_consistency_level, + ts_utils, + Prepare, + CollectionSchema, + DescribeCollectionException, + ChunkedQueryResult, ) @@ -21,3 +31,65 @@ async def _channel_ready(self): def _header_adder_interceptor(self, header, value): raise NotImplementedError # TODO + + # TODO: @retry_on_rpc_failure() + async def describe_collection(self, collection_name, timeout=None, **kwargs): + check_pass_param(collection_name=collection_name) + request = Prepare.describe_collection_request(collection_name) + response = await self._stub.DescribeCollection(request, timeout=timeout) + + status = response.status + if status.error_code != 0: + raise DescribeCollectionException(status.error_code, status.reason) + + return CollectionSchema(raw=response).dict() + + async def _execute_search_requests(self, requests, timeout=None, *, auto_id=True, round_decimal=-1, **kwargs): + async def _raise_milvus_exception_on_error_response(awaitable_response): + response = await awaitable_response + if response.status.error_code != 0: + raise MilvusException(response.status.error_code, response.status.reason) + return response + + raws: list = await asyncio.gather(*( + _raise_milvus_exception_on_error_response( + self._stub.Search(request, timeout=timeout) + ) + for request in requests + )) + return ChunkedQueryResult(raws, auto_id, round_decimal) + + # TODO: @retry_on_rpc_failure(retry_on_deadline=False) + async def search( + self, collection_name, data, anns_field, param, limit, + expression=None, partition_names=None, output_fields=None, + round_decimal=-1, timeout=None, schema=None, **kwargs, + ): + check_pass_param( + limit=limit, + round_decimal=round_decimal, + anns_field=anns_field, + search_data=data, + partition_name_array=partition_names, + output_fields=output_fields, + travel_timestamp=kwargs.get("travel_timestamp", 0), + guarantee_timestamp=kwargs.get("guarantee_timestamp", 0) + ) + + if schema is None: + schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs) + + consistency_level = schema["consistency_level"] + # overwrite the consistency level defined when user created the collection + consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level)) + + ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs) + + requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, schema, + expression, partition_names, output_fields, round_decimal, + **kwargs) + + auto_id = schema["auto_id"] + return await self._execute_search_requests( + requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs, + ) diff --git a/pymilvus/asyncio/orm/collection.py b/pymilvus/asyncio/orm/collection.py new file mode 100644 index 000000000..b3c32657a --- /dev/null +++ b/pymilvus/asyncio/orm/collection.py @@ -0,0 +1,196 @@ +import asyncio + +from ...orm.collection import ( + AbstractCollection, DataTypeNotMatchException, ExceptionsMessage, + SearchResult, CollectionSchema, + DEFAULT_CONSISTENCY_LEVEL, + cmp_consistency_level, + SchemaNotReadyException, + check_schema, + get_consistency_level, +) +from .connections import connections, Connections as AsyncConnections + + +class Collection(AbstractCollection[AsyncConnections]): + connections = connections + + def _init(self): + self._ready = asyncio.create_task(self._async_init()) + + @property + def _schema(self): + raise AssertionError("await self._ready before accessing self._schema") + + @property + def _schema_dict(self): + raise AssertionError("await self._ready before accessing self._schema_dict") + + async def _async_init(self): + schema = self._init_schema + kwargs = self._kwargs + conn = self._get_connection() + + has = await conn.has_collection(self._name, **kwargs) + if has: + resp = await conn.describe_collection(self._name, **kwargs) + consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + arg_consistency_level = kwargs.get("consistency_level", consistency_level) + if not cmp_consistency_level(consistency_level, arg_consistency_level): + raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent) + server_schema = CollectionSchema.construct_from_dict(resp) + if schema is None: + self._schema = server_schema + else: + if not isinstance(schema, CollectionSchema): + raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) + if server_schema != schema: + raise SchemaNotReadyException(message=ExceptionsMessage.SchemaInconsistent) + self._schema = schema + + else: + if schema is None: + raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name) + if isinstance(schema, CollectionSchema): + check_schema(schema) + consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) + await conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs) + self._schema = schema + else: + raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) + + self._schema_dict = self._schema.to_dict() + self._schema_dict["consistency_level"] = consistency_level + + async def search(self, data, anns_field, param, limit, expr=None, partition_names=None, + output_fields=None, timeout=None, round_decimal=-1, **kwargs): + """ Conducts a vector similarity search with an optional boolean expression as filter. + + Args: + data (``List[List[float]]``): The vectors of search data. + the length of data is number of query (nq), and the dim of every vector in data must be equal to + the vector field's of collection. + anns_field (``str``): The name of the vector field used to search of collection. + param (``dict[str, Any]``): + + The parameters of search. The followings are valid keys of param. + + * *nprobe*, *ef*, *search_k*, etc + Corresponding search params for a certain index. + + * *metric_type* (``str``) + similar metricy types, the value must be of type str. + + * *offset* (``int``, optional) + offset for pagination. + + * *limit* (``int``, optional) + limit for the search results and pagination. + + example for param:: + + { + "nprobe": 128, + "metric_type": "L2", + "offset": 10, + "limit": 10, + } + + limit (``int``): The max number of returned record, also known as `topk`. + expr (``str``): The boolean expression used to filter attribute. Default to None. + + example for expr:: + + "id_field >= 0", "id_field in [1, 2, 3, 4]" + + partition_names (``List[str]``, optional): The names of partitions to search on. Default to None. + output_fields (``List[str]``, optional): + The name of fields to return in the search result. Can only get scalar fields. + round_decimal (``int``, optional): The specified number of decimal places of returned distance. + Defaults to -1 means no round to returned distance. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. + If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + **kwargs (``dict``): Optional search params + + * *consistency_level* (``str/int``, optional) + Which consistency level to use when searching in the collection. + + Options of consistency level: Strong, Bounded, Eventually, Session, Customized. + + Note: this parameter will overwrite the same parameter specified when user created the collection, + if no consistency level was specified, search will use the consistency level when you create the + collection. + + * *guarantee_timestamp* (``int``, optional) + Instructs Milvus to see all operations performed before this timestamp. + By default Milvus will search all operations performed to date. + + Note: only valid in Customized consistency level. + + * *graceful_time* (``int``, optional) + Search will use the (current_timestamp - the graceful_time) as the + `guarantee_timestamp`. By default with 5s. + + Note: only valid in Bounded consistency level + + * *travel_timestamp* (``int``, optional) + A specific timestamp to get results based on a data view at. + + Returns: + SearchResult: + Returns ``SearchResult`` + + .. _Metric type documentations: + https://milvus.io/docs/v2.2.x/metric.md + .. _Index documentations: + https://milvus.io/docs/v2.2.x/index.md + .. _How guarantee ts works: + https://github.com/milvus-io/milvus/blob/master/docs/developer_guides/how-guarantee-ts-works.md + + Raises: + MilvusException: If anything goes wrong + + Examples: + >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> import random + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ... ]) + >>> collection = Collection("test_collection_search", schema) + >>> # insert + >>> data = [ + ... [i for i in range(10)], + ... [[random.random() for _ in range(2)] for _ in range(10)], + ... ] + >>> collection.insert(data) + >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> collection.load() + >>> # search + >>> search_param = { + ... "data": [[1.0, 1.0]], + ... "anns_field": "films", + ... "param": {"metric_type": "L2", "offset": 1}, + ... "limit": 2, + ... "expr": "film_id > 0", + ... } + >>> res = collection.search(**search_param) + >>> assert len(res) == 1 + >>> hits = res[0] + >>> assert len(hits) == 2 + >>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ") + - Total hits: 2, hits ids: [8, 5] + >>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ") + - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 + """ + await self._ready + if expr is not None and not isinstance(expr, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) + + conn = self._get_connection() + res = await conn.search( + self._name, data, anns_field, param, limit, expr, + partition_names, output_fields, round_decimal, timeout=timeout, + schema=self._schema_dict, **kwargs) + return SearchResult(res) diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 288a7347c..00a9a33cd 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -12,10 +12,11 @@ import copy import json +import typing from typing import List import pandas -from .connections import connections +from .connections import connections, AbstractConnections, Connections as SyncConnections from .schema import ( CollectionSchema, FieldSchema, @@ -46,8 +47,11 @@ from ..client.configs import DefaultConfigs +ConnectionsT = typing.TypeVar('ConnectionsT', bound=AbstractConnections) + +class AbstractCollection(typing.Generic[ConnectionsT]): + connections: ConnectionsT -class Collection: def __init__(self, name: str, schema: CollectionSchema=None, using: str="default", shards_num: int=2, **kwargs): """ Constructs a collection by name, schema and other parameters. @@ -91,17 +95,38 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default self._using = using self._shards_num = shards_num self._kwargs = kwargs + self._init_schema = schema + self._init() + + def _init(self): + raise NotImplementedError + + def _get_connection(self): + return self.connections._fetch_handler(self._using) + + @property + def name(self) -> str: + """str: the name of the collection. """ + return self._name + + +class Collection(AbstractCollection[SyncConnections]): + connections = connections + + def _init(self): + assert not (hasattr(self, '_schema') or hasattr(self, '_schema_dict')), "_init() must be called only once" + schema = self._init_schema + kwargs = self._kwargs conn = self._get_connection() has = conn.has_collection(self._name, **kwargs) if has: resp = conn.describe_collection(self._name, **kwargs) - s_consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) - arg_consistency_level = kwargs.get("consistency_level", s_consistency_level) - if not cmp_consistency_level(s_consistency_level, arg_consistency_level): + consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + arg_consistency_level = kwargs.get("consistency_level", consistency_level) + if not cmp_consistency_level(consistency_level, arg_consistency_level): raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent) server_schema = CollectionSchema.construct_from_dict(resp) - self._consistency_level = s_consistency_level if schema is None: self._schema = server_schema else: @@ -113,18 +138,17 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default else: if schema is None: - raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name) + raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name) if isinstance(schema, CollectionSchema): check_schema(schema) consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs) self._schema = schema - self._consistency_level = consistency_level else: raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) self._schema_dict = self._schema.to_dict() - self._schema_dict["consistency_level"] = self._consistency_level + self._schema_dict["consistency_level"] = consistency_level def __repr__(self): _dict = { @@ -139,9 +163,6 @@ def __repr__(self): r.append(s.format(k, v)) return "".join(r) - def _get_connection(self): - return connections._fetch_handler(self._using) - @classmethod def construct_from_dataframe(cls, name, dataframe, **kwargs): if dataframe is None: @@ -212,11 +233,6 @@ def description(self) -> str: """str: a text description of the collection. """ return self._schema.description - @property - def name(self) -> str: - """str: the name of the collection. """ - return self._name - @property def is_empty(self) -> bool: """bool: whether the collection is empty or not.""" diff --git a/tests/test_collection.py b/tests/test_collection.py index d0fca7bcd..cb2be4d45 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -10,14 +10,13 @@ class TestCollections: - - # @pytest.fixture(scope="function",) - # def collection(self): - # name = gen_collection_name() - # schema = gen_schema() - # yield Collection(name, schema=schema) - # if connections.get_connection().has_collection(name): - # connections.get_connection().drop_collection(name) + @pytest.fixture(scope="function") + def collection(self): + name = gen_collection_name() + schema = gen_schema() + yield Collection(name, schema=schema) + if connections.get_connection().has_collection(name): + connections.get_connection().drop_collection(name) def test_collection_by_DataFrame(self): from pymilvus import Collection @@ -54,11 +53,9 @@ def test_collection_by_DataFrame(self): with mock.patch(f"{prefix}.close", return_value=None): connections.disconnect("default") - @pytest.mark.xfail def test_constructor(self, collection): assert type(collection) is Collection - @pytest.mark.xfail def test_construct_from_dataframe(self): assert type(Collection.construct_from_dataframe(gen_collection_name(), gen_pd_data(default_nb), primary_field="int64")[0]) is Collection