Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When converting a traced torchvision model AssertionError: type_inference: axis=0, i=1: 256 != is452 #1795

Open
ivyas21 opened this issue Mar 7, 2023 · 1 comment
Labels
bug Unexpected behaviour that should be corrected (type) triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@ivyas21
Copy link

ivyas21 commented Mar 7, 2023

When converting a traced torchvision model, After applying roi_align from #1509 AssertionError: type_inference: axis=0, i=1: 256 != is452

Stack Trace

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_31355/3386583322.py in <module>
      5     traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
      6 
----> 7 detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=(1, 3, 224, 224))])
      8 detector_mlmodel.save("segmenter.mlmodel")

/opt/conda/lib/python3.7/site-packages/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)
    454         package_dir=package_dir,
    455         debug=debug,
--> 456         specification_version=specification_version,
    457     )
    458 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    185         See `coremltools.converters.convert`
    186     """
--> 187     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
    188 
    189 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    214                             convert_to,
    215                             registry,
--> 216                             **kwargs
    217                          )
    218 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    279     frontend_converter = frontend_converter_type()
    280 
--> 281     prog = frontend_converter(model, **kwargs)
    282 
    283     if convert_to.lower() != "neuralnetwork":

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
    107         from .frontend.torch import load
    108 
--> 109         return load(*args, **kwargs)
    110 
    111 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)
     55     inputs = _convert_to_torch_inputtype(inputs)
     56     converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols, specification_version)
---> 57     return _perform_torch_convert(converter, debug)
     58 
     59 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in _perform_torch_convert(converter, debug)
     94 def _perform_torch_convert(converter, debug):
     95     try:
---> 96         prog = converter.convert()
     97     except RuntimeError as e:
     98         if debug and "convert function" in str(e):

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
    279 
    280             # Add the rest of the operations
--> 281             convert_nodes(self.context, self.graph)
    282 
    283             graph_outputs = [self.context[name] for name in self.graph.outputs]

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
     87 
     88         context.prepare_for_conversion(node)
---> 89         add_op(context, node)
     90 
     91         # We've generated all the outputs the graph needs, terminate conversion.

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in scatter(context, node)
   5228         mode = 'update'
   5229 
-> 5230     _scatter(context, inputs, mode, node.name)
   5231 
   5232 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in _scatter(context, inputs, mode, name)
   5209     if types.is_scalar(updates.sym_type):
   5210         updates = mb.fill(shape=indices.shape, value=updates.val, name=name)
-> 5211     result = mb.scatter_along_axis(data=data, indices=indices, updates=updates,axis=axis, mode=mode, name=name)
   5212     context.add(result)
   5213 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/registry.py in add_op(cls, **kwargs)
    174                     op_cls_to_add = op_reg[op_type]
    175 
--> 176                 return cls._add_op(op_cls_to_add, **kwargs)
    177 
    178             setattr(Builder, op_type, add_op)

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/builder.py in _add_op(cls, op_cls, **kwargs)
    180         curr_block()._insert_op_before(new_op, before_op=before_op)
    181         new_op.build_nested_blocks()
--> 182         new_op.type_value_inference()
    183         if len(new_op.outputs) == 1:
    184             return new_op.outputs[0]

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/operation.py in type_value_inference(self, overwrite_output)
    251         existing _output_vars
    252         """
--> 253         output_types = self.type_inference()
    254         if not isinstance(output_types, tuple):
    255             output_types = (output_types,)

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py in type_inference(self)
    431         for i in range(self.data.rank):
    432             if i != axis:
--> 433                 assert self.data.shape[i] == self.indices.shape[i], f'type_inference: axis={axis}, i={i}: {self.data.shape[i]} != {self.indices.shape[i]}'
    434 
    435         return self.data.sym_type

AssertionError: type_inference: axis=0, i=1: 256 != is452

Steps To Reproduce

import coremltools as ct
import torch, torchvision
from torchvision.transforms import functional as F, InterpolationMode, transforms as T
import requests
from PIL import Image
import numpy as np
from typing import Dict, Tuple, Optional

# Image conversion tools:
class PILToTensor(torch.nn.Module):
    def forward(
        self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        image = F.pil_to_tensor(image)
        return image, target

class ConvertImageDtype(torch.nn.Module):
    def __init__(self, dtype: torch.dtype) -> None:
        super().__init__()
        self.dtype = dtype

    def forward(
        self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        image = F.convert_image_dtype(image, self.dtype)
        return image, target

# Load the torchvision model
detector_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
detector_model = detector_model.eval()

# Get a sample image
toTensor = T.PILToTensor()
toFloatTensor = T.ConvertImageDtype(torch.float)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

example_image_np = np.array(example_image)
example_image_pt = toFloatTensor(toTensor(example_image))
example_image_pt = example_image_pt.unsqueeze(0)

# Run the sample through the model to demonstrate the model works
y = detector_model(example_image_pt)

# Make an adaptor to convert the model outputs to a tuple
class FasterRCNN_MobileNetV3_AdapterModel(torch.nn.Module):
    """This adapter is only here to unbox the first output."""
    def __init__(self, model, w=2):
        super().__init__()
        self.model = model

    def forward(self, x):
        result = self.model(x)
        return result[0]['boxes'], result[0]['labels'], result[0]['scores']

adapted_detector_model = FasterRCNN_MobileNetV3_AdapterModel(detector_model)

# Trace and convert the model using coremltools
model_to_trace = adapted_detector_model
with torch.inference_mode():
    out = model_to_trace(example_image_pt)
    traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
    
detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=example_image_pt.shape)])
detector_mlmodel.save("segmenter.mlmodel")

System environment:

  • coremltools version: 6.2
  • OS: Linux (Linux foohostname 4.19.0-23-cloud-amd64 #1 SMP Debian 4.19.269-1 (2022-12-20) x86_64 GNU/Linux)
  • Any other relevant version information (e.g. PyTorch or TensorFlow version):
    • Python: 3.7
    • PyTorch: 1.11.1+cu102
    • Other libraries installed as dependencies of coremltools:
Requirement already satisfied: coremltools==6.2 in /opt/conda/lib/python3.7/site-packages (6.2)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (4.64.1)
Requirement already satisfied: protobuf<=4.0.0,>=3.1.0 in /home/jupyter/.local/lib/python3.7/site-packages (from coremltools==6.2) (3.20.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (21.3)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.21.6)
Requirement already satisfied: sympy in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.10.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->coremltools==6.2) (3.0.9)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.7/site-packages (from sympy->coremltools==6.2) (1.2.1)

Please advise. Thank you!

@ivyas21 ivyas21 added the bug Unexpected behaviour that should be corrected (type) label Mar 7, 2023
@junpeiz
Copy link
Collaborator

junpeiz commented Mar 8, 2023

Thank you for providing the detailed steps for reproducing it! Seems like a bug in mb.scatter_along_axis's value inference when input has symbolic shape.

@junpeiz junpeiz added the triaged Reviewed and examined, release as been assigned if applicable (status) label Mar 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

No branches or pull requests

2 participants