Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] torch.listconstruct causing issue for other ops #1926

Open
YifanShenSZ opened this issue Jul 27, 2023 · 10 comments
Open

[PyTorch] torch.listconstruct causing issue for other ops #1926

YifanShenSZ opened this issue Jul 27, 2023 · 10 comments
Labels
bug Unexpected behaviour that should be corrected (type)

Comments

@YifanShenSZ
Copy link
Collaborator

YifanShenSZ commented Jul 27, 2023

This is the root cause to many issues. When symbolic shape is involved in torch.listconstruct, instead of a CoreML tensor, we simply return the list as is

def _array_construct(context, node, array_type):

    ...

    else:
        # If at least one input to the construct op is non-const, collect
        # the inputs and add them directly to the context. Ops that use this
        # node's output will take the list directly as input.
        context.add(array_type(inputs), node.name)

Appendix 1: Issues Sharing the Same Root Cause

Appendix 2: Ops Impacted by the Root Cause

  • torch.GroupNorm
  • torch.pad
  • torch.index_put
@YifanShenSZ YifanShenSZ added the bug Unexpected behaviour that should be corrected (type) label Jul 27, 2023
@YifanShenSZ
Copy link
Collaborator Author

In the proposed fix to issue #1303, we are trying to gather symbols

    else:
        # Create the new_shape and input_shape to support dynamic sizes.
        xshape = mb.shape(x=x)
        for i, v in enumerate(x.shape[2:]):
            if is_symbolic(v):
                si = mb.gather(x=xshape, indices=i+2, axis=0)
                new_shape.append(si)
                input_shape[i+2] = si
            else:
                new_shape.append(v)

@YifanShenSZ
Copy link
Collaborator Author

In reproducing issue #1921, torch.listconstruct is found to be culprit

(Pdb) print(context.torch_graph)
graph(
    %x : Tensor(1, 192, RangeDim(lower_bound=2, upper_bound=1024, default=10, symbol="is0"), 'None'),
):
  %2 = constant[value=2]()
  %3 = size[](%x, %2)
  %pad_length = numtotensor[](%3)
  %5 = int[](%pad_length)
  %6 = int[](%pad_length)
  %7 = constant[value=0]()
  %8 = constant[value=0]()
  %9 = listconstruct[](%7, %8, %6, %5)
  %10 = constant[value=constant]()
  %11 = constant[]()
  %12 = pad[](%x, %9, %10, %11)
return (%12)
(Pdb) print(node)
  %12 = pad[](%x, %9, %10, %11)

The pad, i.e. variable %9, is the output of torch.listconstruct, which consists of 2 consts and 2 symbols

@YifanShenSZ YifanShenSZ changed the title [PyTorch] torch.listconstruct causing issue for many ops [PyTorch] torch.listconstruct causing issue for other ops Jul 27, 2023
@xorange
Copy link

xorange commented Nov 8, 2023

Yes, for the #1921 case (for op 'pad'):

# pseudo code
@register_torch_op
def listconstruct():
    if constant shape:        # no bug case
        return mb.const(val=[static shapes, ...])

    else:                              # #1921 case
        return [mixed of static shapes, and dynamic shapes, ...]    # which failed to parse in op 'pad'

If we want to fix it referring to the solution in #1303 ( #1922 ), we basically want:

# pseudo code
def listconstruct():
    if constant shape:        # no bug case
        return mb.const(val=[static shapes, ...])

    else if_match_#1921_case():
        static_shapes = [...]    # extract static shapes from inputs
        sliced_dynamic_shapes = mb.slice_by_size(x=, begin=[], size=)    # extract dynamic shape syms from inputs
        return mb.concat(values=[static_shapes, sliced_dynamic_shapes])    # mb.concat() the static and dynamic parts of the shape

To be specific, in #1921 case:

>>> context.torch_graph
graph(
    %x : Tensor(1, 192, RangeDim(lower_bound=2, upper_bound=1024, default=10, symbol="is0"), 'None'),
):
  %2 = constant[value=2]()
  %3 = size[](%x, %2)
  %pad_length = numtotensor[](%3)
  %5 = int[](%pad_length)
  %6 = int[](%pad_length)
  %7 = constant[value=0]()
  %8 = constant[value=0]()
  %9 = listconstruct[](%7, %8, %6, %5)
  %10 = constant[value=constant]()
  %11 = constant[]()
  %12 = pad[](%x, %9, %10, %11)
return (%12)
>>> node.name               # I pdb in def _array_construct()
'9'

>>> node.inputs
['7', '8', '6', '5']
>>> context['7'].val        # static shapes
0
>>> context['8'].val        # static shapes
0
>>> context['6'].op         # dynamic shapes
  %6: (int32)(Scalar) = cast(x=%gather_0, dtype="int32", name="6")

>>> context['5'].op         # dynamic shapes
  %5: (int32)(Scalar) = cast(x=%gather_0, dtype="int32", name="5")

if we hard coded the solution for it:

