You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importcoremltoolsasctimporttorch, torchvisionfromtorchvision.transformsimportfunctionalasF, InterpolationMode, transformsasTimportrequestsfromPILimportImageimportnumpyasnpfromtypingimportDict, Tuple, Optional# Image conversion tools:classPILToTensor(torch.nn.Module):
defforward(
self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] =None
) ->Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
image=F.pil_to_tensor(image)
returnimage, targetclassConvertImageDtype(torch.nn.Module):
def__init__(self, dtype: torch.dtype) ->None:
super().__init__()
self.dtype=dtypedefforward(
self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] =None
) ->Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
image=F.convert_image_dtype(image, self.dtype)
returnimage, target# Load the torchvision modeldetector_model=torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
detector_model=detector_model.eval()
# Get a sample imagetoTensor=T.PILToTensor()
toFloatTensor=T.ConvertImageDtype(torch.float)
url='http://images.cocodataset.org/val2017/000000039769.jpg'example_image=Image.open(requests.get(url, stream=True).raw).convert("RGB")
example_image_np=np.array(example_image)
example_image_pt=toFloatTensor(toTensor(example_image))
example_image_pt=example_image_pt.unsqueeze(0)
# Run the sample through the model to demonstrate the model worksy=detector_model(example_image_pt)
# Make an adaptor to convert the model outputs to a tupleclassFasterRCNN_MobileNetV3_AdapterModel(torch.nn.Module):
"""This adapter is only here to unbox the first output."""def__init__(self, model, w=2):
super().__init__()
self.model=modeldefforward(self, x):
result=self.model(x)
returnresult[0]['boxes'], result[0]['labels'], result[0]['scores']
adapted_detector_model=FasterRCNN_MobileNetV3_AdapterModel(detector_model)
# Trace and convert the model using coremltoolsmodel_to_trace=adapted_detector_modelwithtorch.inference_mode():
out=model_to_trace(example_image_pt)
traced_model=torch.jit.trace(model_to_trace, example_image_pt).eval()
detector_mlmodel=ct.convert(traced_model, inputs=[ct.ImageType(shape=example_image_pt.shape)])
detector_mlmodel.save("segmenter.mlmodel")
Thank you for providing the detailed steps for reproducing it! Seems like a bug in mb.scatter_along_axis's value inference when input has symbolic shape.
junpeiz
added
the
triaged
Reviewed and examined, release as been assigned if applicable (status)
label
Mar 8, 2023
When
convert
ing a tracedtorchvision
model, After applyingroi_align
from #1509AssertionError: type_inference: axis=0, i=1: 256 != is452
Stack Trace
Steps To Reproduce
System environment:
coremltools
version: 6.2Linux foohostname 4.19.0-23-cloud-amd64 #1 SMP Debian 4.19.269-1 (2022-12-20) x86_64 GNU/Linux
)coremltools
:Please advise. Thank you!
The text was updated successfully, but these errors were encountered: