From 8048b52edd0cf60d4f12b8c3e9aa3e444e885075 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Mon, 12 Apr 2021 16:19:07 +0800 Subject: [PATCH] add cudnn version check --- mindspore/_check_version.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mindspore/_check_version.py b/mindspore/_check_version.py index b2f1551cf1..575d4c0e73 100644 --- a/mindspore/_check_version.py +++ b/mindspore/_check_version.py @@ -55,6 +55,7 @@ class GPUEnvChecker(EnvChecker): self.v = "0" self.cuda_lib_path = self._get_lib_path("libcu") self.cuda_bin_path = self._get_bin_path("cuda") + self.cudnn_lib_path = self._get_lib_path("libcudnn") def check_env(self, e): raise e @@ -94,6 +95,20 @@ class GPUEnvChecker(EnvChecker): return line.strip().split("release")[1].split(",")[0].strip() return "" + def _get_cudnn_version(self): + """Get cudnn version by libcudnn.so.""" + cudnn_version = [] + for path in self.cudnn_lib_path: + ls_cudnn = subprocess.run(["ls " + path + "/lib64/libcudnn.so.*.*"], timeout=10, text=True, + capture_output=True, check=False, shell=True) + if ls_cudnn.returncode == 0: + cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.') + if len(cudnn_version) == 2: + cudnn_version.append('0') + break + version_str = ''.join([n for n in cudnn_version]) + return version_str + def check_version(self): """Check cuda version.""" version_match = False @@ -118,6 +133,17 @@ class GPUEnvChecker(EnvChecker): logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} " "does not match, please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install") + cudnn_version = self._get_cudnn_version() + if cudnn_version and int(cudnn_version) < 760: + logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} " + "does not match, please refer to the installation guide for version matching " + "information: https://www.mindspore.cn/install. The recommended version is " + "CUDA10.1 with cuDNN7.6.x and CUAD11.1 with cuDNN8.0.x") + if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10: + logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} " + "does not match, please refer to the installation guide for version matching " + "information: https://www.mindspore.cn/install. The recommended version is " + "CUAD11.1 with cuDNN8.0.x") def _check_version(self, version_file): """Check cuda version by version.txt."""