Browse Source

add export file name check

tags/v1.1.0
changzherui 5 years ago
parent
commit
70f4c51173
2 changed files with 15 additions and 4 deletions
  1. +10
    -0
      mindspore/_checkparam.py
  2. +5
    -4
      mindspore/train/serialization.py

+ 10
- 0
mindspore/_checkparam.py View File

@@ -427,6 +427,16 @@ class Validator:
target, prim_name, reg, flag)) target, prim_name, reg, flag))
return True return True


@staticmethod
def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
if reg is None:
reg = r"^[0-9a-zA-Z\_\.\/\\]*$"
if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
target, prim_name, reg, flag))
return True

@staticmethod @staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name): def check_pad_value_by_mode(pad_mode, padding, prim_name):
"""Validates value of padding according to pad_mode""" """Validates value of padding according to pad_mode"""


+ 5
- 4
mindspore/train/serialization.py View File

@@ -530,6 +530,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
if not isinstance(file_name, str): if not isinstance(file_name, str):
raise ValueError("Args file_name {} must be string, please check it".format(file_name)) raise ValueError("Args file_name {} must be string, please check it".format(file_name))


Validator.check_file_name_by_regular(file_name)
net = _quant_export(net, *inputs, file_format=file_format, **kwargs) net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
_export(net, file_name, file_format, *inputs) _export(net, file_name, file_format, *inputs)


@@ -552,14 +553,14 @@ def _export(net, file_name, file_format, *inputs):
is_dump_onnx_in_training = net.training and file_format == 'ONNX' is_dump_onnx_in_training = net.training and file_format == 'ONNX'
if is_dump_onnx_in_training: if is_dump_onnx_in_training:
net.set_train(mode=False) net.set_train(mode=False)
# export model
net.init_parameters_data() net.init_parameters_data()
if file_format == 'AIR': if file_format == 'AIR':
phase_name = 'export.air' phase_name = 'export.air'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
file_name += ".air" file_name += ".air"
_executor.export(file_name, graph_id) _executor.export(file_name, graph_id)
elif file_format == 'ONNX': # file_format is 'ONNX'
elif file_format == 'ONNX':
phase_name = 'export.onnx' phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(net, graph_id) onnx_stream = _executor._get_func_graph_proto(net, graph_id)
@@ -567,7 +568,7 @@ def _export(net, file_name, file_format, *inputs):
with open(file_name, 'wb') as f: with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream) f.write(onnx_stream)
elif file_format == 'MINDIR': # file_format is 'MINDIR'
elif file_format == 'MINDIR':
phase_name = 'export.mindir' phase_name = 'export.mindir'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
@@ -575,7 +576,7 @@ def _export(net, file_name, file_format, *inputs):
with open(file_name, 'wb') as f: with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream) f.write(onnx_stream)
# restore network training mode
if is_dump_onnx_in_training: if is_dump_onnx_in_training:
net.set_train(mode=True) net.set_train(mode=True)




Loading…
Cancel
Save