Skip to content

Commit

Permalink
Infer start and end date from filters
Browse files Browse the repository at this point in the history
  • Loading branch information
cevian committed Oct 12, 2023
1 parent 694f862 commit 593f878
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 10 deletions.
72 changes: 67 additions & 5 deletions nbs/00_vector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@
" num_dimensions: int,\n",
" distance_type: str,\n",
" id_type: str,\n",
" time_partition_interval: Optional[timedelta]) -> None:\n",
" time_partition_interval: Optional[timedelta],\n",
" infer_filters: bool) -> None:\n",
" \"\"\"\n",
" Initializes a base Vector object to generate queries for vector clients.\n",
"\n",
Expand Down Expand Up @@ -522,6 +523,7 @@
"\n",
" self.id_type = id_type.lower()\n",
" self.time_partition_interval = time_partition_interval\n",
" self.infer_filters = infer_filters\n",
"\n",
" def _quote_ident(self, ident):\n",
" \"\"\"\n",
Expand Down Expand Up @@ -713,6 +715,36 @@
" raise ValueError(\"Unknown filter type: {filter_type}\".format(filter_type=type(filter)))\n",
"\n",
" return (where, params)\n",
" \n",
" def _parse_datetime(self, input_datetime):\n",
" \"\"\"\n",
" Parse a datetime object or string representation of a datetime.\n",
"\n",
" Args:\n",
" input_datetime (datetime or str): Input datetime or string.\n",
"\n",
" Returns:\n",
" datetime: Parsed datetime object.\n",
"\n",
" Raises:\n",
" ValueError: If the input cannot be parsed as a datetime.\n",
" \"\"\"\n",
" if input_datetime is None:\n",
" return None\n",
" \n",
" if isinstance(input_datetime, datetime):\n",
" # If input is already a datetime object, return it as is\n",
" return input_datetime\n",
"\n",
" if isinstance(input_datetime, str):\n",
" try:\n",
" # Attempt to parse the input string into a datetime\n",
" return datetime.fromisoformat(input_datetime)\n",
" except ValueError:\n",
" raise ValueError(\"Invalid datetime string format\")\n",
"\n",
" raise ValueError(\"Input must be a datetime object or string\")\n",
"\n",
"\n",
" def search_query(\n",
" self, \n",
Expand All @@ -739,6 +771,20 @@
" distance = \"-1.0\"\n",
" order_by_clause = \"\"\n",
"\n",
" if self.infer_filters:\n",
" if uuid_time_filter is None and isinstance(filter, dict):\n",
" if \"__start_date\" in filter or \"__end_date\" in filter:\n",
" start_date = self._parse_datetime(filter.get(\"__start_date\"))\n",
" end_date = self._parse_datetime(filter.get(\"__end_date\"))\n",
" \n",
" uuid_time_filter = UUIDTimeRange(start_date, end_date)\n",
" \n",
" if start_date is not None:\n",
" del filter[\"__start_date\"]\n",
" if end_date is not None:\n",
" del filter[\"__end_date\"]\n",
"\n",
"\n",
" where_clauses = []\n",
" if filter is not None:\n",
" (where_filter, params) = self._where_clause_for_filter(params, filter)\n",
Expand Down Expand Up @@ -836,7 +882,8 @@
" distance_type: str = 'cosine',\n",
" id_type='UUID',\n",
" time_partition_interval: Optional[timedelta] = None,\n",
" max_db_connections: Optional[int] = None\n",
" max_db_connections: Optional[int] = None,\n",
" infer_filters: bool = True,\n",
" ) -> None:\n",
" \"\"\"\n",
" Initializes a async client for storing vector data.\n",
Expand All @@ -855,7 +902,7 @@
" The type of the id column. Can be either 'UUID' or 'TEXT'.\n",
" \"\"\"\n",
" self.builder = QueryBuilder(\n",
" table_name, num_dimensions, distance_type, id_type, time_partition_interval)\n",
" table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)\n",
" self.service_url = service_url\n",
" self.pool = None\n",
" self.max_db_connections = max_db_connections\n",
Expand Down Expand Up @@ -1444,8 +1491,20 @@
"assert not await vec.table_is_empty()\n",
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime+timedelta(days=7)))\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime+timedelta(days=7)})\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7)), \"__end_date\": str(specific_datetime+timedelta(days=7))})\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7))})\n",
"assert len(rec) == 2\n",
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__end_date\": str(specific_datetime+timedelta(days=7))})\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime-timedelta(days=2)))\n",
"assert len(rec) == 0\n",
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime-timedelta(days=2)})\n",
"assert len(rec) == 0\n",
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7)), \"__end_date\": str(specific_datetime-timedelta(days=2))})\n",
"assert len(rec) == 0\n",
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7)))\n",
"assert len(rec) == 2\n",
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(start_date=specific_datetime, time_delta=timedelta(days=7)))\n",
Expand Down Expand Up @@ -1500,7 +1559,8 @@
" distance_type: str = 'cosine',\n",
" id_type='UUID',\n",
" time_partition_interval: Optional[timedelta] = None,\n",
" max_db_connections: Optional[int] = None\n",
" max_db_connections: Optional[int] = None,\n",
" infer_filters: bool = True,\n",
" ) -> None:\n",
" \"\"\"\n",
" Initializes a sync client for storing vector data.\n",
Expand All @@ -1519,7 +1579,7 @@
" The type of the primary id column. Can be either 'UUID' or 'TEXT'.\n",
" \"\"\"\n",
" self.builder = QueryBuilder(\n",
" table_name, num_dimensions, distance_type, id_type, time_partition_interval)\n",
" table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)\n",
" self.service_url = service_url\n",
" self.pool = None\n",
" self.max_db_connections = max_db_connections\n",
Expand Down Expand Up @@ -2147,6 +2207,8 @@
"assert not vec.table_is_empty()\n",
"rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime+timedelta(days=7)))\n",
"assert len(rec) == 1\n",
"rec = vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime+timedelta(days=7)})\n",
"assert len(rec) == 1\n",
"rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime-timedelta(days=2)))\n",
"assert len(rec) == 0\n",
"rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7)))\n",
Expand Down
23 changes: 23 additions & 0 deletions nbs/tsv_python_getting_started_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@
"Each partition will consist of data for the specified length of time. We'll use 7 days for simplicity, but you can pick whatever value make sense for your use case -- for example if you query recent vectors frequently you might want to use a smaller time delta like 1 day, or if you query vectors over a decade long time period then you might want to use a larger time delta like 6 months or 1 year."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"import asyncpg"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"con = await asyncpg.connect(TIMESCALE_SERVICE_URL)\n",
"await con.execute(\"DROP TABLE IF EXISTS commit_history;\")\n",
"await con.execute(\"DROP EXTENSION IF EXISTS vector CASCADE\")\n",
"await con.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
2 changes: 2 additions & 0 deletions timescale_vector/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._get_embedding_index_name': ( 'vector.html#querybuilder._get_embedding_index_name',
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._parse_datetime': ( 'vector.html#querybuilder._parse_datetime',
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._quote_ident': ( 'vector.html#querybuilder._quote_ident',
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._where_clause_for_filter': ( 'vector.html#querybuilder._where_clause_for_filter',
Expand Down
58 changes: 53 additions & 5 deletions timescale_vector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ def __init__(
num_dimensions: int,
distance_type: str,
id_type: str,
time_partition_interval: Optional[timedelta]) -> None:
time_partition_interval: Optional[timedelta],
infer_filters: bool) -> None:
"""
Initializes a base Vector object to generate queries for vector clients.
Expand Down Expand Up @@ -411,6 +412,7 @@ def __init__(

self.id_type = id_type.lower()
self.time_partition_interval = time_partition_interval
self.infer_filters = infer_filters

def _quote_ident(self, ident):
"""
Expand Down Expand Up @@ -602,6 +604,36 @@ def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str
raise ValueError("Unknown filter type: {filter_type}".format(filter_type=type(filter)))

