Browse Source

code check clean

tags/v1.6.0
liuyang_655 4 years ago
parent
commit
3eff2fe8e0
3 changed files with 16 additions and 34 deletions
  1. +1
    -0
      mindspore/ccsrc/transform/express_ir/onnx_exporter.cc
  2. +7
    -3
      mindspore/python/mindspore/train/callback/_checkpoint.py
  3. +8
    -31
      mindspore/python/mindspore/train/model.py

+ 1
- 0
mindspore/ccsrc/transform/express_ir/onnx_exporter.cc View File

@@ -1651,6 +1651,7 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr &, std::map<AnfNodePtr,
const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
onnx::GraphProto *const graph_proto) {
auto op_map = OpConvertRegistry::GetOpConvertMap();
MS_EXCEPTION_IF_NULL(prim);
auto op_iter = op_map.find(prim->name());
if (op_iter == op_map.end()) {
MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map. "


+ 7
- 3
mindspore/python/mindspore/train/callback/_checkpoint.py View File

@@ -35,15 +35,19 @@ _save_dir = _cur_dir
_info_list = ["epoch_num", "step_num"]


def _chg_ckpt_file_name_if_same_exist(directory, prefix):
def _chg_ckpt_file_name_if_same_exist(directory, prefix, exception=False):
"""Check if there is a file with the same name."""
files = os.listdir(directory)
suffix_num = 0
pre_len = len(prefix)
for filename in files:
name_ext = os.path.splitext(filename)
if name_ext[-1] != ".ckpt":
continue
if not exception:
if name_ext[-1] != ".ckpt" or name_ext[-1] == ".ckpt" and filename[-16:] == "_breakpoint.ckpt":
continue
else:
if filename[-16:] != "_breakpoint.ckpt":
continue
# find same prefix file
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
# add the max suffix + 1


+ 8
- 31
mindspore/python/mindspore/train/model.py View File

@@ -23,6 +23,7 @@ import numpy as np
from mindspore import log as logger
from .serialization import save_checkpoint
from .callback._checkpoint import ModelCheckpoint
from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, Validator
@@ -59,29 +60,6 @@ class _StepSync(Callback):
_pynative_executor.sync()


def _check_bpckpt_file_name_if_same_exist(directory, prefix):
"""Check if there is a exception checkpoint file with the same name."""
files = os.listdir(directory)
suffix_num = 0
pre_len = len(prefix)
for filename in files:
if filename[-16:] != "_breakpoint.ckpt":
continue
# find same prefix file
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
# add the max suffix + 1
index = filename[pre_len:].find("-")
if index == 0:
suffix_num = max(suffix_num, 1)
elif index != -1:
num = filename[pre_len+1:pre_len+index]
if num.isdigit():
suffix_num = max(suffix_num, int(num)+1)
if suffix_num != 0:
prefix = prefix + "_" + str(suffix_num)
return prefix


def _save_final_ckpt(func):
"""
Decorator function, which saves the current checkpoint when an exception occurs during training.
@@ -89,18 +67,17 @@ def _save_final_ckpt(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
obj = None
if kwargs['callbacks']:
if isinstance(kwargs['callbacks'], ModelCheckpoint):
obj = kwargs['callbacks']
if isinstance(kwargs['callbacks'], list):
for item in kwargs['callbacks']:
if isinstance(item, ModelCheckpoint):
obj = item
if kwargs['callbacks'] and isinstance(kwargs['callbacks'], ModelCheckpoint):
obj = kwargs['callbacks']
if kwargs['callbacks'] and isinstance(kwargs['callbacks'], list):
for item in kwargs['callbacks']:
if isinstance(item, ModelCheckpoint):
obj = item
if obj and obj._config and obj._config.exception_save:
try:
func(self, *args, **kwargs)
except BaseException as e:
prefix = _check_bpckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix)
prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
+ str(self._current_step_num) + "_breakpoint.ckpt"
cur_file = os.path.join(obj._directory, cur_ckpoint_file)


Loading…
Cancel
Save