Skip to content

Commit

Permalink
Trt compiler fixes (#8064)
Browse files Browse the repository at this point in the history
Fixes #8061.

### Description

Post-merge fixes for trt_compile()

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: binliunls <[email protected]>
  • Loading branch information
6 people authored Sep 4, 2024
1 parent befb5f6 commit aea46ff
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions monai/networks/trt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def forward(self, model, argv, kwargs):
self._build_and_save(model, build_args)
# This will reassign input_names from the engine
self._load_engine()
assert self.engine is not None
except Exception as e:
if self.fallback:
self.logger.info(f"Failed to build engine: {e}")
Expand Down Expand Up @@ -403,8 +404,10 @@ def _onnx_to_trt(self, onnx_path):

build_args = self.build_args.copy()
build_args["tf32"] = self.precision != "fp32"
build_args["fp16"] = self.precision == "fp16"
build_args["bf16"] = self.precision == "bf16"
if self.precision == "fp16":
build_args["fp16"] = True
elif self.precision == "bf16":
build_args["bf16"] = True

self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}")
network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
Expand Down Expand Up @@ -502,6 +505,7 @@ def trt_compile(
) -> torch.nn.Module:
"""
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x
Args:
model: module to patch with TrtCompiler object.
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trt_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from monai.handlers import TrtHandler
from monai.networks import trt_compile
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
from monai.utils import optional_import
from monai.utils import min_version, optional_import
from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows

trt, trt_imported = optional_import("tensorrt")
trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
polygraphy, polygraphy_imported = optional_import("polygraphy")
build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")

Expand Down

0 comments on commit aea46ff

Please sign in to comment.