return (where, params)

def _parse_datetime(self, input_datetime):
"""
Parse a datetime object or string representation of a datetime.
Args:
input_datetime (datetime or str): Input datetime or string.
Returns:
datetime: Parsed datetime object.
Raises:
ValueError: If the input cannot be parsed as a datetime.
"""
if input_datetime is None:
return None

if isinstance(input_datetime, datetime):
# If input is already a datetime object, return it as is
return input_datetime

if isinstance(input_datetime, str):
try:
# Attempt to parse the input string into a datetime
return datetime.fromisoformat(input_datetime)
except ValueError:
raise ValueError("Invalid datetime string format")

raise ValueError("Input must be a datetime object or string")


def search_query(
self,
Expand All @@ -628,6 +660,20 @@ def search_query(
distance = "-1.0"
order_by_clause = ""

if self.infer_filters:
if uuid_time_filter is None and isinstance(filter, dict):
if "__start_date" in filter or "__end_date" in filter:
start_date = self._parse_datetime(filter.get("__start_date"))
end_date = self._parse_datetime(filter.get("__end_date"))

uuid_time_filter = UUIDTimeRange(start_date, end_date)

if start_date is not None:
del filter["__start_date"]
if end_date is not None:
del filter["__end_date"]


where_clauses = []
if filter is not None:
(where_filter, params) = self._where_clause_for_filter(params, filter)
Expand Down Expand Up @@ -671,7 +717,8 @@ def __init__(
distance_type: str = 'cosine',
id_type='UUID',
time_partition_interval: Optional[timedelta] = None,
max_db_connections: Optional[int] = None
max_db_connections: Optional[int] = None,
infer_filters: bool = True,
) -> None:
"""
Initializes a async client for storing vector data.
Expand All @@ -690,7 +737,7 @@ def __init__(
The type of the id column. Can be either 'UUID' or 'TEXT'.
"""
self.builder = QueryBuilder(
table_name, num_dimensions, distance_type, id_type, time_partition_interval)
table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)
self.service_url = service_url
self.pool = None
self.max_db_connections = max_db_connections
Expand Down Expand Up @@ -933,7 +980,8 @@ def __init__(
distance_type: str = 'cosine',
id_type='UUID',
time_partition_interval: Optional[timedelta] = None,
max_db_connections: Optional[int] = None
max_db_connections: Optional[int] = None,
infer_filters: bool = True,
) -> None:
"""
Initializes a sync client for storing vector data.
Expand All @@ -952,7 +1000,7 @@ def __init__(
The type of the primary id column. Can be either 'UUID' or 'TEXT'.
"""
self.builder = QueryBuilder(
table_name, num_dimensions, distance_type, id_type, time_partition_interval)
table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)
self.service_url = service_url
self.pool = None
self.max_db_connections = max_db_connections
Expand Down

0 comments on commit 593f878

Please sign in to comment.