Browse Source

!30432 add ckpt check and mod cn api

Merge pull request !30432 from changzherui/add_ckpt_check
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
99fbcec2ce
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 7 additions and 2 deletions
  1. +1
    -0
      docs/api/api_python/mindspore/mindspore.save_checkpoint.rst
  2. +2
    -1
      docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst
  3. +4
    -1
      mindspore/python/mindspore/train/serialization.py

+ 1
- 0
docs/api/api_python/mindspore/mindspore.save_checkpoint.rst View File

@@ -18,3 +18,4 @@ mindspore.save_checkpoint
**异常:**

- **TypeError** – 如果参数 `save_obj` 类型不为nn.Cell或者list,且如果参数 `integrated_save` 及 `async_save` 非bool类型。
- **TypeError** – 如果参数 `ckpt_file_name` 不是str类型。

+ 2
- 1
docs/api/api_python/train/mindspore.train.callback.LossMonitor.rst View File

@@ -10,11 +10,12 @@
**参数:**

- **per_print_times** (int) - 表示每隔多少个step打印一次loss。默认值:1。
- **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如何设置了该参数,LossMonitor将监控该数值之后epoch的loss值。默认值:0。

**异常:**

- **ValueError** - 当 `per_print_times` 不是整数或小于零。
- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。

.. py:method:: step_end(run_context)



+ 4
- 1
mindspore/python/mindspore/train/serialization.py View File

@@ -244,6 +244,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
if not isinstance(ckpt_file_name, str):
raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be "
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
ckpt_file_name = os.path.realpath(ckpt_file_name)
if os.path.isdir(ckpt_file_name):
raise IsADirectoryError("The argument `ckpt_file_name`: {} is a directory, "
"it should be a file name.".format(ckpt_file_name))
if not ckpt_file_name.endswith('.ckpt'):
ckpt_file_name += ".ckpt"
integrated_save = Validator.check_bool(integrated_save)
@@ -298,7 +302,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
data = param["data"].asnumpy().reshape(-1)
data_list[key].append(data)

ckpt_file_name = os.path.realpath(ckpt_file_name)
if async_save:
data_copy = copy.deepcopy(data_list)
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")


Loading…
Cancel
Save