Browse Source

!6747 Adjust the current NASNet-A-Mobile training setting

Merge pull request !6747 from dessyang/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
05c4c7593f
3 changed files with 12 additions and 11 deletions
  1. +1
    -1
      model_zoo/official/cv/nasnet/README.md
  2. +2
    -2
      model_zoo/official/cv/nasnet/src/config.py
  3. +9
    -8
      model_zoo/official/cv/nasnet/train.py

+ 1
- 1
model_zoo/official/cv/nasnet/README.md View File

@@ -40,7 +40,7 @@ Parameters for both training and evaluating can be set in config.py
'rank': 0, # local rank of distributed 'rank': 0, # local rank of distributed
'group_size': 1, # world size of distributed 'group_size': 1, # world size of distributed
'work_nums': 8, # number of workers to read the data 'work_nums': 8, # number of workers to read the data
'epoch_size': 250, # total epoch numbers
'epoch_size': 500, # total epoch numbers
'keep_checkpoint_max': 100, # max numbers to keep checkpoints 'keep_checkpoint_max': 100, # max numbers to keep checkpoints
'ckpt_path': './checkpoint/', # save checkpoint path 'ckpt_path': './checkpoint/', # save checkpoint path
'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters 'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters


+ 2
- 2
model_zoo/official/cv/nasnet/src/config.py View File

@@ -23,9 +23,9 @@ nasnet_a_mobile_config_gpu = edict({
'rank': 0, 'rank': 0,
'group_size': 1, 'group_size': 1,
'work_nums': 8, 'work_nums': 8,
'epoch_size': 312,
'epoch_size': 500,
'keep_checkpoint_max': 100, 'keep_checkpoint_max': 100,
'ckpt_path': './',
'ckpt_path': './checkpoint/',
'is_save_on_master': 0, 'is_save_on_master': 0,


### Dataset Config ### Dataset Config


+ 9
- 8
model_zoo/official/cv/nasnet/train.py View File

@@ -28,7 +28,7 @@ from mindspore.common import set_seed
from src.config import nasnet_a_mobile_config_gpu as cfg from src.config import nasnet_a_mobile_config_gpu as cfg
from src.dataset import create_dataset from src.dataset import create_dataset
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient
from src.nasnet_a_mobile import NASNetAMobile, CrossEntropy
from src.lr_generator import get_lr from src.lr_generator import get_lr
@@ -68,10 +68,13 @@ if __name__ == '__main__':
batches_per_epoch = dataset.get_dataset_size() batches_per_epoch = dataset.get_dataset_size()
# network # network
net_with_loss = NASNetAMobileWithLoss(cfg)
net = NASNetAMobile(cfg.num_classes)
if args_opt.resume: if args_opt.resume:
ckpt = load_checkpoint(args_opt.resume) ckpt = load_checkpoint(args_opt.resume)
load_param_into_net(net_with_loss, ckpt)
load_param_into_net(net, ckpt)
#loss
loss = CrossEntropy(smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes, factor=cfg.aux_factor)
# learning rate schedule # learning rate schedule
lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate, lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate,
@@ -82,20 +85,18 @@ if __name__ == '__main__':
# optimizer # optimizer
decayed_params = [] decayed_params = []
no_decayed_params = [] no_decayed_params = []
for param in net_with_loss.trainable_params():
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param) decayed_params.append(param)
else: else:
no_decayed_params.append(param) no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
{'params': no_decayed_params}, {'params': no_decayed_params},
{'order_params': net_with_loss.trainable_params()}]
{'order_params': net.trainable_params()}]
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay, optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale) momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
net_with_grads.set_train()
model = Model(net_with_grads)
model = Model(net, loss_fn=loss, optimizer=optimizer)
print("============== Starting Training ==============") print("============== Starting Training ==============")
loss_cb = LossMonitor(per_print_times=batches_per_epoch) loss_cb = LossMonitor(per_print_times=batches_per_epoch)


Loading…
Cancel
Save