|
|
|
@@ -23,6 +23,7 @@ from lr_generator import get_lr |
|
|
|
from config import config |
|
|
|
from mindspore import context |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore import nn |
|
|
|
from mindspore.model_zoo.mobilenet import mobilenet_v2 |
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context |
|
|
|
from mindspore.nn.optim.momentum import Momentum |
|
|
|
@@ -110,16 +111,17 @@ class Monitor(Callback): |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
if run_distribute: |
|
|
|
context.set_context(enable_hccl=True) |
|
|
|
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
parameter_broadcast=True, mirror_mean=True) |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([140]) |
|
|
|
init() |
|
|
|
else: |
|
|
|
context.set_context(enable_hccl=False) |
|
|
|
|
|
|
|
epoch_size = config.epoch_size |
|
|
|
net = mobilenet_v2(num_classes=config.num_classes) |
|
|
|
net.add_flags_recursive(fp16=True) |
|
|
|
for _, cell in net.cells_and_names(): |
|
|
|
if isinstance(cell, nn.Dense): |
|
|
|
cell.add_flags_recursive(fp32=True) |
|
|
|
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') |
|
|
|
|
|
|
|
print("train args: ", args_opt, "\ncfg: ", config, |
|
|
|
@@ -135,8 +137,7 @@ if __name__ == '__main__': |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, |
|
|
|
config.weight_decay, config.loss_scale) |
|
|
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0', |
|
|
|
keep_batchnorm_fp32=False) |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) |
|
|
|
|
|
|
|
cb = None |
|
|
|
if rank_id == 0: |
|
|
|
|