Browse Source

add args in history

r1.7
liutongtong 4 years ago
parent
commit
980a438d38
2 changed files with 23 additions and 2 deletions
  1. +8
    -0
      docs/api/api_python/train/mindspore.train.callback.History.rst
  2. +15
    -2
      mindspore/python/mindspore/train/callback/_history.py

+ 8
- 0
docs/api/api_python/train/mindspore.train.callback.History.rst View File

@@ -7,6 +7,14 @@
.. note::
通常使用在 `mindspore.Model.train` 中。

**参数:**

- **has_trained_epoch** (int) - 表示已经训练了多少个epoch,如果设置了该参数,History将监控该数值之后epoch的网络输出信息。默认值:0。

**异常:**

- **ValueError** - 当 `has_trained_epoch` 不是整数或小于零。

.. py:method:: begin(run_context)

训练开始时初始化History对象的epoch属性。


+ 15
- 2
mindspore/python/mindspore/train/callback/_history.py View File

@@ -16,6 +16,7 @@

import numpy as np
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator
from ._callback import Callback


@@ -31,6 +32,13 @@ class History(Callback):
Note:
Normally used in `mindspore.Model.train`.

Args:
has_trained_epoch (int): How many epochs has trained. If this parameter is set, History will record the
network output information after has_trained_epoch's epoch. Default: 0.

Raises:
ValueError: If has_trained_epoch is not an integer or less than zero.

Examples:
>>> import numpy as np
>>> import mindspore.dataset as ds
@@ -49,9 +57,11 @@ class History(Callback):
{'epoch': [1, 2]}
{'net_output': [1.607877, 1.6033841]}
"""
def __init__(self):
def __init__(self, has_trained_epoch=0):
super(History, self).__init__()
Validator.check_non_negative_int(has_trained_epoch)
self.history = {}
self._has_trained_epoch = has_trained_epoch

def begin(self, run_context):
"""
@@ -70,7 +80,10 @@ class History(Callback):
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`.
"""
cb_params = run_context.original_args()
epoch = cb_params.get("cur_epoch_num", 1)
if "cur_epoch_num" in cb_params:
epoch = cb_params.get("cur_epoch_num") + self._has_trained_epoch
else:
epoch = 1
self.epoch.get("epoch").append(epoch)
net_output = cb_params.net_outputs
if isinstance(net_output, (tuple, list)):


Loading…
Cancel
Save