Browse Source

update nasnet scripts

tags/v1.1.0
panfengfeng 5 years ago
parent
commit
2b0550b77a
1 changed files with 8 additions and 4 deletions
  1. +8
    -4
      model_zoo/official/cv/nasnet/train.py

+ 8
- 4
model_zoo/official/cv/nasnet/train.py View File

@@ -29,7 +29,7 @@ from mindspore.common import dtype as mstype


from src.config import nasnet_a_mobile_config_gpu as cfg from src.config import nasnet_a_mobile_config_gpu as cfg
from src.dataset import create_dataset from src.dataset import create_dataset
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient
from src.nasnet_a_mobile import NASNetAMobileWithLoss
from src.lr_generator import get_lr from src.lr_generator import get_lr




@@ -104,9 +104,13 @@ if __name__ == '__main__':
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay, optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale) momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)


net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
net_with_grads.set_train()
model = Model(net_with_grads)
# net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
# net_with_grads.set_train()
# model = Model(net_with_grads)

# high performance
net_with_loss.set_train()
model = Model(net_with_loss, optimizer=optimizer)


print("============== Starting Training ==============") print("============== Starting Training ==============")
loss_cb = LossMonitor(per_print_times=batches_per_epoch) loss_cb = LossMonitor(per_print_times=batches_per_epoch)


Loading…
Cancel
Save