diff --git a/src/tamm/gpu_streams.hpp b/src/tamm/gpu_streams.hpp index 6e5cf15a3..05d12a6b4 100644 --- a/src/tamm/gpu_streams.hpp +++ b/src/tamm/gpu_streams.hpp @@ -13,9 +13,11 @@ #include #include #include +#include #elif defined(USE_HIP) #include #include +#include #elif defined(USE_DPCPP) #include "sycl_device.hpp" #include @@ -117,36 +119,77 @@ static inline void getDeviceCount(int* id) { #endif } +#if defined(USE_CUDA) +static inline int nv_device_count() { + nvmlReturn_t result; + unsigned int device_count, i; + + result = nvmlInit(); + if(NVML_SUCCESS != result) { + std::string nvml_error = "Failed to initialize NVML: " + std::string(nvmlErrorString(result)); + tamm_terminate(nvml_error); + } + + result = nvmlDeviceGetCount(&device_count); + if(NVML_SUCCESS != result) { + std::string nvml_error = + "Failed to query device count: " + std::string(nvmlErrorString(result)); + tamm_terminate(nvml_error); + } + nvmlShutdown(); + return device_count; +} + +#elif defined(USE_HIP) +static inline int amd_device_count() { + rsmi_status_t result; + uint32_t device_count; + + result = rsmi_init(0); + if(result != RSMI_STATUS_SUCCESS) tamm_terminate("rsmi_init failed"); + result = rsmi_num_monitor_devices(&device_count); + if(result != RSMI_STATUS_SUCCESS) tamm_terminate("Failed to query device count"); + rsmi_shut_down(); + + return device_count; +} + +#endif + // The following API is to get the hardware count of // GPUs/GCDs/Xe-stacks/tiles on a given node. Unlike the // above API, this method is not affected by the masking // env variables like CUDA/ROCR_VISIBLE_DEVICES or ZE_AFFINITY_MASK static inline void getHardwareGPUCount(int* gpus_per_node) { - std::array buffer; - std::string result, m_call; - #if defined(USE_CUDA) - m_call = "nvidia-smi --query-gpu=name --format=csv,noheader | wc -l"; + *gpus_per_node = nv_device_count(); #elif defined(USE_HIP) - m_call = "rocm-smi -i |grep GPU|wc -l"; + *gpus_per_node = amd_device_count(); #elif defined(USE_DPCPP) + std::array buffer; + std::string result, m_call; + sycl::platform pltf = sycl_get_device(0)->get_platform(); if(pltf.get_backend() == sycl::backend::ext_oneapi_level_zero || pltf.get_backend() == sycl::backend::opencl) { m_call = "cat /sys/class/drm/card*/gt/gt*/id | wc -l"; } else if(pltf.get_backend() == sycl::backend::ext_oneapi_cuda) { + // TODO: can we use nvml api ?, propably no m_call = "nvidia-smi --query-gpu=name --format=csv,noheader | wc -l"; } else if(pltf.get_backend() == sycl::backend::ext_oneapi_hip) { + // TODO: can we use ROCm SMI api ?, probably no m_call = "rocm-smi -i |grep GPU|wc -l"; } -#endif std::unique_ptr pipe(popen(m_call.c_str(), "r"), pclose); if(!pipe) { throw std::runtime_error("popen() failed!"); } while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { result += buffer.data(); } *gpus_per_node = stoi(result); + + if(*gpus_per_node == 0) { tamm_terminate("[TAMM ERROR] No GPUs detected on node!"); } +#endif } static inline std::string getDeviceName() {