diff --git a/mindinsight/mindconverter/graph_based_converter/common/utils.py b/mindinsight/mindconverter/graph_based_converter/common/utils.py index 7dedd7b7..b946e832 100644 --- a/mindinsight/mindconverter/graph_based_converter/common/utils.py +++ b/mindinsight/mindconverter/graph_based_converter/common/utils.py @@ -109,3 +109,29 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], log.error(str(error)) log.exception(error) raise error + + +def lib_version_satisfied(current_ver: str, mini_ver_limited: str, + newest_ver_limited: str = ""): + """ + Check python lib version whether is satisfied. + + Notes: + Version number must be format of x.x.x, e.g. 1.1.0. + + Args: + current_ver (str): Current lib version. + mini_ver_limited (str): Mini lib version. + newest_ver_limited (str): Newest lib version. + + Returns: + bool, true or false. + """ + required_version_number_len = 3 + if len(list(current_ver.split("."))) != required_version_number_len or \ + len(list(mini_ver_limited.split("."))) != required_version_number_len or \ + (newest_ver_limited and len(newest_ver_limited.split(".")) != required_version_number_len): + raise ValueError("Version number must be format of x.x.x.") + if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): + return False + return True diff --git a/mindinsight/mindconverter/graph_based_converter/constant.py b/mindinsight/mindconverter/graph_based_converter/constant.py index 8a48be18..590e67cd 100644 --- a/mindinsight/mindconverter/graph_based_converter/constant.py +++ b/mindinsight/mindconverter/graph_based_converter/constant.py @@ -37,6 +37,10 @@ DYNAMIC_SHAPE = -1 SCALAR_WITHOUT_SHAPE = 0 UNKNOWN_DIM_VAL = "unk__001" +ONNX_MIN_VER = "1.8.0" +TF2ONNX_MIN_VER = "1.7.1" +ONNXRUNTIME_MIN_VER = "1.5.2" + BINARY_HEADER_PYTORCH_FILE = \ b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00' diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index 8f34444a..28f792eb 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -20,8 +20,9 @@ from importlib import import_module from importlib.util import find_spec import mindinsight +from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ - BINARY_HEADER_PYTORCH_BITS + BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ @@ -92,11 +93,30 @@ def tf_installation_validation(func): output_folder: str, report_folder: str = None, input_nodes: str = None, output_nodes: str = None): # Check whether tensorflow is installed. - if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnxruntime"): - error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using " - "graph based scripts converter.") + if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnx") \ + or not find_spec("onnxruntime"): + error = ModuleNotFoundError( + f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " + f"based scripts converter for TensorFlow conversion." + ) log.error(str(error)) raise error + + onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") + ort = import_module("onnxruntime") + + if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ + or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ + or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): + error = ModuleNotFoundError( + f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " + f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " + f"based scripts converter for TensorFlow conversion." + ) + log.error(str(error)) + raise error + func(graph_path=graph_path, sample_shape=sample_shape, output_folder=output_folder, report_folder=report_folder, input_nodes=input_nodes, output_nodes=output_nodes)