From: @penny369 Reviewed-by: @guoqi1024,@pandoublefeng Signed-off-by: @guoqi1024tags/v1.2.0-rc1
| @@ -14,15 +14,13 @@ | |||
| # ============================================================================ | |||
| """train Xception.""" | |||
| import os | |||
| import time | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.model import Model, ParallelMode | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| @@ -37,59 +35,6 @@ from src.loss import CrossEntropySmooth | |||
| set_seed(1) | |||
| class Monitor(Callback): | |||
| """ | |||
| Monitor loss and time. | |||
| Args: | |||
| lr_init (numpy array): train lr | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> Monitor(lr_init=Tensor([0.05]*100).asnumpy()) | |||
| """ | |||
| def __init__(self, lr_init=None): | |||
| super(Monitor, self).__init__() | |||
| self.lr_init = lr_init | |||
| self.lr_init_len = len(lr_init) | |||
| def epoch_begin(self, run_context): | |||
| self.losses = [] | |||
| self.epoch_time = time.time() | |||
| def epoch_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||
| per_step_mseconds = epoch_mseconds / cb_params.batch_num | |||
| print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, | |||
| per_step_mseconds, | |||
| np.mean(self.losses))) | |||
| def step_begin(self, run_context): | |||
| self.step_time = time.time() | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||
| step_loss = cb_params.net_outputs | |||
| if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): | |||
| step_loss = step_loss[0] | |||
| if isinstance(step_loss, Tensor): | |||
| step_loss = np.mean(step_loss.asnumpy()) | |||
| self.losses.append(step_loss) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num | |||
| print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( | |||
| cb_params.cur_epoch_num - 1 + config.finish_epoch, cb_params.epoch_num + config.finish_epoch, | |||
| cur_step_in_epoch, cb_params.batch_num, step_loss, | |||
| np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]), flush=True) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='image classification training') | |||
| parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training') | |||
| @@ -153,7 +98,7 @@ if __name__ == '__main__': | |||
| amp_level='O3', keep_batchnorm_fp32=True) | |||
| # define callbacks | |||
| cb = [Monitor(lr_init=lr.asnumpy())] | |||
| cb = [TimeMonitor(), LossMonitor()] | |||
| if config.save_checkpoint: | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(rank) + '/') | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, | |||
| @@ -83,6 +83,7 @@ After installing MindSpore via the official website, you can start training and | |||
| │ ├── config.py // parameter configuration | |||
| ├── train.py // training script | |||
| ├── eval.py // evaluation script | |||
| ├── export.py // export checkpoint to other format file | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| @@ -175,4 +176,4 @@ For more configuration details, please refer the script `config.py`. | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ##############export checkpoint file into air, onnx, mindir models################# | |||
| python export.py | |||
| """ | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context | |||
| from src.config import cfg | |||
| from src.textcnn import TextCNN | |||
| from src.dataset import MovieReview | |||
| parser = argparse.ArgumentParser(description='TextCNN export') | |||
| parser.add_argument("--device_id", type=int, default=0, help="device id") | |||
| parser.add_argument("--ckpt_file", type=str, required=True, help="checkpoint file path.") | |||
| parser.add_argument("--file_name", type=str, default="textcnn", help="output file name.") | |||
| parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') | |||
| parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", | |||
| help="device target") | |||
| parser.add_argument('--dataset_name', type=str, default='MR', choices=['MR'], | |||
| help='dataset name.') | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||
| if __name__ == '__main__': | |||
| if args.dataset_name == 'MR': | |||
| instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| else: | |||
| raise ValueError("dataset is not support.") | |||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, | |||
| num_classes=cfg.num_classes, vec_length=cfg.vec_length) | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| load_param_into_net(net, param_dict) | |||
| input_arr = Tensor(np.ones([cfg.batch_size, cfg.word_len], np.int32)) | |||
| export(net, input_arr, file_name=args.file_name, file_format=args.file_format) | |||