def _array_construct(context, node, array_type):
    
    ...

    else:
        # context.add(array_type(inputs), node.name)

        static_7_8 = [context['7'].val, context['8'].val]
        sliced_dynamic_6 = mb.slice_by_size(x=mb.shape(x=context['x']), begin=[2], size=[1])
        sliced_dynamic_5 = mb.slice_by_size(x=mb.shape(x=context['x']), begin=[2], size=[1])
        context.add(mb.concat(values=[static_7_8, sliced_dynamic_6, sliced_dynamic_5], axis=0), node.name)

and modify op pad register so that it supports padding with sym vals:

@register_torch_op(torch_alias=['constant_pad_nd'])
def pad(context, node):

    ...

    if pad.val is not None:
        ...

    else:
        missing_dims = (x.rank * 2 - pad.shape[0]) // 2
        pad = mb.concat(values=[pad, [0, 0] * missing_dims], axis=0)
        pad = mb.reverse(x=pad, axes=[0])

#1921 is confirmed fixed (temporarily)

@xorange
Copy link

xorange commented Nov 8, 2023

However, I have some doubts before proposing a general fix for it.

in the above case, the padding value is hard-coded:

  • from node '7' and '8' (they are known to be static),
  • and from input x directly (because we know it being the root cause)

My question:

How can we relate from node '9'.inputs all the way to input x symbolic value, with data structure:

  • context,
  • node, and
  • <coremltools.converters.mil.mil.var.Var object> ? (context['9'].inputs[2] for example)
  • or Is there any other infomations are at hand ?

Edit: I've noticed context.torch_graph.nodes and it answers my question above. WIP...

@YifanShenSZ
Copy link
Collaborator Author

YifanShenSZ commented Nov 13, 2023

Hi @xorange, thanks for looking into this issue! About relating a symbol to input symbols, you should be able to simply compare if those symbols are the same: we propagate symbols using sympy

As of the fix, I have several thoughts that might be easier:

  1. pad-specific fix: In pad, given the constructed list, is it possible to use something like mb.stack, mb.concat, or mb.gather to construct a tensor?
  2. general fix: The ultimate problem is, if we change _array_construct output signature, it would break backward compatibility 😞 All functions that rely on "list of symbols" rather than tensor would break

@xorange
Copy link

xorange commented Nov 14, 2023

Quoting from #2050:

  1. This fix addes another branch, only targets for an op gather, and what it gathers from is not a name in context.torch_graph.nodes, i.e. the net inputs.

Upon #2037, and another net structure on my hand that shared a similar root cause, it is clear that only targetting op gather is not enough to provide a generalized fix.

for example,

>>> context.torch_graph
graph(
    %x.1 : Tensor(1, 3, RangeDim(lower_bound=300, upper_bound=400, default=300, symbol="is0"), RangeDim(lower_bound=300, upper_bound=400, default=300, symbol="is1"), 'None'),
    ...
)
%input.1 = _convolution[](%x.1, %model.features.conv1.0.weight, %8, %31, %32, %33, %12, %34, %11, %12, %12, %14, %14)

or

>>> context.torch_graph
graph(
    %x.1 : Tensor(1, RangeDim(lower_bound=5, upper_bound=512, default=275, symbol="is0"), 'None'),
    %x_mask : Tensor(1, 1, RangeDim(lower_bound=5, upper_bound=512, default=275, symbol="is1"), 'None'),
    ...
)
%9 = embedding[](%emb.weight, %x.1, %7, %6, %6)
%x.3 = mul[](%9, %10)
%x.5 = transpose[](%x.3, %12, %13)
%x.7 = mul[](%x.5, %x_mask)
%input.1 = mul[](%x.7, %x_mask)

@xorange
Copy link

xorange commented Nov 14, 2023

Hi @xorange, thanks for looking into this issue! About relating a symbol to input symbols, you should be able to simply compare if those symbols are the same: we propagate symbols using sympy

Thanks for reply ! I'll look into it.

As of the fix, I have several thoughts that might be easier:

  1. pad-specific fix: In pad, given the constructed list, is it possible to use something like mb.stack, mb.concat, or mb.gather to construct a tensor?

Yes this should be a cleaner way for pad.

  1. general fix: The ultimate problem is, if we change _array_construct output signature, it would break backward compatibility 😞 All functions that rely on "list of symbols" rather than tensor would break

I think we both agree that a generalized fix is what we want here... Because I've already come across several cases that pad is not the culprit.

Could you share some functions that rely on "list of symbols" for me to design ? Let me see if I can cover for those, or learn the current design better (because clearly I'm missing something here).

@YifanShenSZ
Copy link
Collaborator Author

Could you share some functions that rely on "list of symbols" for me to design ? Let me see if I can cover for those, or learn the current design better (because clearly I'm missing something here).

Unfortunately I cannot tell from top of my mind 😞 We could try to modify _array_construct, then pytest --pyargs coremltools.converters.mil.frontend.torch to see what gets broken

@kdonbekci
Copy link

Any progress on this issue?

@xorange
Copy link

xorange commented Apr 25, 2024

Any progress on this issue?

None on my end. Having trouble coordinating between work and spare time for this, and it requires a lot to digest the whole design. No progress will be made from me at least before 2024 Q4 sry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type)
Projects
None yet
Development

No branches or pull requests

3 participants