From 70f4c5117390908831458ea4f4fd0a62e39fbd8c Mon Sep 17 00:00:00 2001 From: changzherui Date: Mon, 7 Dec 2020 21:22:16 +0800 Subject: [PATCH] add export file name check --- mindspore/_checkparam.py | 10 ++++++++++ mindspore/train/serialization.py | 9 +++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 6f6a50c360..a60420ee04 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -427,6 +427,16 @@ class Validator: target, prim_name, reg, flag)) return True + @staticmethod + def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): + if reg is None: + reg = r"^[0-9a-zA-Z\_\.\/\\]*$" + if re.match(reg, target, flag) is None: + prim_name = f'in `{prim_name}`' if prim_name else "" + raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( + target, prim_name, reg, flag)) + return True + @staticmethod def check_pad_value_by_mode(pad_mode, padding, prim_name): """Validates value of padding according to pad_mode""" diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 855b449f0b..ca2680de7f 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -530,6 +530,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): if not isinstance(file_name, str): raise ValueError("Args file_name {} must be string, please check it".format(file_name)) + Validator.check_file_name_by_regular(file_name) net = _quant_export(net, *inputs, file_format=file_format, **kwargs) _export(net, file_name, file_format, *inputs) @@ -552,14 +553,14 @@ def _export(net, file_name, file_format, *inputs): is_dump_onnx_in_training = net.training and file_format == 'ONNX' if is_dump_onnx_in_training: net.set_train(mode=False) - # export model + net.init_parameters_data() if file_format == 'AIR': phase_name = 'export.air' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) file_name += ".air" _executor.export(file_name, graph_id) - elif file_format == 'ONNX': # file_format is 'ONNX' + elif file_format == 'ONNX': phase_name = 'export.onnx' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(net, graph_id) @@ -567,7 +568,7 @@ def _export(net, file_name, file_format, *inputs): with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) - elif file_format == 'MINDIR': # file_format is 'MINDIR' + elif file_format == 'MINDIR': phase_name = 'export.mindir' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') @@ -575,7 +576,7 @@ def _export(net, file_name, file_format, *inputs): with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) - # restore network training mode + if is_dump_onnx_in_training: net.set_train(mode=True)