Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
belkka committed Feb 6, 2023
1 parent 46167a1 commit 7bb3eee
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 27 deletions.
72 changes: 72 additions & 0 deletions pymilvus/asyncio/client/grpc_handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import asyncio

import grpc.aio

from ...client.grpc_handler import (
AbstractGrpcHandler,
Status,
MilvusException,
# retry_on_rpc_failure,
check_pass_param,
get_consistency_level,
ts_utils,
Prepare,
CollectionSchema,
DescribeCollectionException,
ChunkedQueryResult,
)


Expand All @@ -21,3 +31,65 @@ async def _channel_ready(self):

def _header_adder_interceptor(self, header, value):
raise NotImplementedError # TODO

# TODO: @retry_on_rpc_failure()
async def describe_collection(self, collection_name, timeout=None, **kwargs):
check_pass_param(collection_name=collection_name)
request = Prepare.describe_collection_request(collection_name)
response = await self._stub.DescribeCollection(request, timeout=timeout)

status = response.status
if status.error_code != 0:
raise DescribeCollectionException(status.error_code, status.reason)

return CollectionSchema(raw=response).dict()

async def _execute_search_requests(self, requests, timeout=None, *, auto_id=True, round_decimal=-1, **kwargs):
async def _raise_milvus_exception_on_error_response(awaitable_response):
response = await awaitable_response
if response.status.error_code != 0:
raise MilvusException(response.status.error_code, response.status.reason)
return response

raws: list = await asyncio.gather(*(
_raise_milvus_exception_on_error_response(
self._stub.Search(request, timeout=timeout)
)
for request in requests
))
return ChunkedQueryResult(raws, auto_id, round_decimal)

# TODO: @retry_on_rpc_failure(retry_on_deadline=False)
async def search(
self, collection_name, data, anns_field, param, limit,
expression=None, partition_names=None, output_fields=None,
round_decimal=-1, timeout=None, schema=None, **kwargs,
):
check_pass_param(
limit=limit,
round_decimal=round_decimal,
anns_field=anns_field,
search_data=data,
partition_name_array=partition_names,
output_fields=output_fields,
travel_timestamp=kwargs.get("travel_timestamp", 0),
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0)
)

if schema is None:
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)

consistency_level = schema["consistency_level"]
# overwrite the consistency level defined when user created the collection
consistency_level = get_consistency_level(kwargs.get("consistency_level", consistency_level))

ts_utils.construct_guarantee_ts(consistency_level, collection_name, kwargs)

requests = Prepare.search_requests_with_expr(collection_name, data, anns_field, param, limit, schema,
expression, partition_names, output_fields, round_decimal,
**kwargs)

auto_id = schema["auto_id"]
return await self._execute_search_requests(
requests, timeout, round_decimal=round_decimal, auto_id=auto_id, **kwargs,
)
196 changes: 196 additions & 0 deletions pymilvus/asyncio/orm/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import asyncio

from ...orm.collection import (
AbstractCollection, DataTypeNotMatchException, ExceptionsMessage,
SearchResult, CollectionSchema,
DEFAULT_CONSISTENCY_LEVEL,
cmp_consistency_level,
SchemaNotReadyException,
check_schema,
get_consistency_level,
)
from .connections import connections, Connections as AsyncConnections


class Collection(AbstractCollection[AsyncConnections]):
connections = connections

def _init(self):
self._ready = asyncio.create_task(self._async_init())

@property
def _schema(self):
raise AssertionError("await self._ready before accessing self._schema")

@property
def _schema_dict(self):
raise AssertionError("await self._ready before accessing self._schema_dict")

async def _async_init(self):
schema = self._init_schema
kwargs = self._kwargs
conn = self._get_connection()

has = await conn.has_collection(self._name, **kwargs)
if has:
resp = await conn.describe_collection(self._name, **kwargs)
consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)
arg_consistency_level = kwargs.get("consistency_level", consistency_level)
if not cmp_consistency_level(consistency_level, arg_consistency_level):
raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent)
server_schema = CollectionSchema.construct_from_dict(resp)
if schema is None:
self._schema = server_schema
else:
if not isinstance(schema, CollectionSchema):
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType)
if server_schema != schema:
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaInconsistent)
self._schema = schema

else:
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name)
if isinstance(schema, CollectionSchema):
check_schema(schema)
consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL))
await conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs)
self._schema = schema
else:
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType)

self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = consistency_level

