diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c b/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c index 9f2077f064f..ff83a066283 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c @@ -21,14 +21,14 @@ #include +typedef enum { + UCT_CUDA_IPC_DEVICE_ADDR_FLAG_MNNVL = UCS_BIT(0) +} uct_cuda_ipc_device_addr_flags_t; + + typedef struct { uint64_t system_uuid; -#if HAVE_NVML_FABRIC_INFO - struct { - uint32_t clique_id; - uint8_t cluster_uuid[NVML_GPU_FABRIC_UUID_LEN]; - } mnnvl_addr; -#endif + uint8_t flags; /* uct_cuda_ipc_device_addr_flags_t */ } UCS_S_PACKED uct_cuda_ipc_device_addr_t; @@ -76,18 +76,14 @@ ucs_status_t uct_cuda_ipc_iface_get_device_address(uct_iface_t *tl_iface, uct_device_addr_t *addr) { uct_cuda_ipc_device_addr_t *dev_addr = (uct_cuda_ipc_device_addr_t*)addr; -#if HAVE_NVML_FABRIC_INFO uct_cuda_ipc_iface_t *iface = ucs_derived_of(tl_iface, uct_cuda_ipc_iface_t); uct_cuda_ipc_md_t *md = ucs_derived_of(iface->super.super.md, uct_cuda_ipc_md_t); if (md->enable_mnnvl) { - dev_addr->mnnvl_addr.clique_id = md->fabric_info.cliqueId; - memcpy(dev_addr->mnnvl_addr.cluster_uuid, md->fabric_info.clusterUuid, - sizeof(dev_addr->mnnvl_addr.cluster_uuid)); + dev_addr->flags = UCT_CUDA_IPC_DEVICE_ADDR_FLAG_MNNVL; } -#endif dev_addr->system_uuid = ucs_get_system_id(); @@ -102,26 +98,16 @@ static ucs_status_t uct_cuda_ipc_iface_get_address(uct_iface_h tl_iface, } static int -uct_cuda_ipc_iface_mnnvl_reachable(uct_cuda_ipc_md_t *md, +uct_cuda_ipc_iface_mnnvl_supported(uct_cuda_ipc_md_t *md, const uct_cuda_ipc_device_addr_t *dev_addr, - size_t dev_addr_len, - const uct_iface_is_reachable_params_t *params) + size_t dev_addr_len) { -#if HAVE_NVML_FABRIC_INFO - if (memcmp(dev_addr->mnnvl_addr.cluster_uuid, md->fabric_info.clusterUuid, - sizeof(dev_addr->mnnvl_addr.cluster_uuid))) { - uct_iface_fill_info_str_buf(params, "cluster uuid doesn't match"); - return 0; - } - - if (dev_addr->mnnvl_addr.clique_id != md->fabric_info.cliqueId){ - uct_iface_fill_info_str_buf(params, "clique id doesn't match"); - return 0; + if (md->enable_mnnvl && (dev_addr_len != sizeof(uint64_t))) { + ucs_assertv(dev_addr_len >= sizeof(uct_cuda_ipc_device_addr_t), + "dev_addr_len=%zu", dev_addr_len); + return (dev_addr->flags & UCT_CUDA_IPC_DEVICE_ADDR_FLAG_MNNVL); } - return 1; -#endif - return 0; } @@ -151,21 +137,13 @@ uct_cuda_ipc_iface_is_reachable_v2(const uct_iface_h tl_iface, return 0; } - if (md->enable_mnnvl && (dev_addr_len != sizeof(uint64_t))) { - ucs_assertv(dev_addr_len >= sizeof(uct_cuda_ipc_device_addr_t), - "dev_addr_len=%zu", dev_addr_len); - if (!uct_cuda_ipc_iface_mnnvl_reachable(md, dev_addr, dev_addr_len, - params)) { - return 0; - } - } else if (!same_uuid) { - uct_iface_fill_info_str_buf(params, - "different system id %"PRIx64" vs %"PRIx64"", - ucs_get_system_id(), dev_addr->system_uuid); - return 0; + if (same_uuid || + uct_cuda_ipc_iface_mnnvl_supported(md, dev_addr, dev_addr_len)) { + return uct_iface_scope_is_reachable(tl_iface, params); } - return uct_iface_scope_is_reachable(tl_iface, params); + uct_iface_fill_info_str_buf(params, "MNNVL is not supported"); + return 0; } static double uct_cuda_ipc_iface_get_bw() diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c index bac706f4db4..aeff3b52a4b 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c @@ -409,11 +409,12 @@ uct_cuda_ipc_mem_dereg(uct_md_h md, const uct_md_mem_dereg_params_t *params) } static int -uct_cuda_ipc_md_init_fabric_info(uct_cuda_ipc_md_t *md, - ucs_ternary_auto_value_t mnnvl_enable) +uct_cuda_ipc_md_check_fabric_info(uct_cuda_ipc_md_t *md, + ucs_ternary_auto_value_t mnnvl_enable) { int mnnvl_supported = 0; #if HAVE_NVML_FABRIC_INFO + nvmlGpuFabricInfoV_t fabric_info; nvmlDevice_t device; ucs_status_t status; char buf[64]; @@ -432,22 +433,22 @@ uct_cuda_ipc_md_init_fabric_info(uct_cuda_ipc_md_t *md, goto out_sd; } - md->fabric_info.version = nvmlGpuFabricInfo_v2; - status = UCT_NVML_FUNC_LOG_ERR( - nvmlDeviceGetGpuFabricInfoV(device, &md->fabric_info)); + fabric_info.version = nvmlGpuFabricInfo_v2; + status = UCT_NVML_FUNC_LOG_ERR( + nvmlDeviceGetGpuFabricInfoV(device, &fabric_info)); if (status != UCS_OK) { goto out_sd; } ucs_debug("fabric_info: healthmask=%u state=%u status=%u clique=%u uuid=%s", - md->fabric_info.healthMask, md->fabric_info.state, - md->fabric_info.status, md->fabric_info.cliqueId, + fabric_info.healthMask, fabric_info.state, fabric_info.status, + fabric_info.cliqueId, ucs_str_dump_hex( - md->fabric_info.clusterUuid, NVML_GPU_FABRIC_UUID_LEN, buf, + fabric_info.clusterUuid, NVML_GPU_FABRIC_UUID_LEN, buf, sizeof(buf), SIZE_MAX)); - if ((md->fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED) || - (md->fabric_info.status != NVML_SUCCESS)) { + if ((fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED) || + (fabric_info.status != NVML_SUCCESS)) { goto out_sd; } @@ -493,7 +494,7 @@ uct_cuda_ipc_md_open(uct_component_t *component, const char *md_name, md->super.ops = &md_ops; md->super.component = &uct_cuda_ipc_component.super; - md->enable_mnnvl = uct_cuda_ipc_md_init_fabric_info( + md->enable_mnnvl = uct_cuda_ipc_md_check_fabric_info( md, ipc_config->enable_mnnvl); *md_p = &md->super; diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_md.h b/src/uct/cuda/cuda_ipc/cuda_ipc_md.h index 4b7e0bb0875..eb621bd5ce8 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_md.h +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_md.h @@ -43,9 +43,6 @@ typedef CUipcMemHandle uct_cuda_ipc_md_handle_t; typedef struct uct_cuda_ipc_md { uct_md_t super; /**< Domain info */ int enable_mnnvl; /**< Multi-node NVLINK support status */ -#if HAVE_NVML_FABRIC_INFO - nvmlGpuFabricInfoV_t fabric_info; /**< GPU fabric information */ -#endif } uct_cuda_ipc_md_t;