diff --git a/example/mobilenetv2_imagenet2012/config.py b/example/mobilenetv2_imagenet2012/config.py index 32df4eabc9..2a8d37b6fc 100644 --- a/example/mobilenetv2_imagenet2012/config.py +++ b/example/mobilenetv2_imagenet2012/config.py @@ -27,6 +27,7 @@ config = ed({ "lr": 0.4, "momentum": 0.9, "weight_decay": 4e-5, + "label_smooth": 0.1, "loss_scale": 1024, "save_checkpoint": True, "save_checkpoint_epochs": 1, diff --git a/example/mobilenetv2_imagenet2012/dataset.py b/example/mobilenetv2_imagenet2012/dataset.py index 9df34d51dc..46f5a1770c 100644 --- a/example/mobilenetv2_imagenet2012/dataset.py +++ b/example/mobilenetv2_imagenet2012/dataset.py @@ -53,8 +53,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): # define map operations decode_op = C.Decode() - resize_crop_op = C.RandomResizedCrop(resize_height, scale=(0.2, 1.0)) - horizontal_flip_op = C.RandomHorizontalFlip() + resize_crop_op = C.RandomResizedCrop(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333)) + horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5) resize_op = C.Resize((256, 256)) center_crop = C.CenterCrop(resize_width) diff --git a/example/mobilenetv2_imagenet2012/eval.py b/example/mobilenetv2_imagenet2012/eval.py index 6c51fc042b..397b3a37c3 100644 --- a/example/mobilenetv2_imagenet2012/eval.py +++ b/example/mobilenetv2_imagenet2012/eval.py @@ -38,8 +38,6 @@ context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) if __name__ == '__main__': - context.set_context(enable_hccl=False) - loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') net = mobilenet_v2() diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py index d97eab5f04..c12f2ef9c0 100644 --- a/example/mobilenetv2_imagenet2012/train.py +++ b/example/mobilenetv2_imagenet2012/train.py @@ -28,6 +28,10 @@ from mindspore.model_zoo.mobilenet import mobilenet_v2 from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.nn.optim.momentum import Momentum from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype from mindspore.train.model import Model, ParallelMode @@ -54,6 +58,35 @@ context.set_context(enable_task_sink=True) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor, default=0. + num_classes (int): num classes + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss class Monitor(Callback): """ @@ -63,7 +96,7 @@ class Monitor(Callback): lr_init (numpy array): train lr Returns: - None. + None Examples: >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) @@ -122,7 +155,10 @@ if __name__ == '__main__': for _, cell in net.cells_and_names(): if isinstance(cell, nn.Dense): cell.add_flags_recursive(fp32=True) - loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') + if config.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') print("train args: ", args_opt, "\ncfg: ", config, "\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))