Skip to content

Commit

Permalink
Add SM architecture version check (#8199)
Browse files Browse the repository at this point in the history
Fixes #8198

NVIDIA Volta support (GPUs with compute capability 7.0) has been removed
starting with TensorRT 10.5. Review the [TensorRT Support
Matrix](https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html)
for which GPUs are supported by this release.
Add SM architecture version check to skip trt test before 7.0.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
KumoLiu and pre-commit-ci[bot] authored Nov 13, 2024
1 parent 0bb20a8 commit b6663b9
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 8 deletions.
2 changes: 2 additions & 0 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,8 @@ def trt_export(
"""
Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript.
Currently, this API only supports converting models whose inputs are all tensors.
Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
Review the TensorRT Support Matrix for which GPUs are supported.
There are two ways to export a model:
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
Expand Down
4 changes: 3 additions & 1 deletion monai/networks/trt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,9 @@ def trt_compile(
) -> torch.nn.Module:
"""
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.
NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
Review the TensorRT Support Matrix for which GPUs are supported.
Args:
model: module to patch with TrtCompiler object.
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
InvalidPyTorchVersionError,
OptionalImportError,
allow_missing_reference,
compute_capabilities_after,
damerau_levenshtein_distance,
exact_version,
get_full_type_name,
Expand Down
41 changes: 41 additions & 0 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
if is_prerelease:
return False
return True


@functools.lru_cache(None)
def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool:
"""
Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
The current system GPU CUDA compute capability is determined by the first GPU in the system.
The compared version is a string in the form of "major.minor".
Args:
major: major version number to be compared with.
minor: minor version number to be compared with. Defaults to 0.
current_ver_string: if None, the current system GPU CUDA compute capability will be used.
Returns:
True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
"""
if current_ver_string is None:
cuda_available = torch.cuda.is_available()
pynvml, has_pynvml = optional_import("pynvml")
if not has_pynvml: # assuming that the user has Ampere and later GPU
return True
if not cuda_available:
return False
else:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU
major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
current_ver_string = f"{major_c}.{minor_c}"
pynvml.nvmlShutdown()

ver, has_ver = optional_import("packaging.version", name="parse")
if has_ver:
return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
while len(parts) < 2:
parts += ["0"]
c_major, c_minor = parts[:2]
c_mn = int(c_major), int(c_minor)
mn = int(major), int(minor)
return c_mn > mn
9 changes: 8 additions & 1 deletion tests/test_bundle_trt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from monai.data import load_net_with_metadata
from monai.networks import save_state
from monai.utils import optional_import
from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows
from tests.utils import (
SkipIfBeforeComputeCapabilityVersion,
command_line_tests,
skip_if_no_cuda,
skip_if_quick,
skip_if_windows,
)

_, has_torchtrt = optional_import(
"torch_tensorrt",
Expand All @@ -47,6 +53,7 @@
@skip_if_windows
@skip_if_no_cuda
@skip_if_quick
@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestTRTExport(unittest.TestCase):

def setUp(self):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_convert_to_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.networks import convert_to_trt
from monai.networks.nets import UNet
from monai.utils import optional_import
from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows
from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows

_, has_torchtrt = optional_import(
"torch_tensorrt",
Expand All @@ -38,6 +38,7 @@
@skip_if_windows
@skip_if_no_cuda
@skip_if_quick
@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestConvertToTRT(unittest.TestCase):

def setUp(self):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_trt_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from monai.networks import trt_compile
from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
from monai.utils import min_version, optional_import
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
from tests.utils import (
SkipIfAtLeastPyTorchVersion,
SkipIfBeforeComputeCapabilityVersion,
skip_if_no_cuda,
skip_if_quick,
skip_if_windows,
)

trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
polygraphy, polygraphy_imported = optional_import("polygraphy")
Expand All @@ -36,6 +42,7 @@
@skip_if_quick
@unittest.skipUnless(trt_imported, "tensorrt is required")
@unittest.skipUnless(polygraphy_imported, "polygraphy is required")
@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestTRTCompile(unittest.TestCase):

def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from parameterized import parameterized

from monai.utils import pytorch_after
from monai.utils import compute_capabilities_after, pytorch_after

TEST_CASES = (
TEST_CASES_PT = (
(1, 5, 9, "1.6.0"),
(1, 6, 0, "1.6.0"),
(1, 6, 1, "1.6.0", False),
Expand All @@ -36,14 +36,30 @@
(1, 6, 1, "1.6.0+cpu", False),
)

TEST_CASES_SM = [
# (major, minor, sm, expected)
(6, 1, "6.1", True),
(6, 1, "6.0", False),
(6, 0, "8.6", True),
(7, 0, "8", True),
(8, 6, "8", False),
]


class TestPytorchVersionCompare(unittest.TestCase):

@parameterized.expand(TEST_CASES)
@parameterized.expand(TEST_CASES_PT)
def test_compare(self, a, b, p, current, expected=True):
"""Test pytorch_after with a and b"""
self.assertEqual(pytorch_after(a, b, p, current), expected)


class TestComputeCapabilitiesAfter(unittest.TestCase):

@parameterized.expand(TEST_CASES_SM)
def test_compute_capabilities_after(self, major, minor, sm, expected):
self.assertEqual(compute_capabilities_after(major, minor, sm), expected)


if __name__ == "__main__":
unittest.main()
16 changes: 15 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from monai.networks import convert_to_onnx, convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import MONAIEnvVars
from monai.utils.module import pytorch_after
from monai.utils.module import compute_capabilities_after, pytorch_after
from monai.utils.tf32 import detect_default_tf32
from monai.utils.type_conversion import convert_data_type

Expand Down Expand Up @@ -286,6 +286,20 @@ def __call__(self, obj):
)(obj)


class SkipIfBeforeComputeCapabilityVersion:
"""Decorator to be used if test should be skipped
with Compute Capability older than that given."""

def __init__(self, compute_capability_tuple):
self.min_version = compute_capability_tuple
self.version_too_old = not compute_capabilities_after(*compute_capability_tuple)

def __call__(self, obj):
return unittest.skipIf(
self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}"
)(obj)


def is_main_test_process():
ps = torch.multiprocessing.current_process()
if not ps or not hasattr(ps, "name"):
Expand Down

0 comments on commit b6663b9

Please sign in to comment.