|
|
|
@@ -16,6 +16,7 @@ |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore._checkparam import Validator |
|
|
|
from ._callback import Callback |
|
|
|
|
|
|
|
|
|
|
|
@@ -31,6 +32,13 @@ class History(Callback): |
|
|
|
Note: |
|
|
|
Normally used in `mindspore.Model.train`. |
|
|
|
|
|
|
|
Args: |
|
|
|
has_trained_epoch (int): How many epochs has trained. If this parameter is set, History will record the |
|
|
|
network output information after has_trained_epoch's epoch. Default: 0. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: If has_trained_epoch is not an integer or less than zero. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> import numpy as np |
|
|
|
>>> import mindspore.dataset as ds |
|
|
|
@@ -49,9 +57,11 @@ class History(Callback): |
|
|
|
{'epoch': [1, 2]} |
|
|
|
{'net_output': [1.607877, 1.6033841]} |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
def __init__(self, has_trained_epoch=0): |
|
|
|
super(History, self).__init__() |
|
|
|
Validator.check_non_negative_int(has_trained_epoch) |
|
|
|
self.history = {} |
|
|
|
self._has_trained_epoch = has_trained_epoch |
|
|
|
|
|
|
|
def begin(self, run_context): |
|
|
|
""" |
|
|
|
@@ -70,7 +80,10 @@ class History(Callback): |
|
|
|
run_context (RunContext): Context of the `mindspore.Model.{train | eval}`. |
|
|
|
""" |
|
|
|
cb_params = run_context.original_args() |
|
|
|
epoch = cb_params.get("cur_epoch_num", 1) |
|
|
|
if "cur_epoch_num" in cb_params: |
|
|
|
epoch = cb_params.get("cur_epoch_num") + self._has_trained_epoch |
|
|
|
else: |
|
|
|
epoch = 1 |
|
|
|
self.epoch.get("epoch").append(epoch) |
|
|
|
net_output = cb_params.net_outputs |
|
|
|
if isinstance(net_output, (tuple, list)): |
|
|
|
|