Skip to content

Commit

Permalink
fix error
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yzou committed Nov 15, 2024
1 parent f6ab17f commit 2add71b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

- Improve np.where with scalar x value by eliminating unnecessary join and temp table creation.
- Improve get_dummies performance by flattening the pivot with join.
- Improve align performance when align on row position column by removing unnecessary window functions.



Expand Down
28 changes: 12 additions & 16 deletions src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,9 +1443,15 @@ def align(
left = self.ensure_row_position_column()
right = right.ensure_row_position_column()

# whether the alignment is performed on the row position column of each dataframe.
# In other words, this indicates whether the alignment is applied on a unique column
# of each dataframe. Optimizations can be applied based on this information.
align_on_row_position_column = left_on_cols == [
left.row_position_snowflake_quoted_identifier
] and right_on_cols == [right.row_position_snowflake_quoted_identifier]
# If the alignment is applied on the unique row position column, and the method is
# not "coalesce", the align operation can be directly mapped as a join. No extra filtering
# will be needed.
direct_join_map = align_on_row_position_column and how != "coalesce"

# perform outer join
Expand Down Expand Up @@ -1531,9 +1537,12 @@ def align(
right_count = coalesce(max_(right_row_pos).over() + 1, lit(0))
eq_row_pos_count = sum_(iff(left_row_pos == right_row_pos, 1, 0)).over()

# align_on_row_position_column = False
ordering_columns = joined_ordered_frame.ordering_columns
if align_on_row_position_column:
# when the alignment is applied on the row position column, there is no need to do
# filtering based on the column matching. Since the columns align on have unique values,
# if they match, the join will already give the result. If not, since the column values
# are unique, there will no duplicated rows to filter
align_filter = None
else:
# 'col_matching_expr' represents if left_on_cols is an exact match with right_on_cols.
Expand Down Expand Up @@ -1580,7 +1589,6 @@ def align(
] + ordering_columns

align_filter = not_(col_matching_column) | (left_row_pos == right_row_pos)
# filter_expression = not_(col_matching_column) | (left_row_pos == right_row_pos)

joined_ordered_frame = joined_ordered_frame.select(
joined_ordered_frame.projected_column_snowflake_quoted_identifiers
Expand Down Expand Up @@ -1662,11 +1670,6 @@ def align(
right_row_pos.is_not_null(), # right join
left_row_pos.is_not_null(), # left join
)
# filter_expression = filter_expression & iff(
# left_count_column == 0,
# right_row_pos.is_not_null(), # right join
# left_row_pos.is_not_null(), # left join
# )
from snowflake.snowpark.modin.plugin._internal.utils import (
unquote_name_if_quoted,
)
Expand Down Expand Up @@ -1699,22 +1702,18 @@ def align(
select_list.append(identifier)
elif how == "left":
join_filter = left_row_pos.is_not_null()
# filter_expression = filter_expression & left_row_pos.is_not_null()
select_list = result_projected_column_snowflake_quoted_identifiers
elif how == "inner":
join_filter = left_row_pos.is_not_null() & right_row_pos.is_not_null()
# filter_expression = (
# filter_expression
# & left_row_pos.is_not_null()
# & right_row_pos.is_not_null()
# )
select_list = result_projected_column_snowflake_quoted_identifiers
elif how == "outer":
select_list = result_projected_column_snowflake_quoted_identifiers
else:
raise ValueError(
f"how={how} is not valid argument for ordered_dataframe.align."
)

# apply all filters to the joined_ordered_frame
if (align_filter is not None) and (join_filter is not None):
joined_ordered_frame = joined_ordered_frame.filter(
align_filter & join_filter
Expand All @@ -1724,9 +1723,6 @@ def align(
elif join_filter is not None:
joined_ordered_frame = joined_ordered_frame.filter(join_filter)

# joined_ordered_frame = joined_ordered_frame.filter(filter_expression).sort(
# ordering_columns
# )
joined_ordered_frame.sort(ordering_columns)

# call select to make sure only the result_projected_column_snowflake_quoted_identifiers are projected
Expand Down

0 comments on commit 2add71b

Please sign in to comment.