|
|
@@ -14,15 +14,13 @@ |
|
|
# ============================================================================ |
|
|
# ============================================================================ |
|
|
"""train Xception.""" |
|
|
"""train Xception.""" |
|
|
import os |
|
|
import os |
|
|
import time |
|
|
|
|
|
import argparse |
|
|
import argparse |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
from mindspore import context |
|
|
from mindspore import context |
|
|
from mindspore import Tensor |
|
|
from mindspore import Tensor |
|
|
from mindspore.nn.optim.momentum import Momentum |
|
|
from mindspore.nn.optim.momentum import Momentum |
|
|
from mindspore.train.model import Model, ParallelMode |
|
|
from mindspore.train.model import Model, ParallelMode |
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback |
|
|
|
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor |
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
from mindspore.communication.management import init, get_rank, get_group_size |
|
|
from mindspore.communication.management import init, get_rank, get_group_size |
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager |
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager |
|
|
@@ -37,59 +35,6 @@ from src.loss import CrossEntropySmooth |
|
|
|
|
|
|
|
|
set_seed(1) |
|
|
set_seed(1) |
|
|
|
|
|
|
|
|
class Monitor(Callback): |
|
|
|
|
|
""" |
|
|
|
|
|
Monitor loss and time. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
lr_init (numpy array): train lr |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
None |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> Monitor(lr_init=Tensor([0.05]*100).asnumpy()) |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, lr_init=None): |
|
|
|
|
|
super(Monitor, self).__init__() |
|
|
|
|
|
self.lr_init = lr_init |
|
|
|
|
|
self.lr_init_len = len(lr_init) |
|
|
|
|
|
|
|
|
|
|
|
def epoch_begin(self, run_context): |
|
|
|
|
|
self.losses = [] |
|
|
|
|
|
self.epoch_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
def epoch_end(self, run_context): |
|
|
|
|
|
cb_params = run_context.original_args() |
|
|
|
|
|
|
|
|
|
|
|
epoch_mseconds = (time.time() - self.epoch_time) * 1000 |
|
|
|
|
|
per_step_mseconds = epoch_mseconds / cb_params.batch_num |
|
|
|
|
|
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, |
|
|
|
|
|
per_step_mseconds, |
|
|
|
|
|
np.mean(self.losses))) |
|
|
|
|
|
|
|
|
|
|
|
def step_begin(self, run_context): |
|
|
|
|
|
self.step_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
def step_end(self, run_context): |
|
|
|
|
|
cb_params = run_context.original_args() |
|
|
|
|
|
step_mseconds = (time.time() - self.step_time) * 1000 |
|
|
|
|
|
step_loss = cb_params.net_outputs |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): |
|
|
|
|
|
step_loss = step_loss[0] |
|
|
|
|
|
if isinstance(step_loss, Tensor): |
|
|
|
|
|
step_loss = np.mean(step_loss.asnumpy()) |
|
|
|
|
|
|
|
|
|
|
|
self.losses.append(step_loss) |
|
|
|
|
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num |
|
|
|
|
|
|
|
|
|
|
|
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( |
|
|
|
|
|
cb_params.cur_epoch_num - 1 + config.finish_epoch, cb_params.epoch_num + config.finish_epoch, |
|
|
|
|
|
cur_step_in_epoch, cb_params.batch_num, step_loss, |
|
|
|
|
|
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]), flush=True) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description='image classification training') |
|
|
parser = argparse.ArgumentParser(description='image classification training') |
|
|
parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training') |
|
|
parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training') |
|
|
@@ -153,7 +98,7 @@ if __name__ == '__main__': |
|
|
amp_level='O3', keep_batchnorm_fp32=True) |
|
|
amp_level='O3', keep_batchnorm_fp32=True) |
|
|
|
|
|
|
|
|
# define callbacks |
|
|
# define callbacks |
|
|
cb = [Monitor(lr_init=lr.asnumpy())] |
|
|
|
|
|
|
|
|
cb = [TimeMonitor(), LossMonitor()] |
|
|
if config.save_checkpoint: |
|
|
if config.save_checkpoint: |
|
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(rank) + '/') |
|
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(rank) + '/') |
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, |
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, |
|
|
|