We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
for the given IR
module { func.func @torch_jit(%arg0: !torch.vtensor<[1,3,224,224],f32>, %arg2: !torch.vtensor<[?,?,?,?,?,?],f32>, %arg3:!torch.vtensor<[?,?,?,?],f32> ) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} { %130 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_1867> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %380 = torch.operator "onnx.Shape"(%arg3) : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[4],si64> %381 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__35> : tensor<si64>} : () -> !torch.vtensor<[],si64> %382 = torch.operator "onnx.Gather"(%380, %381) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> %383 = torch.operator "onnx.Shape"(%arg3) : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[4],si64> %384 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__36> : tensor<si64>} : () -> !torch.vtensor<[],si64> %385 = torch.operator "onnx.Gather"(%383, %384) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> %389 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__38> : tensor<si64>} : () -> !torch.vtensor<[],si64> %390 = torch.operator "onnx.Div"(%382, %389) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> %393 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__39> : tensor<si64>} : () -> !torch.vtensor<[],si64> %394 = torch.operator "onnx.Div"(%385, %393) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> %408 = torch.operator "onnx.Mul"(%390, %394) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> %411 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__45> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %412 = torch.operator "onnx.Unsqueeze"(%408, %411) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> %413 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__46> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> %415 = torch.operator "onnx.Concat"(%411, %412, %130, %413) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> %416 = torch.operator "onnx.Reshape"(%arg2, %415) : (!torch.vtensor<[?,?,?,?,?,?],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %416 : !torch.vtensor<[?,?,?,?],f32> } } {-# dialect_resources: { builtin: { _onnx__Concat_1867: "0x08000000FFFFFFFFFFFFFFFF", __35: "0x080000000100000000000000", __36: "0x080000000200000000000000", __38: "0x080000000E00000000000000", __39: "0x080000000E00000000000000", __45: "0x080000000000000000000000", __46: "0x080000000000000000000000" } } #-}
getting error as "failed to legalize operation 'torch.aten.or.bool'"
command: iree-compile --iree-hal-target-backends=llvm-cpu model.torch_onnx.mlir
The text was updated successfully, but these errors were encountered:
The issue does not appear on removing this patch llvm/torch-mlir@140cad5 from IREE.
CC: @zjgarvey @pdhirajkumarprasad
Sorry, something went wrong.
This is a great reproducer, thanks @pdhirajkumarprasad. I'll take a look.
PR up for a fix. llvm/torch-mlir#3823
Blocked by a CI issue that got introduced in a different patch.
No branches or pull requests
for the given IR
getting error as "failed to legalize operation 'torch.aten.or.bool'"
command: iree-compile --iree-hal-target-backends=llvm-cpu model.torch_onnx.mlir
The text was updated successfully, but these errors were encountered: