Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Torch var stride_width.1 not found in context #1788

Open
ivyas21 opened this issue Mar 2, 2023 · 14 comments
Open

ValueError: Torch var stride_width.1 not found in context #1788

ivyas21 opened this issue Mar 2, 2023 · 14 comments
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@ivyas21
Copy link

ivyas21 commented Mar 2, 2023

When converting a traced torchvision model, an expected input to a mul operation is not found: ValueError: Torch var stride_width.1 not found in context

Stack Trace

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_16536/1090456604.py in <module>
      5     traced_model = torch.jit.trace(model_to_trace, example_image).eval()
      6 
----> 7 detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=(1, 3, 224, 224))])
      8 detector_mlmodel.save("segmenter.mlmodel")

/opt/conda/lib/python3.7/site-packages/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)
    454         package_dir=package_dir,
    455         debug=debug,
--> 456         specification_version=specification_version,
    457     )
    458 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    185         See `coremltools.converters.convert`
    186     """
--> 187     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
    188 
    189 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    214                             convert_to,
    215                             registry,
--> 216                             **kwargs
    217                          )
    218 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    279     frontend_converter = frontend_converter_type()
    280 
--> 281     prog = frontend_converter(model, **kwargs)
    282 
    283     if convert_to.lower() != "neuralnetwork":

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
    107         from .frontend.torch import load
    108 
--> 109         return load(*args, **kwargs)
    110 
    111 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)
     55     inputs = _convert_to_torch_inputtype(inputs)
     56     converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols, specification_version)
---> 57     return _perform_torch_convert(converter, debug)
     58 
     59 

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in _perform_torch_convert(converter, debug)
     94 def _perform_torch_convert(converter, debug):
     95     try:
---> 96         prog = converter.convert()
     97     except RuntimeError as e:
     98         if debug and "convert function" in str(e):

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
    279 
    280             # Add the rest of the operations
--> 281             convert_nodes(self.context, self.graph)
    282 
    283             graph_outputs = [self.context[name] for name in self.graph.outputs]

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
     87 
     88         context.prepare_for_conversion(node)
---> 89         add_op(context, node)
     90 
     91         # We've generated all the outputs the graph needs, terminate conversion.

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in mul(context, node)
   1388 @register_torch_op
   1389 def mul(context, node):
-> 1390     inputs = _get_inputs(context, node, expected=2)
   1391     x, y = promote_input_dtypes(inputs)
   1392     res = mb.mul(x=x, y=y, name=node.name)

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in _get_inputs(context, node, expected, min_expected)
    187     value of @expected.
    188     """
--> 189     inputs = [context[name] for name in node.inputs]
    190     if expected is not None:
    191         expected = [expected] if not isinstance(expected, (list, tuple)) else expected

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in <listcomp>(.0)
    187     value of @expected.
    188     """
--> 189     inputs = [context[name] for name in node.inputs]
    190     if expected is not None:
    191         expected = [expected] if not isinstance(expected, (list, tuple)) else expected

/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/converter.py in __getitem__(self, torch_name)
     87                 return self._current_graph[idx][torch_name]
     88         raise ValueError(
---> 89             "Torch var {} not found in context {}".format(torch_name, self.name)
     90         )
     91 

ValueError: Torch var stride_width.1 not found in context 

Steps To Reproduce

import coremltools as ct
import torch, torchvision

detector_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights="DEFAULT")
detector_model = detector_model.eval()

class FasterRCNN_MobileNetV3_AdapterModel(torch.nn.Module):
    """This adapter is only here to unbox the first output."""
    def __init__(self, model, w=2):
        super().__init__()
        self.model = model

    def forward(self, x):
        result = self.model(x)
        return result[0]['boxes'], result[0]['labels'], result[0]['scores']

adapted_detector_model = FasterRCNN_MobileNetV3_AdapterModel(detector_model)

model_to_trace = adapted_detector_model
with torch.inference_mode():
    example_image = torch.rand(1,3,224,224)
    out = model_to_trace(example_image)
    traced_model = torch.jit.trace(model_to_trace, example_image).eval()
    
detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=(1, 3, 224, 224))])
detector_mlmodel.save("segmenter.mlmodel")
  • If the model conversion succeeds, but there is a numerical mismatch in predictions, please include the code used for comparisons.

System environment:

  • coremltools version: 6.2
  • OS: Linux (Linux foohostname 4.19.0-23-cloud-amd64 #1 SMP Debian 4.19.269-1 (2022-12-20) x86_64 GNU/Linux)
  • Any other relevant version information (e.g. PyTorch or TensorFlow version):
    • Python: 3.7
    • PyTorch: 1.12.1+cu102
    • Other libraries installed as dependencies of coremltools:
Requirement already satisfied: coremltools==6.2 in /opt/conda/lib/python3.7/site-packages (6.2)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (4.64.1)
Requirement already satisfied: protobuf<=4.0.0,>=3.1.0 in /home/jupyter/.local/lib/python3.7/site-packages (from coremltools==6.2) (3.20.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (21.3)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.21.6)
Requirement already satisfied: sympy in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.10.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->coremltools==6.2) (3.0.9)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.7/site-packages (from sympy->coremltools==6.2) (1.2.1)

Please advise. Thank you!

@ivyas21 ivyas21 added the bug Unexpected behaviour that should be corrected (type) label Mar 2, 2023
@TobyRoseman
Copy link
Collaborator

The out variable from your code contains three tensors all of which are empty. The PyTorch documentation for that model uses a different input shape.

Are you sure your PyTorch model is valid for the input shape you are using?

@ivyas21
Copy link
Author

ivyas21 commented Mar 2, 2023

Thank you @TobyRoseman for reaching out.

Are you sure your PyTorch model is valid for the input shape you are using?

I'm using correct valid input shape. Here is output shape: [{'boxes': tensor([], size=(0, 4), grad_fn=), 'labels': tensor([], dtype=torch.int64), 'scores': tensor([], grad_fn=)}]

Thats what I'm using it result[0]['boxes'], result[0]['labels'], result[0]['scores'].

@TobyRoseman
Copy link
Collaborator

Right, as I said, all of the output tensors are empty. Also the fact that the first tensor is of shape (0, 4) seems wrong.

Do you have a example_image value were the model's outputs are non-empty tensors?

@ivyas21
Copy link
Author

ivyas21 commented Mar 2, 2023

Sure, I've included example image here. Thank you @TobyRoseman.

import coremltools as ct
import torch, torchvision

detector_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
detector_model = detector_model.eval()

import requests
from PIL import Image
import numpy as np
import transforms as T

toTensor = T.PILToTensor()
toFloatTensor = T.ConvertImageDtype(torch.float)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

example_image_np = np.array(example_image)
example_image_pt = toFloatTensor(toTensor(example_image)[0])[0]
example_image_pt = example_image_pt.unsqueeze(0)

y = detector_model(example_image_pt)
print(y)

class FasterRCNN_MobileNetV3_AdapterModel(torch.nn.Module):
    """This adapter is only here to unbox the first output."""
    def __init__(self, model, w=2):
        super().__init__()
        self.model = model

    def forward(self, x):
        result = self.model(x)
        return result[0]['boxes'], result[0]['labels'], result[0]['scores']

adapted_detector_model = FasterRCNN_MobileNetV3_AdapterModel(detector_model)

model_to_trace = adapted_detector_model
with torch.inference_mode():
    #-example_image = torch.rand(1,3,224,224)
    out = model_to_trace(example_image_pt)
    traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
    
detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=(1, 3, 224, 224))])
detector_mlmodel.save("segmenter.mlmodel")

@TobyRoseman
Copy link
Collaborator

----> 1 import transforms as T

ModuleNotFoundError: No module named 'transforms'

Where is this package coming from? I'm certainly familiar with the transformers package but not the transforms package.

This same model is also causing a different error in #1790

@ivyas21
Copy link
Author

ivyas21 commented Mar 2, 2023

----> 1 import transforms as T

ModuleNotFoundError: No module named 'transforms'

Sorry I forgot to add. You need to install pip install transformers https://pypi.org/project/transformers/

@TobyRoseman
Copy link
Collaborator

I have the transformers package installed. Your code is importing transforms not transformers.

There is no PILToTensor or ConvertImageDtype under the transformers package.

transforms looks like a pretty obscure (i.e. untrusted) package. I'm not going to install it.

@ivyas21
Copy link
Author

ivyas21 commented Mar 3, 2023

@TobyRoseman , Okay I refined code to remove transforms. No additional library is needed:

import coremltools as ct
import torch, torchvision
from torchvision.transforms import functional as F, InterpolationMode, transforms as T
import requests
from PIL import Image
import numpy as np
from typing import Dict, Tuple, Optional

# Image conversion tools:
class PILToTensor(torch.nn.Module):
    def forward(
        self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        image = F.pil_to_tensor(image)
        return image, target

class ConvertImageDtype(torch.nn.Module):
    def __init__(self, dtype: torch.dtype) -> None:
        super().__init__()
        self.dtype = dtype

    def forward(
        self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        image = F.convert_image_dtype(image, self.dtype)
        return image, target

# Load the torchvision model
detector_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
detector_model = detector_model.eval()

# Get a sample image
toTensor = T.PILToTensor()
toFloatTensor = T.ConvertImageDtype(torch.float)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

example_image_np = np.array(example_image)
example_image_pt = toFloatTensor(toTensor(example_image))
example_image_pt = example_image_pt.unsqueeze(0)

# Run the sample through the model to demonstrate the model works
y = detector_model(example_image_pt)

# Make an adaptor to convert the model outputs to a tuple
class FasterRCNN_MobileNetV3_AdapterModel(torch.nn.Module):
    """This adapter is only here to unbox the first output."""
    def __init__(self, model, w=2):
        super().__init__()
        self.model = model

    def forward(self, x):
        result = self.model(x)
        return result[0]['boxes'], result[0]['labels'], result[0]['scores']

adapted_detector_model = FasterRCNN_MobileNetV3_AdapterModel(detector_model)

# Trace and convert the model using coremltools
model_to_trace = adapted_detector_model
with torch.inference_mode():
    out = model_to_trace(example_image_pt)
    traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
    
detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=example_image_pt.shape)])
detector_mlmodel.save("segmenter.mlmodel")

@TobyRoseman TobyRoseman added the triaged Reviewed and examined, release as been assigned if applicable (status) label Mar 3, 2023
@junpeiz
Copy link
Collaborator

junpeiz commented Mar 6, 2023

I met similar issues when converting FasterRCNN before.

This stride_width.1 issue is related to that the generate_tensor_assignment_ops IR Pass doesn’t work for some special fill_ ops in torch. From my experience, downgrading PyTorch to torch 1.11.0 with torchvision 0.12.0 make this issue gone.

@ivyas21 Could you try to use torch 1.11.0 with torchvision 0.12.0 to see if it works?

@ivyas21
Copy link
Author

ivyas21 commented Mar 6, 2023

@junpeiz , Thank you. Let me try it. Hope it works as you suggest!!

@ivyas21
Copy link
Author

ivyas21 commented Mar 6, 2023

@junpeiz , I just downgraded torch and vision as you suggested. I'm getting this error now:

RuntimeError: PyTorch convert function for op 'torchvision::roi_align' not implemented.

@junpeiz
Copy link
Collaborator

junpeiz commented Mar 6, 2023

@ivyas21 Good, now we can confirm that stride_width.1 issue is introduced by newer version of PyTorch. It's helpful for us to fix it in the future.

For the torchvision::roi_align not implemented issue, it's expected, as coremltools currently doesn't support that op. Feel free to write your own composite op by following https://coremltools.readme.io/docs/composite-operators, and also it seems that other users have proposed a solution to add that PR in #1509.

Let's keep this thread focus on the stride_width.1 issue. For torchvision::roi_align feel free to open another issue as a feature request.

@ivyas21
Copy link
Author

ivyas21 commented Mar 6, 2023

@junpeiz, Sounds good. Thank you for clarifying it. I'll open new issue for torchvision::roi_align. Many of latest library like Clip requires to upgrade torch & torchvision latest version. Hopefully stride_width.1 will be fixed soon.

@fukatani
Copy link
Contributor

@junpeiz
This issue seems to be solved.

I tried with main (e0f8918) and torch==1.13.1, I got

  File "/Users/ryosukefukatani/work/coremltools/faster.py", line 28, in <module>
    detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=(1, 3, 224, 224))])
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/_converters_entry.py", line 542, in convert
    main_pipeline=pass_pipeline,
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 217, in _mil_convert
    **kwargs
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 63, in load
    return _perform_torch_convert(converter, debug)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 110, in _perform_torch_convert
    raise e
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 102, in _perform_torch_convert
    prog = converter.convert()
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 439, in convert
    convert_nodes(self.context, self.graph)
  File "/Users/ryosukefukatani/work/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 89, in convert_nodes
    f"PyTorch convert function for op '{node.kind}' not implemented."
RuntimeError: PyTorch convert function for op 'torchvision::roi_align' not implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) PyTorch (traced) triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

No branches or pull requests

4 participants