|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # less required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Train Retinaface_resnet50."""
- from __future__ import print_function
- import random
- import math
- import numpy as np
-
- import mindspore.nn as nn
- import mindspore.dataset as de
- from mindspore import context
- from mindspore.context import ParallelMode
- from mindspore.train import Model
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
-
- from src.config import cfg_res50
- from src.network import RetinaFace, RetinaFaceWithLossCell, TrainingWrapper, resnet50
- from src.loss import MultiBoxLoss
- from src.dataset import create_dataset
-
- def setup_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- de.config.set_seed(seed)
-
- def adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, total_epochs, warmup_epoch=5):
- lr_each_step = []
- for epoch in range(1, total_epochs+1):
- for step in range(steps_per_epoch):
- if epoch <= warmup_epoch:
- lr = 1e-6 + (initial_lr - 1e-6) * ((epoch - 1) * steps_per_epoch + step) / \
- (steps_per_epoch * warmup_epoch)
- else:
- if stepvalues[0] <= epoch <= stepvalues[1]:
- lr = initial_lr * (gamma ** (1))
- elif epoch > stepvalues[1]:
- lr = initial_lr * (gamma ** (2))
- else:
- lr = initial_lr
- lr_each_step.append(lr)
- return lr_each_step
-
- def train(cfg):
-
- context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
- if cfg['ngpu'] > 1:
- init("nccl")
- context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
- cfg['ckpt_path'] = cfg['ckpt_path'] + "ckpt_" + str(get_rank()) + "/"
- else:
- raise ValueError('cfg_num_gpu <= 1')
-
- batch_size = cfg['batch_size']
- max_epoch = cfg['epoch']
-
- momentum = cfg['momentum']
- weight_decay = cfg['weight_decay']
- initial_lr = cfg['initial_lr']
- gamma = cfg['gamma']
- training_dataset = cfg['training_dataset']
- num_classes = 2
- negative_ratio = 7
- stepvalues = (cfg['decay1'], cfg['decay2'])
-
- ds_train = create_dataset(training_dataset, cfg, batch_size, multiprocessing=True, num_worker=cfg['num_workers'])
- print('dataset size is : \n', ds_train.get_dataset_size())
-
- steps_per_epoch = math.ceil(ds_train.get_dataset_size())
-
- multibox_loss = MultiBoxLoss(num_classes, cfg['num_anchor'], negative_ratio, cfg['batch_size'])
- backbone = resnet50(1001)
- backbone.set_train(True)
-
- if cfg['pretrain'] and cfg['resume_net'] is None:
- pretrained_res50 = cfg['pretrain_path']
- param_dict_res50 = load_checkpoint(pretrained_res50)
- load_param_into_net(backbone, param_dict_res50)
- print('Load resnet50 from [{}] done.'.format(pretrained_res50))
-
- net = RetinaFace(phase='train', backbone=backbone)
- net.set_train(True)
-
- if cfg['resume_net'] is not None:
- pretrain_model_path = cfg['resume_net']
- param_dict_retinaface = load_checkpoint(pretrain_model_path)
- load_param_into_net(net, param_dict_retinaface)
- print('Resume Model from [{}] Done.'.format(cfg['resume_net']))
-
- net = RetinaFaceWithLossCell(net, multibox_loss, cfg)
-
- lr = adjust_learning_rate(initial_lr, gamma, stepvalues, steps_per_epoch, max_epoch,
- warmup_epoch=cfg['warmup_epoch'])
-
- if cfg['optim'] == 'momentum':
- opt = nn.Momentum(net.trainable_params(), lr, momentum)
- elif cfg['optim'] == 'sgd':
- opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
- weight_decay=weight_decay, loss_scale=1)
- else:
- raise ValueError('optim is not define.')
-
- net = TrainingWrapper(net, opt)
-
- model = Model(net)
-
- config_ck = CheckpointConfig(save_checkpoint_steps=cfg['save_checkpoint_steps'],
- keep_checkpoint_max=cfg['keep_checkpoint_max'])
- ckpoint_cb = ModelCheckpoint(prefix="RetinaFace", directory=cfg['ckpt_path'], config=config_ck)
-
- time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
- callback_list = [LossMonitor(), time_cb, ckpoint_cb]
-
- print("============== Starting Training ==============")
- model.train(max_epoch, ds_train, callbacks=callback_list,
- dataset_sink_mode=False)
-
-
-
- if __name__ == '__main__':
-
- setup_seed(1)
- config = cfg_res50
- print('train config:\n', config)
-
- train(cfg=config)
|