Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1787415 Fix dataframe ast to register udtf for replay execution using client created udtf in same client session #2620

Open
wants to merge 12 commits into
base: ls-SNOW-1491199-merge-phase0-server-side
Choose a base branch
from
12 changes: 10 additions & 2 deletions src/snowflake/snowpark/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ def build_proto_from_callable(
expr_builder: proto.SpCallable,
func: Union[Callable, Tuple[str, str]],
ast_batch: Optional[AstBatch] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could do _registered_object_name here as well.

object_name: Optional[Union[str, Iterable[str]]] = None,
):
"""Registers a python callable (i.e., a function or lambda) to the AstBatch and encodes it as SpCallable protobuf."""

Expand All @@ -977,6 +978,9 @@ def build_proto_from_callable(
# Use the actual function name. Note: We do not support different scopes yet, need to be careful with this then.
expr_builder.name = func.__name__

if object_name is not None:
build_sp_table_name(expr_builder.object_name, object_name)


sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
def build_udf(
ast: proto.Udf,
Expand Down Expand Up @@ -1166,7 +1170,8 @@ def build_udtf(
comment: Optional[str] = None,
statement_params: Optional[Dict[str, str]] = None,
is_permanent: bool = False,
session=None,
session: "snowflake.snowpark.session.Session" = None,
udtf_name: Optional[Union[str, Iterable[str]]] = None,
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""Helper function to encode UDTF parameters (used in both regular and mock UDFRegistration)."""
Expand All @@ -1176,7 +1181,10 @@ def build_udtf(
_set_fn_name(name, ast)

build_proto_from_callable(
ast.handler, handler, session._ast_batch if session is not None else None
ast.handler,
handler,
session._ast_batch if session is not None else None,
udtf_name,
)

if output_schema is not None:
Expand Down
7 changes: 4 additions & 3 deletions src/snowflake/snowpark/_internal/proto/ast.proto
Original file line number Diff line number Diff line change
Expand Up @@ -534,22 +534,23 @@ message SpDataframeSchema_Struct {
message SpCallable {
int64 id = 1;
string name = 2;
SpTableName object_name = 3;
}

// sp-type.ir:104
// sp-type.ir:106
message SpPivotValue {
oneof sealed_value {
SpPivotValue_Dataframe sp_pivot_value__dataframe = 1;
SpPivotValue_Expr sp_pivot_value__expr = 2;
}
}

// sp-type.ir:105
// sp-type.ir:107
message SpPivotValue_Expr {
Expr v = 1;
}

// sp-type.ir:106
// sp-type.ir:108
message SpPivotValue_Dataframe {
SpDataframeRef v = 1;
}
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"_from_pandas_udf_function",
"input_names", # for pandas_udtf
"max_batch_size", # for pandas_udtf
"_registered_object_name", # db object name if already registered
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
}


Expand Down
44 changes: 28 additions & 16 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,31 +1359,24 @@ def select(
if not exprs:
raise ValueError("The input of select() cannot be empty")

# AST.
stmt = _ast_stmt
ast = None

if _emit_ast and _ast_stmt is None:
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_select__columns, stmt)
self._set_ast_ref(ast.df)
ast.variadic = is_variadic

names = []
table_func = None
join_plan = None

ast_cols = []

for e in exprs:
if isinstance(e, Column):
names.append(e._named())
if _emit_ast and ast:
ast.cols.append(e._ast)
if _emit_ast and _ast_stmt is None:
ast_cols.append(e._ast)

elif isinstance(e, str):
col_expr_ast = None
if ast:
col_expr_ast = ast.cols.add() if ast else proto.Expr()
if _emit_ast and _ast_stmt is None:
col_expr_ast = proto.Expr()
fill_ast_for_column(col_expr_ast, e, None)
ast_cols.append(col_expr_ast)

col = Column(e, _ast=col_expr_ast)
names.append(col._named())
Expand All @@ -1395,9 +1388,11 @@ def select(
f"Called '{table_func.user_visible_name}' and '{e.user_visible_name}'."
)
table_func = e
if _emit_ast and ast:
if _emit_ast and _ast_stmt is None:
add_intermediate_stmt(self._session._ast_batch, table_func)
build_indirect_table_fn_apply(ast.cols.add(), table_func)
ast_col = proto.Expr()
build_indirect_table_fn_apply(ast_col, table_func)
ast_cols.append(ast_col)

func_expr = _create_table_function_expression(func=table_func)

Expand Down Expand Up @@ -1446,6 +1441,23 @@ def select(
"The input of select() must be Column, column name, TableFunctionCall, or a list of them"
)

# AST.
stmt = _ast_stmt
ast = None

# Note it's intentional the column expressions are AST serializerd earlier (ast_cols) to ensure any
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
# AST IDs created preceed the AST ID of the select statement so they are deserialized in dependent order.
if _emit_ast and _ast_stmt is None:
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_select__columns, stmt)
self._set_ast_ref(ast.df)
ast.variadic = is_variadic

# Add columns after the statement to ensure any dependent columns have lower ast id.
for ast_col in ast_cols:
if ast_col is not None:
ast.cols.add().CopyFrom(ast_col)

if self._select_statement:
if join_plan:
return self._with_plan(
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8758,7 +8758,7 @@ def udtf(
else:
udtf_registration_method = session.udtf.register

if handler is None:
if handler is None and kwargs.get("_registered_object_name") is None:
return functools.partial(
udtf_registration_method,
output_schema=output_schema,
Expand Down
15 changes: 12 additions & 3 deletions src/snowflake/snowpark/mock/_nop_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@
from snowflake.snowpark.mock._select_statement import MockSelectable

# from snowflake.snowpark.session import Session
from snowflake.snowpark.types import MapType, PandasDataFrameType, _NumericType
from snowflake.snowpark.types import (
IntegerType,
MapType,
PandasDataFrameType,
_NumericType,
)


def resolve_attributes(
Expand Down Expand Up @@ -82,8 +87,12 @@ def resolve_attributes(
attributes = [
Attribute(
attr.name,
source_attributes[attr_name].datatype,
source_attributes[attr_name].nullable,
source_attributes[attr_name].datatype
if attr_name in source_attributes
else IntegerType(),
source_attributes[attr_name].nullable
if attr_name in source_attributes
else True,
)
if isinstance(attr, UnresolvedAttribute)
else attr
Expand Down
89 changes: 53 additions & 36 deletions src/snowflake/snowpark/mock/_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,20 @@ def _do_register_udtf(
_emit_ast: bool = True,
**kwargs,
) -> UserDefinedTableFunction:

# Capture original parameters.
ast = None
if _emit_ast:
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.udtf, stmt)

build_udtf(
ast,
if kwargs.get("_registered_object_name") is not None:
ast, ast_id = None, None
if _emit_ast:
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.udtf, stmt)
ast_id = stmt.var_id.bitfield1

return MockUserDefinedTableFunction(
handler,
output_schema=output_schema,
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
input_types=input_types,
name=name,
stage_location=stage_location,
imports=imports,
packages=packages,
replace=replace,
if_not_exists=if_not_exists,
parallel=parallel,
max_batch_size=max_batch_size,
strict=strict,
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
comment=comment,
statement_params=statement_params,
is_permanent=is_permanent,
session=self._session,
**kwargs,
output_schema,
input_types,
kwargs["_registered_object_name"],
_ast=ast,
_ast_id=ast_id,
)

if isinstance(output_schema, StructType):
Expand Down Expand Up @@ -140,17 +124,50 @@ def _do_register_udtf(
output_schema=output_schema,
)

# Capture original parameters.
ast, ast_id = None, None
if _emit_ast:
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.udtf, stmt)
ast_id = stmt.var_id.bitfield1

build_udtf(
ast,
handler,
output_schema=output_schema,
input_types=input_types,
name=name,
stage_location=stage_location,
imports=imports,
packages=packages,
replace=replace,
if_not_exists=if_not_exists,
parallel=parallel,
max_batch_size=max_batch_size,
strict=strict,
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
comment=comment,
statement_params=statement_params,
is_permanent=is_permanent,
session=self._session,
udtf_name=udtf_name,
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
)

foo = MockUserDefinedTableFunction(
handler, output_schema, input_types, udtf_name, packages=packages
handler,
output_schema,
input_types,
udtf_name,
packages=packages,
_ast=ast,
_ast_id=ast_id,
)

# Add to registry to MockPlan can execute.
self._registry[udtf_name] = foo

if _emit_ast:
foo._ast = ast
foo._ast_id = (
stmt.var_id.bitfield1
) # Reference UDTF by its assign/statement id.

return foo
18 changes: 9 additions & 9 deletions src/snowflake/snowpark/table_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,25 @@ def over(
expr = with_src_position(ast.sp_table_fn_call_over)
expr.lhs.CopyFrom(self._ast)
if partition_by is not None:
if isinstance(partition_by, Iterable):
if isinstance(partition_by, (str, Column)):
build_expr_from_snowpark_column_or_col_name(
expr.partition_by.add(), partition_by
)
else:
for partition_clause in partition_by:
build_expr_from_snowpark_column_or_col_name(
expr.partition_by.add(), partition_clause
)
else:
if order_by is not None:
if isinstance(order_by, (str, Column)):
build_expr_from_snowpark_column_or_col_name(
expr.partition_by.add(), partition_by
expr.order_by.add(), order_by
)
if order_by is not None:
if isinstance(order_by, Iterable):
else:
for order_clause in order_by:
build_expr_from_snowpark_column_or_col_name(
expr.order_by.add(), order_clause
)
else:
build_expr_from_snowpark_column_or_col_name(
expr.order_by.add(), order_by
)

new_table_function = TableFunctionCall(
self.name, *self.arguments, _ast=ast, **self.named_arguments
Expand Down
Loading
Loading