diff --git a/examples/iterator.py b/examples/iterator.py index 973ef65ae..f9fa496d5 100644 --- a/examples/iterator.py +++ b/examples/iterator.py @@ -20,7 +20,7 @@ LIMIT = 5 NUM_ENTITIES = 1000 DIM = 8 -CLEAR_EXIST = False +CLEAR_EXIST = True # Create a logger for the main script log = logging.getLogger(__name__) diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 73c2467b9..8cc974a0a 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -2,7 +2,7 @@ import logging from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from pymilvus.client import entity_helper, utils from pymilvus.client.abstract import Hits, LoopBase @@ -60,6 +60,18 @@ def fall_back_to_latest_session_ts(): return mkts_from_datetime(d, milliseconds=1000.0) +def assert_info(condition: bool, message: str): + if not condition: + raise MilvusException(message) + + +def io_operation(io_func: Callable[[Any], None], message: str): + try: + io_func() + except OSError as ose: + raise MilvusException(message=message) from ose + + def extend_batch_size(batch_size: int, next_param: dict, to_extend_batch_size: bool) -> int: extend_rate = 1 if to_extend_batch_size: @@ -112,6 +124,9 @@ def __init__( self.__seek_to_offset() def __seek_to_offset(self): + # read pk cursor from cp file, no need to seek offset + if self._next_id is not None: + return offset = self._kwargs.get(OFFSET, 0) if offset > 0: seek_params = self._kwargs.copy() @@ -140,14 +155,26 @@ def __init_cp_file_handler(self) -> bool: ) from ose return mode == "r+" - def __save_cp(self): + def __save_mvcc_ts(self): + assert_info( + self._cp_file_handler is not None, + "Must init cp file handler before saving session_ts", + ) + self._cp_file_handler.writelines(str(self._session_ts) + "\n") + + def __save_pk_cursor(self): if self._need_save_cp: - self._cp_file_handler.seek(0) - self._cp_file_handler.truncate() - self._cp_file_handler.writelines(str(self._session_ts) + "\n") - if self._next_id is not None: - self._cp_file_handler.writelines(self._next_id) + if self._buffer_cursor_lines_number >= 100: + self._cp_file_handler.seek(0) + self._cp_file_handler.truncate() + log.info( + "cursor lines in cp file has exceeded 100 lines, truncate the file and rewrite" + ) + self._cp_file_handler.writelines(str(self._session_ts) + "\n") + assert_info(self._next_id is not None, "next_id should not be None when saving cp") + self._cp_file_handler.writelines(str(self._next_id) + "\n") self._cp_file_handler.flush() + self._buffer_cursor_lines_number = 0 def __check_set_reduce_stop_for_best(self): if self._kwargs.get(REDUCE_STOP_FOR_BEST, True): @@ -198,6 +225,7 @@ def __setup_ts_by_request(self): self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts def __set_up_ts_cp(self): + self._buffer_cursor_lines_number = 0 self._cp_file_path = self._kwargs.get(ITERATOR_SESSION_CP_FILE, None) # no input cp_file, set up mvccTs by query request if self._cp_file_path is None: @@ -208,29 +236,28 @@ def __set_up_ts_cp(self): if not self.__init_cp_file_handler(): # input cp file is empty, set up mvccTs by query request self.__setup_ts_by_request() - self.__save_cp() + io_operation(self.__save_mvcc_ts, "Failed to save mvcc ts") else: - # input cp file is not emtpy, init mvccTs by reading cp file - file_size = Path(self._cp_file_path).stat().st_size - if file_size > 1024: - raise ParamError( - message="input cp file is too large, exceeding 1kb, " - "this may be a incorrect configuration" - ) - lines = self._cp_file_handler.readlines() - line_count = len(lines) - if line_count > 2 or line_count < 1: - raise ParamError( - message=f"line number:{len(lines)} of input cp file is wrong, " - f"which should be one or two" - ) try: + # input cp file is not emtpy, init mvccTs by reading cp file + lines = self._cp_file_handler.readlines() + line_count = len(lines) + if line_count < 2: + raise ParamError( + message=f"input cp file:{self._cp_file_path} should contain " + f"at least two lines, but only:{line_count} lines" + ) self._session_ts = int(lines[0]) self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts + if line_count > 1: + self._buffer_cursor_lines_number = line_count - 1 + self._next_id = lines[self._buffer_cursor_lines_number].strip() + except OSError as ose: + raise MilvusException( + message=f"Failed to read cp info from file:{self._cp_file_path}" + ) from ose except ValueError as e: raise ParamError(message=f"cannot parse input cp session_ts:{lines[0]}") from e - if line_count == 2: - self._next_id = lines[1].strip() def __maybe_cache(self, result: List): if len(result) < 2 * self._kwargs[BATCH_SIZE]: @@ -268,7 +295,7 @@ def next(self): ret = self.__check_reached_limit(ret) self.__update_cursor(ret) - self.__save_cp() + io_operation(self.__save_pk_cursor, "failed to save pk cursor") self._returned_count += len(ret) return ret @@ -316,7 +343,15 @@ def close(self) -> None: # release cache in use iterator_cache.release_cache(self._cache_id_in_use) if self._cp_file_handler is not None: - self._cp_file_handler.close() + + def inner_close(): + self._cp_file_handler.close() + Path(self._cp_file_path).unlink() + log.info(f"removed cp file:{self._cp_file_path} for query iterator") + + io_operation( + inner_close, f"failed to clear cp file:{self._cp_file_path} for query iterator" + ) def metrics_positive_related(metrics: str) -> bool: