|
|
|
@@ -23,6 +23,7 @@ import numpy as np |
|
|
|
from mindspore import log as logger |
|
|
|
from .serialization import save_checkpoint |
|
|
|
from .callback._checkpoint import ModelCheckpoint |
|
|
|
from .callback._checkpoint import _chg_ckpt_file_name_if_same_exist |
|
|
|
from ..common.tensor import Tensor |
|
|
|
from ..nn.metrics import get_metrics |
|
|
|
from .._checkparam import check_input_data, check_output_data, Validator |
|
|
|
@@ -59,29 +60,6 @@ class _StepSync(Callback): |
|
|
|
_pynative_executor.sync() |
|
|
|
|
|
|
|
|
|
|
|
def _check_bpckpt_file_name_if_same_exist(directory, prefix): |
|
|
|
"""Check if there is a exception checkpoint file with the same name.""" |
|
|
|
files = os.listdir(directory) |
|
|
|
suffix_num = 0 |
|
|
|
pre_len = len(prefix) |
|
|
|
for filename in files: |
|
|
|
if filename[-16:] != "_breakpoint.ckpt": |
|
|
|
continue |
|
|
|
# find same prefix file |
|
|
|
if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): |
|
|
|
# add the max suffix + 1 |
|
|
|
index = filename[pre_len:].find("-") |
|
|
|
if index == 0: |
|
|
|
suffix_num = max(suffix_num, 1) |
|
|
|
elif index != -1: |
|
|
|
num = filename[pre_len+1:pre_len+index] |
|
|
|
if num.isdigit(): |
|
|
|
suffix_num = max(suffix_num, int(num)+1) |
|
|
|
if suffix_num != 0: |
|
|
|
prefix = prefix + "_" + str(suffix_num) |
|
|
|
return prefix |
|
|
|
|
|
|
|
|
|
|
|
def _save_final_ckpt(func): |
|
|
|
""" |
|
|
|
Decorator function, which saves the current checkpoint when an exception occurs during training. |
|
|
|
@@ -89,18 +67,17 @@ def _save_final_ckpt(func): |
|
|
|
@wraps(func) |
|
|
|
def wrapper(self, *args, **kwargs): |
|
|
|
obj = None |
|
|
|
if kwargs['callbacks']: |
|
|
|
if isinstance(kwargs['callbacks'], ModelCheckpoint): |
|
|
|
obj = kwargs['callbacks'] |
|
|
|
if isinstance(kwargs['callbacks'], list): |
|
|
|
for item in kwargs['callbacks']: |
|
|
|
if isinstance(item, ModelCheckpoint): |
|
|
|
obj = item |
|
|
|
if kwargs['callbacks'] and isinstance(kwargs['callbacks'], ModelCheckpoint): |
|
|
|
obj = kwargs['callbacks'] |
|
|
|
if kwargs['callbacks'] and isinstance(kwargs['callbacks'], list): |
|
|
|
for item in kwargs['callbacks']: |
|
|
|
if isinstance(item, ModelCheckpoint): |
|
|
|
obj = item |
|
|
|
if obj and obj._config and obj._config.exception_save: |
|
|
|
try: |
|
|
|
func(self, *args, **kwargs) |
|
|
|
except BaseException as e: |
|
|
|
prefix = _check_bpckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix) |
|
|
|
prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True) |
|
|
|
cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \ |
|
|
|
+ str(self._current_step_num) + "_breakpoint.ckpt" |
|
|
|
cur_file = os.path.join(obj._directory, cur_ckpoint_file) |
|
|
|
|