diff --git a/pymilvus/asyncio/__init__.py b/pymilvus/asyncio/__init__.py new file mode 100644 index 000000000..9b9ad2cb6 --- /dev/null +++ b/pymilvus/asyncio/__init__.py @@ -0,0 +1,2 @@ +from .orm.collection import Collection +from .orm.connections import connections, Connections diff --git a/pymilvus/asyncio/client/__init__.py b/pymilvus/asyncio/client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymilvus/asyncio/client/grpc_handler.py b/pymilvus/asyncio/client/grpc_handler.py new file mode 100644 index 000000000..1cdaf9c74 --- /dev/null +++ b/pymilvus/asyncio/client/grpc_handler.py @@ -0,0 +1,202 @@ +import asyncio +import copy + +import grpc.aio + +from ...client.grpc_handler import ( + AbstractGrpcHandler, + Status, + MilvusException, + # retry_on_rpc_failure, + check_pass_param, + get_consistency_level, + ts_utils, + Prepare, + CollectionSchema, + DescribeCollectionException, + ChunkedQueryResult, + common_pb2, + check_invalid_binary_vector, + ParamError, + milvus_types, + MutationResult, + DefaultConfigs, + DataType, + check_index_params, +) + + +class GrpcHandler(AbstractGrpcHandler[grpc.aio.Channel]): + _insecure_channel = staticmethod(grpc.aio.insecure_channel) + _secure_channel = staticmethod(grpc.aio.secure_channel) + + async def _channel_ready(self): + if self._channel is None: + raise MilvusException( + Status.CONNECT_FAILED, + 'No channel in handler, please setup grpc channel first', + ) + await self._channel.channel_ready() + + def _header_adder_interceptor(self, header, value): + raise NotImplementedError # TODO + + # TODO: @retry_on_rpc_failure() + async def create_collection(self, collection_name, fields, shards_num=2, timeout=None, **kwargs): + request = Prepare.create_collection_request(collection_name, fields, shards_num=shards_num, **kwargs) + + status = await self._stub.CreateCollection(request, timeout=timeout) + if status.error_code != 0: + raise MilvusException(status.error_code, status.reason) + + # TODO: @retry_on_rpc_failure() + async def has_collection(self, collection_name, timeout=None, **kwargs): + check_pass_param(collection_name=collection_name) + request = Prepare.describe_collection_request(collection_name) + reply = await self._stub.DescribeCollection(request, timeout=timeout) + + if reply.status.error_code == common_pb2.Success: + return True + + # TODO: Workaround for unreasonable describe collection results and error_code + if reply.status.error_code == common_pb2.UnexpectedError and "can\'t find collection" in reply.status.reason: + return False + + raise MilvusException(reply.status.error_code, reply.status.reason) + + # TODO: @retry_on_rpc_failure() + async def describe_collection(self, collection_name, timeout=None, **kwargs): + check_pass_param(collection_name=collection_name) + request = Prepare.describe_collection_request(collection_name) + response = await self._stub.DescribeCollection(request, timeout=timeout) + + status = response.status + if status.error_code != 0: + raise DescribeCollectionException(status.error_code, status.reason) + + return CollectionSchema(raw=response).dict() + + # TODO: @retry_on_rpc_failure() + async def batch_insert(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + if not check_invalid_binary_vector(entities): + raise ParamError(message="Invalid binary vector data exists") + insert_param = kwargs.get('insert_param', None) + if insert_param and not isinstance(insert_param, milvus_types.RowBatch): + raise ParamError(message="The value of key 'insert_param' is invalid") + if not isinstance(entities, list): + raise ParamError(message="None entities, please provide valid entities.") + + collection_schema = kwargs.get("schema", None) + if not collection_schema: + collection_schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs) + + fields_info = collection_schema["fields"] + request = insert_param or Prepare.batch_insert_param(collection_name, entities, partition_name, fields_info) + response = await self._stub.Insert(request, timeout=timeout) + if response.status.error_code != 0: + raise MilvusException(response.status.error_code, response.status.reason) + m = MutationResult(response) + ts_utils.update_collection_ts(collection_name, m.timestamp) + return m + + async def _execute_search_requests(self, requests, timeout=None, *, auto_id=True, round_decimal=-1, **kwargs): + async def _raise_milvus_exception_on_error_response(awaitable_response): + response = await awaitable_response + if response.status.error_code != 0: + raise MilvusException(response.status.error_code, response.status.reason) + return response + + raws: list = await asyncio.gather(*( + _raise_milvus_exception_on_error_response( + self._stub.Search(request, timeout=timeout) + ) + for request in requests + )) + return ChunkedQueryResult(raws, auto_id, round_decimal) + + # TODO: @retry_on_rpc_failure(retry_on_deadline=False) + async def search( + self, collection_name, data, anns_field, param, limit, + expression=None, partition_names=None, output_fields=None, + round_decimal=-1, timeout=None, schema=None, **kwargs, + ): + check_pass_param( + limit=limit, + round_decimal=round_decimal, + anns_field=anns_field, + search_data=data, + partition_name_array=partition_names, + output_fields=output_fields, + travel_timestamp=kwargs.get("travel_timestamp", 0), + guarantee_timestamp=kwargs.get("guarantee_timestamp", 0) + ) + + if schema is None: + schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs) + + consistency_level = schema["consistency_level"] + # overwrite the consistency level defined when user created the collection + consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level)) + + ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs) + + requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, schema, + expression, partition_names, output_fields, round_decimal, + **kwargs) + + auto_id = schema["auto_id"] + return await self._execute_search_requests( + requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs, + ) + + # TODO: @retry_on_rpc_failure() + async def create_index(self, collection_name, field_name, params, timeout=None, **kwargs): + # for historical reason, index_name contained in kwargs. + index_name = kwargs.pop("index_name", DefaultConfigs.IndexName) + copy_kwargs = copy.deepcopy(kwargs) + + collection_desc = await self.describe_collection(collection_name, timeout=timeout, **copy_kwargs) + + valid_field = False + for fields in collection_desc["fields"]: + if field_name != fields["name"]: + continue + valid_field = True + if fields["type"] != DataType.FLOAT_VECTOR and fields["type"] != DataType.BINARY_VECTOR: + break + # check index params on vector field. + check_index_params(params) + if not valid_field: + raise MilvusException(message=f"cannot create index on non-existed field: {field_name}") + + index_param = Prepare.create_index_request(collection_name, field_name, params, index_name=index_name) + + status = await self._stub.CreateIndex(index_param, timeout=timeout) + if status.error_code != 0: + raise MilvusException(status.error_code, status.reason) + + return Status(status.error_code, status.reason) + + # TODO: @retry_on_rpc_failure() + async def load_collection(self, collection_name, replica_number=1, timeout=None, **kwargs): + check_pass_param(collection_name=collection_name, replica_number=replica_number) + _refresh = kwargs.get("_refresh", False) + _resource_groups = kwargs.get("_resource_groups") + request = Prepare.load_collection("", collection_name, replica_number, _refresh, _resource_groups) + response = await self._stub.LoadCollection(request, timeout=timeout) + if response.error_code != 0: + raise MilvusException(response.error_code, response.reason) + + # TODO: @retry_on_rpc_failure() + async def load_partitions(self, collection_name, partition_names, replica_number=1, timeout=None, **kwargs): + check_pass_param( + collection_name=collection_name, + partition_name_array=partition_names, + replica_number=replica_number) + _refresh = kwargs.get("_refresh", False) + _resource_groups = kwargs.get("_resource_groups") + request = Prepare.load_partitions("", collection_name, partition_names, replica_number, _refresh, + _resource_groups) + response = await self._stub.LoadPartitions(request, timeout=timeout) + if response.error_code != 0: + raise MilvusException(response.error_code, response.reason) diff --git a/pymilvus/asyncio/orm/__init__.py b/pymilvus/asyncio/orm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymilvus/asyncio/orm/collection.py b/pymilvus/asyncio/orm/collection.py new file mode 100644 index 000000000..d3613106b --- /dev/null +++ b/pymilvus/asyncio/orm/collection.py @@ -0,0 +1,382 @@ +import asyncio +import typing + +import pandas + +from ...orm.collection import ( + AbstractCollection, DataTypeNotMatchException, ExceptionsMessage, + SearchResult, CollectionSchema, + DEFAULT_CONSISTENCY_LEVEL, + cmp_consistency_level, + SchemaNotReadyException, + check_schema, + get_consistency_level, + MutationResult, + check_insert_data_schema, + Prepare, +) +from ..client.grpc_handler import GrpcHandler as AsyncGrpcHandler +from .connections import connections, Connections as AsyncConnections + + +class Collection(AbstractCollection[AsyncConnections]): + connections = connections + + def _init(self): + self._ready = asyncio.create_task(self._async_init()) + + # DEBUG + def __getattr__(self, attr): + if attr in ('_schema', '_schema_dict'): + raise AssertionError(f"await self._ready before accessing self.{attr}") + raise AttributeError(f"{type(self).__name__!r} object has no attribute {attr!r}") + + # DEBUG + def _get_connection(self): + ret = super()._get_connection() + assert isinstance(ret, AsyncGrpcHandler) + return ret + + async def _async_init(self): + schema = self._init_schema + kwargs = self._kwargs + conn = self._get_connection() + + has = await conn.has_collection(self._name, **kwargs) + if has: + resp = await conn.describe_collection(self._name, **kwargs) + consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + arg_consistency_level = kwargs.get("consistency_level", consistency_level) + if not cmp_consistency_level(consistency_level, arg_consistency_level): + raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent) + server_schema = CollectionSchema.construct_from_dict(resp) + if schema is None: + self._schema = server_schema + else: + if not isinstance(schema, CollectionSchema): + raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) + if server_schema != schema: + raise SchemaNotReadyException(message=ExceptionsMessage.SchemaInconsistent) + self._schema = schema + + else: + if schema is None: + raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name) + if isinstance(schema, CollectionSchema): + check_schema(schema) + consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) + await conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs) + self._schema = schema + else: + raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) + + self._schema_dict = self._schema.to_dict() + self._schema_dict["consistency_level"] = consistency_level + + async def load(self, partition_names=None, replica_number=1, timeout=None, **kwargs): + """ Load the data into memory. + + Args: + partition_names (``List[str]``): The specified partitions to load. + replica_number (``int``, optional): The replica number to load, defaults to 1. + timeout (``float``, optional): an optional duration of time in seconds to allow for the RPCs. + If timeout is not set, the client keeps waiting until the server responds or an error occurs. + **kwargs (``dict``, optional): + + * *_async*(``bool``) + Indicate if invoke asynchronously. + + * *_refresh*(``bool``) + Whether to enable refresh mode(renew the segment list of this collection before loading). + * *_resource_groups(``List[str]``) + Specify resource groups which can be used during loading. + + Raises: + MilvusException: If anything goes wrong. + + Examples: + >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ... ]) + >>> collection = Collection("test_collection_load", schema) + >>> collection.insert([[1, 2], [[1.0, 2.0], [3.0, 4.0]]]) + >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> collection.load() + """ + conn = self._get_connection() + if partition_names is not None: + await conn.load_partitions( + self._name, partition_names, replica_number=replica_number, timeout=timeout, **kwargs, + ) + else: + await conn.load_collection( + self._name, replica_number=replica_number, timeout=timeout, **kwargs, + ) + + + async def insert( + self, + data: typing.Union[typing.List, pandas.DataFrame], + partition_name: str = None, timeout=None, **kwargs + ) -> MutationResult: + """ Insert data into the collection. + + Args: + data (``list/tuple/pandas.DataFrame``): The specified data to insert + partition_name (``str``): The partition name which the data will be inserted to, + if partition name is not passed, then the data will be inserted to "_default" partition + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. + If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + Returns: + MutationResult: contains 2 properties `insert_count`, and, `primary_keys` + `insert_count`: how may entites have been inserted into Milvus, + `primary_keys`: list of primary keys of the inserted entities + Raises: + MilvusException: If anything goes wrong. + + Examples: + >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> import random + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ... ]) + >>> collection = Collection("test_collection_insert", schema) + >>> data = [ + ... [random.randint(1, 100) for _ in range(10)], + ... [[random.random() for _ in range(2)] for _ in range(10)], + ... ] + >>> res = collection.insert(data) + >>> res.insert_count + 10 + """ + await self._ready + if data is None: + return MutationResult(data) + check_insert_data_schema(self._schema, data) + entities = Prepare.prepare_insert_data(data, self._schema) + + conn = self._get_connection() + + res = await conn.batch_insert( + self._name, entities, partition_name, timeout=timeout, schema=self._schema_dict, **kwargs, + ) + + return MutationResult(res) + + async def delete(self, expr, partition_name=None, timeout=None, **kwargs): + """ Delete entities with an expression condition. + + Args: + expr (``str``): The specified data to insert. + partition_names (``List[str]``): Name of partitions to delete entities. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. + If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + + Returns: + MutationResult: contains `delete_count` properties represents how many entities might be deleted. + + Raises: + MilvusException: If anything goes wrong. + + Examples: + >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> import random + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("film_date", DataType.INT64), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2), + ... ]) + >>> collection = Collection("test_collection_delete", schema) + >>> # insert + >>> data = [ + ... [i for i in range(10)], + ... [i + 2000 for i in range(10)], + ... [[random.random() for _ in range(2)] for _ in range(10)], + ... ] + >>> collection.insert(data) + >>> res = collection.delete("film_id in [ 0, 1 ]") + >>> print(f"- Deleted entities: {res}") + - Delete results: [0, 1] + """ + + conn = self._get_connection() + res = await conn.delete(self._name, expr, partition_name, timeout=timeout, **kwargs) + return MutationResult(res) + + async def search(self, data, anns_field, param, limit, expr=None, partition_names=None, + output_fields=None, timeout=None, round_decimal=-1, **kwargs): + """ Conducts a vector similarity search with an optional boolean expression as filter. + + Args: + data (``List[List[float]]``): The vectors of search data. + the length of data is number of query (nq), and the dim of every vector in data must be equal to + the vector field's of collection. + anns_field (``str``): The name of the vector field used to search of collection. + param (``dict[str, Any]``): + + The parameters of search. The followings are valid keys of param. + + * *nprobe*, *ef*, *search_k*, etc + Corresponding search params for a certain index. + + * *metric_type* (``str``) + similar metricy types, the value must be of type str. + + * *offset* (``int``, optional) + offset for pagination. + + * *limit* (``int``, optional) + limit for the search results and pagination. + + example for param:: + + { + "nprobe": 128, + "metric_type": "L2", + "offset": 10, + "limit": 10, + } + + limit (``int``): The max number of returned record, also known as `topk`. + expr (``str``): The boolean expression used to filter attribute. Default to None. + + example for expr:: + + "id_field >= 0", "id_field in [1, 2, 3, 4]" + + partition_names (``List[str]``, optional): The names of partitions to search on. Default to None. + output_fields (``List[str]``, optional): + The name of fields to return in the search result. Can only get scalar fields. + round_decimal (``int``, optional): The specified number of decimal places of returned distance. + Defaults to -1 means no round to returned distance. + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. + If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + **kwargs (``dict``): Optional search params + + * *consistency_level* (``str/int``, optional) + Which consistency level to use when searching in the collection. + + Options of consistency level: Strong, Bounded, Eventually, Session, Customized. + + Note: this parameter will overwrite the same parameter specified when user created the collection, + if no consistency level was specified, search will use the consistency level when you create the + collection. + + * *guarantee_timestamp* (``int``, optional) + Instructs Milvus to see all operations performed before this timestamp. + By default Milvus will search all operations performed to date. + + Note: only valid in Customized consistency level. + + * *graceful_time* (``int``, optional) + Search will use the (current_timestamp - the graceful_time) as the + `guarantee_timestamp`. By default with 5s. + + Note: only valid in Bounded consistency level + + * *travel_timestamp* (``int``, optional) + A specific timestamp to get results based on a data view at. + + Returns: + SearchResult: + Returns ``SearchResult`` + + .. _Metric type documentations: + https://milvus.io/docs/v2.2.x/metric.md + .. _Index documentations: + https://milvus.io/docs/v2.2.x/index.md + .. _How guarantee ts works: + https://github.com/milvus-io/milvus/blob/master/docs/developer_guides/how-guarantee-ts-works.md + + Raises: + MilvusException: If anything goes wrong + + Examples: + >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> import random + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ... ]) + >>> collection = Collection("test_collection_search", schema) + >>> # insert + >>> data = [ + ... [i for i in range(10)], + ... [[random.random() for _ in range(2)] for _ in range(10)], + ... ] + >>> collection.insert(data) + >>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + >>> collection.load() + >>> # search + >>> search_param = { + ... "data": [[1.0, 1.0]], + ... "anns_field": "films", + ... "param": {"metric_type": "L2", "offset": 1}, + ... "limit": 2, + ... "expr": "film_id > 0", + ... } + >>> res = collection.search(**search_param) + >>> assert len(res) == 1 + >>> hits = res[0] + >>> assert len(hits) == 2 + >>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ") + - Total hits: 2, hits ids: [8, 5] + >>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ") + - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 + """ + await self._ready + if expr is not None and not isinstance(expr, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) + + conn = self._get_connection() + res = await conn.search( + self._name, data, anns_field, param, limit, expr, + partition_names, output_fields, round_decimal, timeout=timeout, + schema=self._schema_dict, **kwargs) + return SearchResult(res) + + async def create_index(self, field_name, index_params={}, timeout=None, **kwargs): + """Creates index for a specified field, with a index name. + + Args: + field_name (``str``): The name of the field to create index + index_params (``dict``): The parameters to index + * *index_type* (``str``) + "index_type" as the key, example values: "FLAT", "IVF_FLAT", etc. + + * *metric_type* (``str``) + "metric_type" as the key, examples values: "L2", "IP", "JACCARD". + + * *params* (``dict``) + "params" as the key, corresponding index params. + + timeout (``float``, optional): An optional duration of time in seconds to allow for the RPC. When timeout + is set to None, client waits until server response or error occur. + index_name (``str``): The name of index which will be created, must be unique. + If no index name is specified, the default index name will be used. + + Raises: + MilvusException: If anything goes wrong. + + Examples: + >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ... ]) + >>> collection = Collection("test_collection_create_index", schema) + >>> index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + >>> collection.create_index("films", index_params, index_name="idx") + Status(code=0, message='') + """ + conn = self._get_connection() + return await conn.create_index(self._name, field_name, index_params, timeout=timeout, **kwargs) + diff --git a/pymilvus/asyncio/orm/connections.py b/pymilvus/asyncio/orm/connections.py new file mode 100644 index 000000000..8bcfc7e28 --- /dev/null +++ b/pymilvus/asyncio/orm/connections.py @@ -0,0 +1,27 @@ +import copy +import typing + +from ...orm.connections import AbstractConnections +from ..client.grpc_handler import GrpcHandler as AsyncGrpcHandler + + +# pylint: disable=W0236 +class Connections(AbstractConnections[AsyncGrpcHandler, typing.Awaitable[None]]): + async def _disconnect(self, alias: str, *, remove_connection: bool): + if alias in self._connected_alias: + await self._connected_alias.pop(alias).close() + if remove_connection: + self._alias.pop(alias, None) + + async def _connect(self, alias, **kwargs): + gh = AsyncGrpcHandler(**kwargs) + + await gh._channel_ready() + kwargs.pop('password') + kwargs.pop('secure', None) + + self._connected_alias[alias] = gh + self._alias[alias] = copy.deepcopy(kwargs) + + +connections = Connections() diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 031d9b238..7cd87fa64 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -2,6 +2,7 @@ import json import copy import base64 +import typing from urllib import parse import grpc @@ -66,7 +67,14 @@ from ..decorators import retry_on_rpc_failure -class GrpcHandler: +GrpcChannelT = typing.TypeVar('GrpcChannelT', grpc.Channel, grpc.aio.Channel) + + +class AbstractGrpcHandler(typing.Generic[GrpcChannelT]): + _insecure_channel: typing.Callable[..., GrpcChannelT] + _secure_channel: typing.Callable[..., GrpcChannelT] + _channel: typing.Optional[GrpcChannelT] + def __init__(self, uri=config.GRPC_URI, host="", port="", channel=None, **kwargs): self._stub = None self._channel = channel @@ -108,25 +116,17 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - def _wait_for_channel_ready(self, timeout=10): - if self._channel is not None: - try: - grpc.channel_ready_future(self._channel).result(timeout=timeout) - return - except grpc.FutureTimeoutError as e: - raise MilvusException(Status.CONNECT_FAILED, - f'Fail connecting to server on {self._address}. Timeout') from e - - raise MilvusException(Status.CONNECT_FAILED, 'No channel in handler, please setup grpc channel first') - def close(self): - self._channel.close() + return self._channel.close() + + def _header_adder_interceptor(self, header, value): + raise NotImplementedError("this is abstract method") def _setup_authorization_interceptor(self, user, password): if user and password: authorization = base64.b64encode(f"{user}:{password}".encode('utf-8')) key = "authorization" - self._authorization_interceptor = interceptor.header_adder_interceptor(key, authorization) + self._authorization_interceptor = self._header_adder_interceptor(key, authorization) def _setup_grpc_channel(self): """ Create a ddl grpc channel """ @@ -137,7 +137,7 @@ def _setup_grpc_channel(self): ('grpc.keepalive_time_ms', 55000), ] if not self._secure: - self._channel = grpc.insecure_channel( + self._channel = self._insecure_channel( self._address, options=opts, ) @@ -160,7 +160,7 @@ def _setup_grpc_channel(self): else: creds = grpc.ssl_channel_credentials(root_certificates=None, private_key=None, certificate_chain=None) - self._channel = grpc.secure_channel( + self._channel = self._secure_channel( self._address, creds, options=opts @@ -170,11 +170,11 @@ def _setup_grpc_channel(self): if self._authorization_interceptor: self._final_channel = grpc.intercept_channel(self._final_channel, self._authorization_interceptor) if self._log_level: - log_level_interceptor = interceptor.header_adder_interceptor("log_level", self._log_level) + log_level_interceptor = self._header_adder_interceptor("log_level", self._log_level) self._final_channel = grpc.intercept_channel(self._final_channel, log_level_interceptor) self._log_level = None if self._request_id: - request_id_interceptor = interceptor.header_adder_interceptor("client_request_id", self._request_id) + request_id_interceptor = self._header_adder_interceptor("client_request_id", self._request_id) self._final_channel = grpc.intercept_channel(self._final_channel, request_id_interceptor) self._request_id = None self._stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel) @@ -192,6 +192,31 @@ def server_address(self): """ Server network address """ return self._address + +class GrpcHandler(AbstractGrpcHandler[grpc.Channel]): + _insecure_channel = staticmethod(grpc.insecure_channel) + _secure_channel = staticmethod(grpc.secure_channel) + + def _wait_for_channel_ready(self, timeout=10): + if self._channel is None: + raise MilvusException( + Status.CONNECT_FAILED, + 'No channel in handler, please setup grpc channel first', + ) + + try: + grpc.channel_ready_future(self._channel).result(timeout=timeout) + except grpc.FutureTimeoutError as exc: + raise MilvusException( + Status.CONNECT_FAILED, + f'Fail connecting to server on {self._address}. Timeout' + ) from exc + + def _header_adder_interceptor(self, header, value): + return interceptor.header_adder_interceptor(header, value) + + #### TODO: implement methods below in asyncio.client.grpc_handler + def reset_password(self, user, old_password, new_password, timeout=None): """ reset password and then setup the grpc channel. diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 288a7347c..00a9a33cd 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -12,10 +12,11 @@ import copy import json +import typing from typing import List import pandas -from .connections import connections +from .connections import connections, AbstractConnections, Connections as SyncConnections from .schema import ( CollectionSchema, FieldSchema, @@ -46,8 +47,11 @@ from ..client.configs import DefaultConfigs +ConnectionsT = typing.TypeVar('ConnectionsT', bound=AbstractConnections) + +class AbstractCollection(typing.Generic[ConnectionsT]): + connections: ConnectionsT -class Collection: def __init__(self, name: str, schema: CollectionSchema=None, using: str="default", shards_num: int=2, **kwargs): """ Constructs a collection by name, schema and other parameters. @@ -91,17 +95,38 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default self._using = using self._shards_num = shards_num self._kwargs = kwargs + self._init_schema = schema + self._init() + + def _init(self): + raise NotImplementedError + + def _get_connection(self): + return self.connections._fetch_handler(self._using) + + @property + def name(self) -> str: + """str: the name of the collection. """ + return self._name + + +class Collection(AbstractCollection[SyncConnections]): + connections = connections + + def _init(self): + assert not (hasattr(self, '_schema') or hasattr(self, '_schema_dict')), "_init() must be called only once" + schema = self._init_schema + kwargs = self._kwargs conn = self._get_connection() has = conn.has_collection(self._name, **kwargs) if has: resp = conn.describe_collection(self._name, **kwargs) - s_consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) - arg_consistency_level = kwargs.get("consistency_level", s_consistency_level) - if not cmp_consistency_level(s_consistency_level, arg_consistency_level): + consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) + arg_consistency_level = kwargs.get("consistency_level", consistency_level) + if not cmp_consistency_level(consistency_level, arg_consistency_level): raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent) server_schema = CollectionSchema.construct_from_dict(resp) - self._consistency_level = s_consistency_level if schema is None: self._schema = server_schema else: @@ -113,18 +138,17 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default else: if schema is None: - raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name) + raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name) if isinstance(schema, CollectionSchema): check_schema(schema) consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs) self._schema = schema - self._consistency_level = consistency_level else: raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) self._schema_dict = self._schema.to_dict() - self._schema_dict["consistency_level"] = self._consistency_level + self._schema_dict["consistency_level"] = consistency_level def __repr__(self): _dict = { @@ -139,9 +163,6 @@ def __repr__(self): r.append(s.format(k, v)) return "".join(r) - def _get_connection(self): - return connections._fetch_handler(self._using) - @classmethod def construct_from_dataframe(cls, name, dataframe, **kwargs): if dataframe is None: @@ -212,11 +233,6 @@ def description(self) -> str: """str: a text description of the collection. """ return self._schema.description - @property - def name(self) -> str: - """str: the name of the collection. """ - return self._name - @property def is_empty(self) -> bool: """bool: whether the collection is empty or not.""" diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 7c178ff65..d9b3ade9c 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -14,11 +14,12 @@ import copy import re import threading +import typing from urllib import parse from typing import Tuple from ..client.check import is_legal_host, is_legal_port, is_legal_address -from ..client.grpc_handler import GrpcHandler +from ..client.grpc_handler import GrpcHandler, AbstractGrpcHandler from .default_config import DefaultConfig, ENV_CONNECTION_CONF from ..exceptions import ExceptionsMessage, ConnectionConfigException, ConnectionNotExistException @@ -55,7 +56,11 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, **kwargs) -class Connections(metaclass=SingleInstanceMetaClass): +NoneT = typing.TypeVar('NoneT', None, typing.Awaitable[None]) +GrpcHandlerT = typing.TypeVar('GrpcHandlerT', bound=AbstractGrpcHandler) + + +class AbstractConnections(typing.Generic[GrpcHandlerT, NoneT], metaclass=SingleInstanceMetaClass): """ Class for managing all connections of milvus. Used as a singleton in this module. """ def __init__(self): @@ -66,7 +71,7 @@ def __init__(self): """ self._alias = {} - self._connected_alias = {} + self._connected_alias: typing.Dict[str, GrpcHandlerT] = {} self.add_connection(default=self._read_default_config_from_os_env()) @@ -190,6 +195,9 @@ def __generate_address(self, uri: str, host: str, port: str) -> str: return f"{host}:{port}" + def _disconnect(self, alias: str, *, remove_connection: bool) -> NoneT: + raise NotImplementedError + def disconnect(self, alias: str): """ Disconnects connection from the registry. @@ -199,8 +207,7 @@ def disconnect(self, alias: str): if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - if alias in self._connected_alias: - self._connected_alias.pop(alias).close() + return self._disconnect(alias, remove_connection=False) def remove_connection(self, alias: str): """ Removes connection from the registry. @@ -211,8 +218,10 @@ def remove_connection(self, alias: str): if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - self.disconnect(alias) - self._alias.pop(alias, None) + return self._disconnect(alias, remove_connection=True) + + def _connect(self, alias, **kwargs) -> NoneT: + raise NotImplementedError def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwargs): """ @@ -265,19 +274,6 @@ def connect(self, alias=DefaultConfig.DEFAULT_USING, user="", password="", **kwa if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - def connect_milvus(**kwargs): - gh = GrpcHandler(**kwargs) - - t = kwargs.get("timeout") - timeout = t if isinstance(t, int) else DefaultConfig.DEFAULT_CONNECT_TIMEOUT - - gh._wait_for_channel_ready(timeout=timeout) - kwargs.pop('password') - kwargs.pop('secure', None) - - self._connected_alias[alias] = gh - self._alias[alias] = copy.deepcopy(kwargs) - def with_config(config: Tuple) -> bool: for c in config: if c != "": @@ -300,15 +296,14 @@ def with_config(config: Tuple) -> bool: if self._alias[alias].get("address") != in_addr: raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) - connect_milvus(**kwargs, user=user, password=password) - else: if alias not in self._alias: raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) - connect_alias = dict(self._alias[alias].items()) - connect_alias["user"] = user - connect_milvus(**connect_alias, password=password, **kwargs) + kwargs = dict(**self._alias[alias], **kwargs) + + kwargs["user"] = user + return self._connect(alias, **kwargs, password=password) def list_connections(self) -> list: """ List names of all connections. @@ -369,7 +364,7 @@ def has_connection(self, alias: str) -> bool: raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) return alias in self._connected_alias - def _fetch_handler(self, alias=DefaultConfig.DEFAULT_USING) -> GrpcHandler: + def _fetch_handler(self, alias=DefaultConfig.DEFAULT_USING) -> GrpcHandlerT: """ Retrieves a GrpcHandler by alias. """ if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) @@ -381,5 +376,26 @@ def _fetch_handler(self, alias=DefaultConfig.DEFAULT_USING) -> GrpcHandler: return conn +class Connections(AbstractConnections[GrpcHandler, None]): + def _disconnect(self, alias: str, *, remove_connection: bool): + if alias in self._connected_alias: + self._connected_alias.pop(alias).close() + if remove_connection: + self._alias.pop(alias, None) + + def _connect(self, alias, **kwargs): + gh = GrpcHandler(**kwargs) + + t = kwargs.get("timeout") + timeout = t if isinstance(t, int) else DefaultConfig.DEFAULT_CONNECT_TIMEOUT + + gh._wait_for_channel_ready(timeout=timeout) + kwargs.pop('password') + kwargs.pop('secure', None) + + self._connected_alias[alias] = gh + self._alias[alias] = copy.deepcopy(kwargs) + + # Singleton Mode in Python connections = Connections() diff --git a/tests/asyncio/test_async_collection.py b/tests/asyncio/test_async_collection.py new file mode 100644 index 000000000..b609dceec --- /dev/null +++ b/tests/asyncio/test_async_collection.py @@ -0,0 +1,43 @@ +import random +import unittest + +from pymilvus import FieldSchema, CollectionSchema, DataType +from pymilvus.asyncio import connections, Collection + + +# this test case requires a running milvus instance. E.g.: +# export IMAGE_REPO=milvusdb +# export IMAGE_TAG=2.1.0-latest +# docker-compose --file ci/docker/milvus/docker-compose.yml up +class TestAsyncCollections(unittest.IsolatedAsyncioTestCase): + async def test_collection_search(self): + await connections.connect() + schema = CollectionSchema([ + FieldSchema("film_id", DataType.INT64, is_primary=True), + FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ]) + collection = Collection("test_collection_search", schema) + # insert + data = [ + list(range(10)), + [[random.random() for _ in range(2)] for _ in range(10)], + ] + await collection.insert(data) + await collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) + await collection.load() + # search + search_param = { + "data": [[1.0, 1.0]], + "anns_field": "films", + "param": {"metric_type": "L2", "offset": 1}, + "limit": 2, + "expr": "film_id > 0", + } + res = await collection.search(**search_param) + assert len(res) == 1 + hits = res[0] + assert len(hits) == 2 + print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ") + # - Total hits: 2, hits ids: [8, 5] + print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ") + # - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385