|
|
|
@@ -20,8 +20,9 @@ from importlib import import_module |
|
|
|
from importlib.util import find_spec |
|
|
|
|
|
|
|
import mindinsight |
|
|
|
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied |
|
|
|
from mindinsight.mindconverter.graph_based_converter.constant import BINARY_HEADER_PYTORCH_FILE, FrameworkType, \ |
|
|
|
BINARY_HEADER_PYTORCH_BITS |
|
|
|
BINARY_HEADER_PYTORCH_BITS, ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER |
|
|
|
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper |
|
|
|
from mindinsight.mindconverter.common.log import logger as log |
|
|
|
from mindinsight.mindconverter.common.exceptions import GraphInitFail, TreeCreateFail, SourceFilesSaveFail, \ |
|
|
|
@@ -92,11 +93,30 @@ def tf_installation_validation(func): |
|
|
|
output_folder: str, report_folder: str = None, |
|
|
|
input_nodes: str = None, output_nodes: str = None): |
|
|
|
# Check whether tensorflow is installed. |
|
|
|
if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnxruntime"): |
|
|
|
error = ModuleNotFoundError("Tensorflow and tf2onnx are required when using " |
|
|
|
"graph based scripts converter.") |
|
|
|
if not find_spec("tensorflow") or not find_spec("tf2onnx") or not find_spec("onnx") \ |
|
|
|
or not find_spec("onnxruntime"): |
|
|
|
error = ModuleNotFoundError( |
|
|
|
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " |
|
|
|
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " |
|
|
|
f"based scripts converter for TensorFlow conversion." |
|
|
|
) |
|
|
|
log.error(str(error)) |
|
|
|
raise error |
|
|
|
|
|
|
|
onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx") |
|
|
|
ort = import_module("onnxruntime") |
|
|
|
|
|
|
|
if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \ |
|
|
|
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \ |
|
|
|
or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER): |
|
|
|
error = ModuleNotFoundError( |
|
|
|
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and " |
|
|
|
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph " |
|
|
|
f"based scripts converter for TensorFlow conversion." |
|
|
|
) |
|
|
|
log.error(str(error)) |
|
|
|
raise error |
|
|
|
|
|
|
|
func(graph_path=graph_path, sample_shape=sample_shape, |
|
|
|
output_folder=output_folder, report_folder=report_folder, |
|
|
|
input_nodes=input_nodes, output_nodes=output_nodes) |
|
|
|
|