async def search(self, data, anns_field, param, limit, expr=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
""" Conducts a vector similarity search with an optional boolean expression as filter.
Args:
data (``List[List[float]]``): The vectors of search data.
the length of data is number of query (nq), and the dim of every vector in data must be equal to
the vector field's of collection.
anns_field (``str``): The name of the vector field used to search of collection.
param (``dict[str, Any]``):
The parameters of search. The followings are valid keys of param.
* *nprobe*, *ef*, *search_k*, etc
Corresponding search params for a certain index.
* *metric_type* (``str``)
similar metricy types, the value must be of type str.
* *offset* (``int``, optional)
offset for pagination.
* *limit* (``int``, optional)
limit for the search results and pagination.
example for param::
{
"nprobe": 128,
"metric_type": "L2",
"offset": 10,
"limit": 10,
}
limit (``int``): The max number of returned record, also known as `topk`.
expr (``str``): The boolean expression used to filter attribute. Default to None.
example for expr::
"id_field >= 0", "id_field in [1, 2, 3, 4]"
partition_names (``List[str]``, optional): The names of partitions to search on. Default to None.
output_fields (``List[str]``, optional):
The name of fields to return in the search result. Can only get scalar fields.
round_decimal (``int``, optional): The specified number of decimal places of returned distance.
Defaults to -1 means no round to returned distance.
timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None.
If timeout is set to None, the client keeps waiting until the server responds or an error occurs.
**kwargs (``dict``): Optional search params
* *consistency_level* (``str/int``, optional)
Which consistency level to use when searching in the collection.
Options of consistency level: Strong, Bounded, Eventually, Session, Customized.
Note: this parameter will overwrite the same parameter specified when user created the collection,
if no consistency level was specified, search will use the consistency level when you create the
collection.
* *guarantee_timestamp* (``int``, optional)
Instructs Milvus to see all operations performed before this timestamp.
By default Milvus will search all operations performed to date.
Note: only valid in Customized consistency level.
* *graceful_time* (``int``, optional)
Search will use the (current_timestamp - the graceful_time) as the
`guarantee_timestamp`. By default with 5s.
Note: only valid in Bounded consistency level
* *travel_timestamp* (``int``, optional)
A specific timestamp to get results based on a data view at.
Returns:
SearchResult:
Returns ``SearchResult``
.. _Metric type documentations:
https://milvus.io/docs/v2.2.x/metric.md
.. _Index documentations:
https://milvus.io/docs/v2.2.x/index.md
.. _How guarantee ts works:
https://github.com/milvus-io/milvus/blob/master/docs/developer_guides/how-guarantee-ts-works.md
Raises:
MilvusException: If anything goes wrong
Examples:
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
>>> import random
>>> connections.connect()
>>> schema = CollectionSchema([
... FieldSchema("film_id", DataType.INT64, is_primary=True),
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2)
... ])
>>> collection = Collection("test_collection_search", schema)
>>> # insert
>>> data = [
... [i for i in range(10)],
... [[random.random() for _ in range(2)] for _ in range(10)],
... ]
>>> collection.insert(data)
>>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}})
>>> collection.load()
>>> # search
>>> search_param = {
... "data": [[1.0, 1.0]],
... "anns_field": "films",
... "param": {"metric_type": "L2", "offset": 1},
... "limit": 2,
... "expr": "film_id > 0",
... }
>>> res = collection.search(**search_param)
>>> assert len(res) == 1
>>> hits = res[0]
>>> assert len(hits) == 2
>>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ")
- Total hits: 2, hits ids: [8, 5]
>>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ")
- Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385
"""
await self._ready
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
res = await conn.search(
self._name, data, anns_field, param, limit, expr,
partition_names, output_fields, round_decimal, timeout=timeout,
schema=self._schema_dict, **kwargs)
return SearchResult(res)
50 changes: 33 additions & 17 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

import copy
import json
import typing
from typing import List
import pandas

from .connections import connections
from .connections import connections, AbstractConnections, Connections as SyncConnections
from .schema import (
CollectionSchema,
FieldSchema,
Expand Down Expand Up @@ -46,8 +47,11 @@
from ..client.configs import DefaultConfigs


ConnectionsT = typing.TypeVar('ConnectionsT', bound=AbstractConnections)

class AbstractCollection(typing.Generic[ConnectionsT]):
connections: ConnectionsT

class Collection:
def __init__(self, name: str, schema: CollectionSchema=None, using: str="default", shards_num: int=2, **kwargs):
""" Constructs a collection by name, schema and other parameters.
Expand Down Expand Up @@ -91,17 +95,38 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default
self._using = using
self._shards_num = shards_num
self._kwargs = kwargs
self._init_schema = schema
self._init()

def _init(self):
raise NotImplementedError

def _get_connection(self):
return self.connections._fetch_handler(self._using)

@property
def name(self) -> str:
"""str: the name of the collection. """
return self._name


class Collection(AbstractCollection[SyncConnections]):
connections = connections

def _init(self):
assert not (hasattr(self, '_schema') or hasattr(self, '_schema_dict')), "_init() must be called only once"
schema = self._init_schema
kwargs = self._kwargs
conn = self._get_connection()

has = conn.has_collection(self._name, **kwargs)
if has:
resp = conn.describe_collection(self._name, **kwargs)
s_consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)
arg_consistency_level = kwargs.get("consistency_level", s_consistency_level)
if not cmp_consistency_level(s_consistency_level, arg_consistency_level):
consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)
arg_consistency_level = kwargs.get("consistency_level", consistency_level)
if not cmp_consistency_level(consistency_level, arg_consistency_level):
raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent)
server_schema = CollectionSchema.construct_from_dict(resp)
self._consistency_level = s_consistency_level
if schema is None:
self._schema = server_schema
else:
Expand All @@ -113,18 +138,17 @@ def __init__(self, name: str, schema: CollectionSchema=None, using: str="default

else:
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name)
raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name)
if isinstance(schema, CollectionSchema):
check_schema(schema)
consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL))
conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs)
self._schema = schema
self._consistency_level = consistency_level
else:
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType)

self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = self._consistency_level
self._schema_dict["consistency_level"] = consistency_level

def __repr__(self):
_dict = {
Expand All @@ -139,9 +163,6 @@ def __repr__(self):
r.append(s.format(k, v))
return "".join(r)

def _get_connection(self):
return connections._fetch_handler(self._using)

@classmethod
def construct_from_dataframe(cls, name, dataframe, **kwargs):
if dataframe is None:
Expand Down Expand Up @@ -212,11 +233,6 @@ def description(self) -> str:
"""str: a text description of the collection. """
return self._schema.description

@property
def name(self) -> str:
"""str: the name of the collection. """
return self._name

@property
def is_empty(self) -> bool:
"""bool: whether the collection is empty or not."""
Expand Down
Loading

0 comments on commit 7bb3eee

Please sign in to comment.