Skip to content

Commit

Permalink
A couple of bug fixes (#1945)
Browse files Browse the repository at this point in the history
Fixes a couple of bugs that show up in GPT2 optimization.
  • Loading branch information
gramalingam authored Nov 14, 2024
1 parent 5a35958 commit d81480b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl

# Identify call-stack for node, used to generate unique names.
call_stack = self.node_context.get(node, [])
call_stack.append(call_site_id)
new_call_stack = [*call_stack, call_site_id]

cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, call_stack)
cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack)

# iterate over the nodes in the function, creating a copy of each node
# and replacing inputs with the corresponding values in the value map.
Expand Down
7 changes: 4 additions & 3 deletions onnxscript/rewriter/collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def _check_if_redundant_slice(
axes_const = axes.const_value
steps_const = steps.const_value

if starts_const is None or ends_const is None or axes_const is None or steps_const is None:
logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.")
return False

# Check if the values are scalar
if starts_const.numpy().size != 1: # type: ignore[union-attr]
logger.info("The value 'start' is not a scalar.")
Expand All @@ -42,9 +46,6 @@ def _check_if_redundant_slice(
logger.info("The value 'step' is not a scalar.")
return False

if starts_const is None or ends_const is None or axes_const is None or steps_const is None:
logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.")
return False
if steps_const.numpy().item() != 1:
logger.info("The value 'step' is not 1.")
return False
Expand Down

0 comments on commit d81480b

Please sign in to comment.