Merge pull request !3051 from chenzhongming/newnew_mastertags/v0.6.0-beta
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """LossMonitor Callback class.""" | """LossMonitor Callback class.""" | ||||
| import time | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| @@ -32,62 +31,32 @@ class LossMonitor(Callback): | |||||
| Args: | Args: | ||||
| per_print_times (int): Print loss every times. Default: 1. | per_print_times (int): Print loss every times. Default: 1. | ||||
| lr_init (numpy array): train learning rate. Default: None. | |||||
| Raises: | Raises: | ||||
| ValueError: If print_step is not int or less than zero. | ValueError: If print_step is not int or less than zero. | ||||
| Examples: | |||||
| >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) | |||||
| """ | """ | ||||
| def __init__(self, per_print_times=1, lr_init=None): | |||||
| def __init__(self, per_print_times=1): | |||||
| super(LossMonitor, self).__init__() | super(LossMonitor, self).__init__() | ||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | if not isinstance(per_print_times, int) or per_print_times < 0: | ||||
| raise ValueError("print_step must be int and >= 0.") | raise ValueError("print_step must be int and >= 0.") | ||||
| self._per_print_times = per_print_times | self._per_print_times = per_print_times | ||||
| self.lr_init = lr_init | |||||
| def epoch_begin(self, run_context): | |||||
| self.losses = [] | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| per_step_mseconds = epoch_mseconds / cb_params.batch_num | |||||
| print("Epoch time: {:5.3f}, per step time: {:5.3f}, " | |||||
| "avg loss: {:5.3f}".format(epoch_mseconds, | |||||
| per_step_mseconds, | |||||
| np.mean(self.losses))) | |||||
| print("*" * 60) | |||||
| def step_begin(self, run_context): | |||||
| self.step_time = time.time() | |||||
| def step_end(self, run_context): | def step_end(self, run_context): | ||||
| cb_params = run_context.original_args() | cb_params = run_context.original_args() | ||||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||||
| step_loss = cb_params.net_outputs | |||||
| loss = cb_params.net_outputs | |||||
| if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): | |||||
| step_loss = step_loss[0] | |||||
| if isinstance(step_loss, Tensor): | |||||
| step_loss = np.mean(step_loss.asnumpy()) | |||||
| if isinstance(loss, (tuple, list)): | |||||
| if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): | |||||
| loss = loss[0] | |||||
| self.losses.append(step_loss) | |||||
| cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 | |||||
| if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): | |||||
| loss = np.mean(loss.asnumpy()) | |||||
| if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): | |||||
| raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " | |||||
| "Invalid loss, terminating training.".format( | |||||
| cb_params.cur_epoch_num - 1, cb_params.epoch_num, | |||||
| cur_step_in_epoch, cb_params.batch_num)) | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | |||||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | |||||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | ||||
| print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | |||||
| "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}ms]".format( | |||||
| cb_params.cur_epoch_num, cb_params.epoch_num, | |||||
| cur_step_in_epoch, int(cb_params.batch_num), | |||||
| step_loss, np.mean(self.losses), | |||||
| step_mseconds), flush=True) | |||||
| print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) | |||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """LossMonitor Callback class.""" | |||||
| import time | |||||
| import numpy as np | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.callback import Callback | |||||
| class LossMonitor(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF, it will terminate training. | |||||
| Note: | |||||
| If per_print_times is 0 do not print loss. | |||||
| Args: | |||||
| per_print_times (int): Print loss every times. Default: 1. | |||||
| lr_init (numpy array): train learning rate. Default: None. | |||||
| Raises: | |||||
| ValueError: If print_step is not int or less than zero. | |||||
| Examples: | |||||
| >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) | |||||
| """ | |||||
| def __init__(self, per_print_times=1, lr_init=None): | |||||
| super(LossMonitor, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0.") | |||||
| self._per_print_times = per_print_times | |||||
| self.lr_init = lr_init | |||||
| def epoch_begin(self, run_context): | |||||
| self.losses = [] | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| per_step_mseconds = epoch_mseconds / cb_params.batch_num | |||||
| print("Epoch time: {:5.3f}, per step time: {:5.3f}, " | |||||
| "avg loss: {:5.3f}".format(epoch_mseconds, | |||||
| per_step_mseconds, | |||||
| np.mean(self.losses))) | |||||
| print("*" * 60) | |||||
| def step_begin(self, run_context): | |||||
| self.step_time = time.time() | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||||
| step_loss = cb_params.net_outputs | |||||
| if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): | |||||
| step_loss = step_loss[0] | |||||
| if isinstance(step_loss, Tensor): | |||||
| step_loss = np.mean(step_loss.asnumpy()) | |||||
| self.losses.append(step_loss) | |||||
| cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 | |||||
| if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): | |||||
| raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " | |||||
| "Invalid loss, terminating training.".format( | |||||
| cb_params.cur_epoch_num - 1, cb_params.epoch_num, | |||||
| cur_step_in_epoch, cb_params.batch_num)) | |||||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||||
| print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | |||||
| "loss: [{:5.4f}], avg loss: [{:5.4f}], time: [{:5.4f}ms]".format( | |||||
| cb_params.cur_epoch_num, cb_params.epoch_num, | |||||
| cur_step_in_epoch, int(cb_params.batch_num), | |||||
| step_loss, np.mean(self.losses), | |||||
| step_mseconds), flush=True) | |||||
| @@ -22,12 +22,13 @@ import os | |||||
| import argparse | import argparse | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | from src.lenet_fusion import LeNet5 as LeNet5Fusion | ||||
| from src.loss_monitor import LossMonitor | |||||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | ||||
| parser.add_argument('--device_target', type=str, default="Ascend", | parser.add_argument('--device_target', type=str, default="Ascend", | ||||
| @@ -23,13 +23,14 @@ import argparse | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore.train.quant import quant | from mindspore.train.quant import quant | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.config import mnist_cfg as cfg | from src.config import mnist_cfg as cfg | ||||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | from src.lenet_fusion import LeNet5 as LeNet5Fusion | ||||
| from src.loss_monitor import LossMonitor | |||||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | ||||
| parser.add_argument('--device_target', type=str, default="Ascend", | parser.add_argument('--device_target', type=str, default="Ascend", | ||||