|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- Train CenterNet and get network model files(.ckpt)
- """
-
- import os
- import argparse
- import mindspore.communication.management as D
- from mindspore.communication.management import get_rank
- from mindspore import context
- from mindspore.train.model import Model
- from mindspore.context import ParallelMode
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.nn.optim import Adam
- from mindspore import log as logger
- from mindspore.common import set_seed
- from mindspore.profiler import Profiler
- from src.dataset import COCOHP
- from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
- from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
- from src.config import dataset_config, net_config, train_config
-
- _current_dir = os.path.dirname(os.path.realpath(__file__))
-
- parser = argparse.ArgumentParser(description='CenterNet training')
- parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
- help="Run distribute, default is false.")
- parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
- help="Profiling to parsing runtime info, default is false.")
- parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data")
- parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
- parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1,"
- "i.e. run all steps according to epoch number.")
- parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
- parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
- parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
- help="Enable save checkpoint, default is true.")
- parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
- help="Enable shuffle for dataset, default is true.")
- parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
- help="Enable data sink, default is true.")
- parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
- parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
- parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
- parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
- parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
- parser.add_argument("--mindrecord_dir", type=str, default="",
- help="Mindrecord files directory. If is empty, mindrecord format files will be generated"
- "based on the original dataset and annotation information. If mindrecord_dir isn't empty,"
- "mindrecord_dir will be used inplace of data_dir and anno_path.")
- parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
- "the absolute image path is joined by the data_dir "
- "and the relative path in anno_path")
- parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
- parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
-
- args_opt = parser.parse_args()
-
-
- def _set_parallel_all_reduce_split():
- """set centernet all_reduce fusion split"""
- if net_config.last_level == 5:
- context.set_auto_parallel_context(all_reduce_fusion_config=[16, 56, 96, 136, 175])
- elif net_config.last_level == 6:
- context.set_auto_parallel_context(all_reduce_fusion_config=[18, 59, 100, 141, 182])
- else:
- raise ValueError("The total num of allreduced grads for last level = {} is unknown,"
- "please re-split after known the true value".format(net_config.last_level))
-
-
- def _get_params_groups(network, optimizer):
- """
- Get param groups
- """
- params = network.trainable_params()
- decay_params = list(filter(lambda x: not optimizer.decay_filter(x), params))
- other_params = list(filter(optimizer.decay_filter, params))
- group_params = [{'params': decay_params, 'weight_decay': optimizer.weight_decay},
- {'params': other_params, 'weight_decay': 0.0},
- {'order_params': params}]
- return group_params
-
-
- def _get_optimizer(network, dataset_size):
- """get optimizer, only support Adam right now."""
- if train_config.optimizer == 'Adam':
- group_params = _get_params_groups(network, train_config.Adam)
- if train_config.lr_schedule == "PolyDecay":
- lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate,
- end_learning_rate=train_config.PolyDecay.end_learning_rate,
- warmup_steps=train_config.PolyDecay.warmup_steps,
- decay_steps=args_opt.train_steps,
- power=train_config.PolyDecay.power)
- optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0)
- elif train_config.lr_schedule == "MultiDecay":
- multi_epochs = train_config.MultiDecay.multi_epochs
- if not isinstance(multi_epochs, (list, tuple)):
- raise TypeError("multi_epochs must be list or tuple.")
- if not multi_epochs:
- multi_epochs = [args_opt.epoch_size]
- lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate,
- warmup_steps=train_config.MultiDecay.warmup_steps,
- multi_epochs=multi_epochs,
- steps_per_epoch=dataset_size,
- factor=train_config.MultiDecay.factor)
- optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.MultiDecay.eps, loss_scale=1.0)
- else:
- raise ValueError("Don't support lr_schedule {}, only support [PolynormialDecay, MultiEpochDecay]".
- format(train_config.optimizer))
- else:
- raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, Adam]".
- format(train_config.optimizer))
- return optimizer
-
-
- def train():
- """training CenterNet"""
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
- context.set_context(enable_auto_mixed_precision=False)
- context.set_context(reserve_class_name_in_scope=False)
- context.set_context(save_graphs=False)
-
- ckpt_save_dir = args_opt.save_checkpoint_path
- if args_opt.distribute == "true":
- D.init()
- device_num = args_opt.device_num
- rank = args_opt.device_id % device_num
- ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
-
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- device_num=device_num)
- _set_parallel_all_reduce_split()
- else:
- rank = 0
- device_num = 1
- num_workers = device_num * 8
- # Start create dataset!
- # mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
- logger.info("Begin creating dataset for CenterNet")
- prefix = "coco_hp.train.mind"
- coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode="train")
- coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
- dataset = coco.create_train_dataset(args_opt.mindrecord_dir, prefix, batch_size=train_config.batch_size,
- device_num=device_num, rank=rank, num_parallel_workers=num_workers,
- do_shuffle=args_opt.do_shuffle == 'true')
- dataset_size = dataset.get_dataset_size()
- logger.info("Create dataset done!")
-
- net_with_loss = CenterNetMultiPoseLossCell(net_config)
-
- new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
- if args_opt.train_steps > 0:
- new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
- else:
- args_opt.train_steps = args_opt.epoch_size * dataset_size
- logger.info("train steps: {}".format(args_opt.train_steps))
-
- optimizer = _get_optimizer(net_with_loss, dataset_size)
-
- callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size)]
- if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
- config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
- keep_checkpoint_max=args_opt.save_checkpoint_num)
- ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
- directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
- callback.append(ckpoint_cb)
-
- if args_opt.load_checkpoint_path:
- param_dict = load_checkpoint(args_opt.load_checkpoint_path)
- load_param_into_net(net_with_loss, param_dict)
-
- net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
- sens=train_config.loss_scale_value)
-
- model = Model(net_with_grads)
-
- model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
- sink_size=args_opt.data_sink_steps)
-
-
- if __name__ == '__main__':
- if args_opt.need_profiler == "true":
- profiler = Profiler(output_path=args_opt.profiler_path)
- set_seed(0)
- train()
- if args_opt.need_profiler == "true":
- profiler.analyse()
|