-
Notifications
You must be signed in to change notification settings - Fork 328
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
308 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import asyncio | ||
|
||
from ...orm.collection import ( | ||
AbstractCollection, DataTypeNotMatchException, ExceptionsMessage, | ||
SearchResult, CollectionSchema, | ||
DEFAULT_CONSISTENCY_LEVEL, | ||
cmp_consistency_level, | ||
SchemaNotReadyException, | ||
check_schema, | ||
get_consistency_level, | ||
) | ||
from .connections import connections, Connections as AsyncConnections | ||
|
||
|
||
class Collection(AbstractCollection[AsyncConnections]): | ||
connections = connections | ||
|
||
def _init(self): | ||
self._ready = asyncio.create_task(self._async_init()) | ||
|
||
@property | ||
def _schema(self): | ||
raise AssertionError("await self._ready before accessing self._schema") | ||
|
||
@property | ||
def _schema_dict(self): | ||
raise AssertionError("await self._ready before accessing self._schema_dict") | ||
|
||
async def _async_init(self): | ||
schema = self._init_schema | ||
kwargs = self._kwargs | ||
conn = self._get_connection() | ||
|
||
has = await conn.has_collection(self._name, **kwargs) | ||
if has: | ||
resp = await conn.describe_collection(self._name, **kwargs) | ||
consistency_level = resp.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL) | ||
arg_consistency_level = kwargs.get("consistency_level", consistency_level) | ||
if not cmp_consistency_level(consistency_level, arg_consistency_level): | ||
raise SchemaNotReadyException(message=ExceptionsMessage.ConsistencyLevelInconsistent) | ||
server_schema = CollectionSchema.construct_from_dict(resp) | ||
if schema is None: | ||
self._schema = server_schema | ||
else: | ||
if not isinstance(schema, CollectionSchema): | ||
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) | ||
if server_schema != schema: | ||
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaInconsistent) | ||
self._schema = schema | ||
|
||
else: | ||
if schema is None: | ||
raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % self._name) | ||
if isinstance(schema, CollectionSchema): | ||
check_schema(schema) | ||
consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) | ||
await conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs) | ||
self._schema = schema | ||
else: | ||
raise SchemaNotReadyException(message=ExceptionsMessage.SchemaType) | ||
|
||
self._schema_dict = self._schema.to_dict() | ||
self._schema_dict["consistency_level"] = consistency_level | ||
|
||
async def search(self, data, anns_field, param, limit, expr=None, partition_names=None, | ||
output_fields=None, timeout=None, round_decimal=-1, **kwargs): | ||
""" Conducts a vector similarity search with an optional boolean expression as filter. | ||
Args: | ||
data (``List[List[float]]``): The vectors of search data. | ||
the length of data is number of query (nq), and the dim of every vector in data must be equal to | ||
the vector field's of collection. | ||
anns_field (``str``): The name of the vector field used to search of collection. | ||
param (``dict[str, Any]``): | ||
The parameters of search. The followings are valid keys of param. | ||
* *nprobe*, *ef*, *search_k*, etc | ||
Corresponding search params for a certain index. | ||
* *metric_type* (``str``) | ||
similar metricy types, the value must be of type str. | ||
* *offset* (``int``, optional) | ||
offset for pagination. | ||
* *limit* (``int``, optional) | ||
limit for the search results and pagination. | ||
example for param:: | ||
{ | ||
"nprobe": 128, | ||
"metric_type": "L2", | ||
"offset": 10, | ||
"limit": 10, | ||
} | ||
limit (``int``): The max number of returned record, also known as `topk`. | ||
expr (``str``): The boolean expression used to filter attribute. Default to None. | ||
example for expr:: | ||
"id_field >= 0", "id_field in [1, 2, 3, 4]" | ||
partition_names (``List[str]``, optional): The names of partitions to search on. Default to None. | ||
output_fields (``List[str]``, optional): | ||
The name of fields to return in the search result. Can only get scalar fields. | ||
round_decimal (``int``, optional): The specified number of decimal places of returned distance. | ||
Defaults to -1 means no round to returned distance. | ||
timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. | ||
If timeout is set to None, the client keeps waiting until the server responds or an error occurs. | ||
**kwargs (``dict``): Optional search params | ||
* *consistency_level* (``str/int``, optional) | ||
Which consistency level to use when searching in the collection. | ||
Options of consistency level: Strong, Bounded, Eventually, Session, Customized. | ||
Note: this parameter will overwrite the same parameter specified when user created the collection, | ||
if no consistency level was specified, search will use the consistency level when you create the | ||
collection. | ||
* *guarantee_timestamp* (``int``, optional) | ||
Instructs Milvus to see all operations performed before this timestamp. | ||
By default Milvus will search all operations performed to date. | ||
Note: only valid in Customized consistency level. | ||
* *graceful_time* (``int``, optional) | ||
Search will use the (current_timestamp - the graceful_time) as the | ||
`guarantee_timestamp`. By default with 5s. | ||
Note: only valid in Bounded consistency level | ||
* *travel_timestamp* (``int``, optional) | ||
A specific timestamp to get results based on a data view at. | ||
Returns: | ||
SearchResult: | ||
Returns ``SearchResult`` | ||
.. _Metric type documentations: | ||
https://milvus.io/docs/v2.2.x/metric.md | ||
.. _Index documentations: | ||
https://milvus.io/docs/v2.2.x/index.md | ||
.. _How guarantee ts works: | ||
https://github.com/milvus-io/milvus/blob/master/docs/developer_guides/how-guarantee-ts-works.md | ||
Raises: | ||
MilvusException: If anything goes wrong | ||
Examples: | ||
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType | ||
>>> import random | ||
>>> connections.connect() | ||
>>> schema = CollectionSchema([ | ||
... FieldSchema("film_id", DataType.INT64, is_primary=True), | ||
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) | ||
... ]) | ||
>>> collection = Collection("test_collection_search", schema) | ||
>>> # insert | ||
>>> data = [ | ||
... [i for i in range(10)], | ||
... [[random.random() for _ in range(2)] for _ in range(10)], | ||
... ] | ||
>>> collection.insert(data) | ||
>>> collection.create_index("films", {"index_type": "FLAT", "metric_type": "L2", "params": {}}) | ||
>>> collection.load() | ||
>>> # search | ||
>>> search_param = { | ||
... "data": [[1.0, 1.0]], | ||
... "anns_field": "films", | ||
... "param": {"metric_type": "L2", "offset": 1}, | ||
... "limit": 2, | ||
... "expr": "film_id > 0", | ||
... } | ||
>>> res = collection.search(**search_param) | ||
>>> assert len(res) == 1 | ||
>>> hits = res[0] | ||
>>> assert len(hits) == 2 | ||
>>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ") | ||
- Total hits: 2, hits ids: [8, 5] | ||
>>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ") | ||
- Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 | ||
""" | ||
await self._ready | ||
if expr is not None and not isinstance(expr, str): | ||
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) | ||
|
||
conn = self._get_connection() | ||
res = await conn.search( | ||
self._name, data, anns_field, param, limit, expr, | ||
partition_names, output_fields, round_decimal, timeout=timeout, | ||
schema=self._schema_dict, **kwargs) | ||
return SearchResult(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.