Browse Source

!5502 Mod SoftmaxCrossEntropyWithlogits

Merge pull request !5502 from wanyiming/mod_SoftmaxCrossEntropyWithlogits
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
fc79997de5
66 changed files with 170 additions and 126 deletions
  1. +6
    -17
      mindspore/nn/loss/loss.py
  2. +2
    -2
      mindspore/nn/probability/toolbox/uncertainty_evaluation.py
  3. +1
    -1
      model_zoo/official/cv/alexnet/eval.py
  4. +1
    -1
      model_zoo/official/cv/alexnet/train.py
  5. +1
    -1
      model_zoo/official/cv/googlenet/eval.py
  6. +1
    -1
      model_zoo/official/cv/googlenet/train.py
  7. +1
    -1
      model_zoo/official/cv/lenet/eval.py
  8. +1
    -1
      model_zoo/official/cv/lenet/train.py
  9. +1
    -1
      model_zoo/official/cv/lenet_quant/eval_quant.py
  10. +1
    -1
      model_zoo/official/cv/lenet_quant/train_quant.py
  11. +1
    -2
      model_zoo/official/cv/mobilenetv2/eval.py
  12. +2
    -3
      model_zoo/official/cv/mobilenetv2/train.py
  13. +1
    -1
      model_zoo/official/cv/mobilenetv2_quant/eval.py
  14. +2
    -2
      model_zoo/official/cv/mobilenetv2_quant/train.py
  15. +1
    -2
      model_zoo/official/cv/mobilenetv3/eval.py
  16. +1
    -2
      model_zoo/official/cv/mobilenetv3/train.py
  17. +3
    -2
      model_zoo/official/cv/resnet/eval.py
  18. +38
    -0
      model_zoo/official/cv/resnet/src/CrossEntropySmooth.py
  19. +6
    -6
      model_zoo/official/cv/resnet/train.py
  20. +1
    -1
      model_zoo/official/cv/vgg16/eval.py
  21. +1
    -1
      model_zoo/official/cv/vgg16/train.py
  22. +1
    -1
      model_zoo/official/nlp/lstm/eval.py
  23. +1
    -1
      model_zoo/official/nlp/lstm/train.py
  24. +1
    -1
      tests/st/fusion/test_conv_bn1_fusion.py
  25. +1
    -1
      tests/st/host_device/test_host_device_lenet.py
  26. +1
    -1
      tests/st/nccl/test_nccl_lenet.py
  27. +38
    -0
      tests/st/networks/models/resnet50/src/CrossEntropySmooth.py
  28. +5
    -6
      tests/st/networks/models/resnet50/test_resnet50_imagenet.py
  29. +1
    -1
      tests/st/networks/test_cpu_lenet.py
  30. +1
    -1
      tests/st/networks/test_gpu_alexnet.py
  31. +2
    -2
      tests/st/networks/test_gpu_lenet.py
  32. +1
    -1
      tests/st/networks/test_gpu_lstm.py
  33. +3
    -3
      tests/st/networks/test_gpu_resnet.py
  34. +1
    -1
      tests/st/networks/test_network_main.py
  35. +1
    -1
      tests/st/ops/cpu/test_momentum_op.py
  36. +1
    -1
      tests/st/ops/gpu/test_adam_op.py
  37. +1
    -1
      tests/st/ops/gpu/test_ftrl_op.py
  38. +1
    -1
      tests/st/ops/gpu/test_momentum_op.py
  39. +1
    -1
      tests/st/ops/gpu/test_sgd_op.py
  40. +5
    -18
      tests/st/ops/gpu/test_sparse_softmax_cross_entropy_with_logits_op.py
  41. +1
    -1
      tests/st/probability/test_bnn_layer.py
  42. +1
    -1
      tests/st/probability/test_transform_bnn_layer.py
  43. +1
    -1
      tests/st/probability/test_transform_bnn_model.py
  44. +1
    -3
      tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py
  45. +1
    -1
      tests/st/ps/full_ps/test_full_ps_lenet.py
  46. +1
    -3
      tests/st/ps/multi_full_ps/test_multi_full_ps.py
  47. +1
    -1
      tests/st/pynative/test_pynative_hook.py
  48. +2
    -2
      tests/st/pynative/test_pynative_mindarmour.py
  49. +3
    -3
      tests/st/quantization/lenet_quant/test_lenet_quant.py
  50. +1
    -1
      tests/st/summary/test_summary.py
  51. +1
    -1
      tests/ut/python/exec/test_train.py
  52. +1
    -1
      tests/ut/python/exec/test_train_with_lars.py
  53. +1
    -1
      tests/ut/python/parallel/test_allreduce_fusion.py
  54. +1
    -1
      tests/ut/python/parallel/test_alltoall.py
  55. +1
    -1
      tests/ut/python/parallel/test_batchnorm_batch_parallel.py
  56. +1
    -1
      tests/ut/python/parallel/test_bn_prelu_cell.py
  57. +1
    -1
      tests/ut/python/parallel/test_dataset_interface.py
  58. +1
    -1
      tests/ut/python/parallel/test_full_batch.py
  59. +1
    -1
      tests/ut/python/parallel/test_one_dev.py
  60. +2
    -2
      tests/ut/python/parallel/test_operator_model_parallel.py
  61. +1
    -1
      tests/ut/python/parallel/test_prelu_cell.py
  62. +1
    -1
      tests/ut/python/parallel/test_reshape.py
  63. +1
    -1
      tests/ut/python/parallel/test_transpose.py
  64. +1
    -1
      tests/ut/python/pynative_mode/test_hook.py
  65. +1
    -1
      tests/ut/python/pynative_mode/test_pynative_model.py
  66. +1
    -1
      tests/ut/python/utils/test_serialize.py

+ 6
- 17
mindspore/nn/loss/loss.py View File

@@ -213,13 +213,9 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
of entry is a valid one.

Args:
is_grad (bool): Specifies whether calculate grad only. Default: True.
sparse (bool): Specifies whether labels use sparse format or not. Default: False.
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
If "none", do not perform reduction. Default: "none".
smooth_factor (float): Label smoothing factor. It is a optional input which should be in range [0, 1].
Default: 0.
num_classes (int): The number of classes in the task. It is a optional input Default: 2.

Inputs:
- **logits** (Tensor) - Tensor of shape (N, C).
@@ -238,29 +234,22 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
>>> loss(logits, labels)
"""
def __init__(self,
is_grad=True,
sparse=False,
reduction='none',
smooth_factor=0,
num_classes=2):
reduction='none'):
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.is_grad = is_grad
self.sparse = sparse
validator.check_number_range(
"smooth_factor", smooth_factor, 0, 1, Rel.INC_BOTH, self.cls_name)
self.smooth_factor = smooth_factor
self.num_classes = num_classes
self.reduction = reduction
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0 - self.smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * self.smooth_factor / (self.num_classes - 1), mstype.float32)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0., mstype.float32)
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"]

if self.is_cpugpu:
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=self.is_grad)
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()

def construct(self, logits, labels):
if self.is_cpugpu and self.sparse:
if self.is_cpugpu and self.sparse and self.reduction == 'mean':
x = self.sparse_softmax_cross_entropy(logits, labels)
return x



+ 2
- 2
mindspore/nn/probability/toolbox/uncertainty_evaluation.py View File

@@ -115,7 +115,7 @@ class UncertaintyEvaluation:
self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
if self.epi_uncer_model.drop_count == 0:
if self.task_type == 'classification':
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = Adam(self.epi_uncer_model.trainable_params())
model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
@@ -314,7 +314,7 @@ class AleatoricLoss(Cell):
self.exp = P.Exp()
self.normal = C.normal
self.to_tensor = P.ScalarToArray()
self.entropy = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
self.entropy = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
else:
self.mean = P.ReduceMean()
self.exp = P.Exp()


+ 1
- 1
model_zoo/official/cv/alexnet/eval.py View File

@@ -44,7 +44,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})


+ 1
- 1
model_zoo/official/cv/alexnet/train.py View File

@@ -47,7 +47,7 @@ if __name__ == "__main__":

ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})


+ 1
- 1
model_zoo/official/cv/googlenet/eval.py View File

@@ -41,7 +41,7 @@ if __name__ == '__main__':
net = GoogleNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})

if device_target == "Ascend":


+ 1
- 1
model_zoo/official/cv/googlenet/train.py View File

@@ -102,7 +102,7 @@ if __name__ == '__main__':
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

if device_target == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},


+ 1
- 1
model_zoo/official/cv/lenet/eval.py View File

@@ -46,7 +46,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})


+ 1
- 1
model_zoo/official/cv/lenet/train.py View File

@@ -50,7 +50,7 @@ if __name__ == "__main__":
cfg.batch_size)

network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,


+ 1
- 1
model_zoo/official/cv/lenet_quant/eval_quant.py View File

@@ -51,7 +51,7 @@ if __name__ == "__main__":
per_channel=[True, False])

# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)



+ 1
- 1
model_zoo/official/cv/lenet_quant/train_quant.py View File

@@ -60,7 +60,7 @@ if __name__ == "__main__":
symmetric=[False, False])

# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)



+ 1
- 2
model_zoo/official/cv/mobilenetv2/eval.py View File

@@ -51,8 +51,7 @@ if __name__ == '__main__':
else:
raise ValueError("Unsupported device_target.")

loss = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

if args_opt.device_target == "Ascend":
net.to_float(mstype.float16)


+ 2
- 3
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -173,7 +173,7 @@ if __name__ == '__main__':
loss = CrossEntropyWithLabelSmooth(smooth_factor=config_gpu.label_smooth,
num_classes=config_gpu.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
epoch_size = config_gpu.epoch_size
dataset = create_dataset(dataset_path=args_opt.dataset_path,
@@ -237,8 +237,7 @@ if __name__ == '__main__':
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config_ascend,


+ 1
- 1
model_zoo/official/cv/mobilenetv2_quant/eval.py View File

@@ -53,7 +53,7 @@ if __name__ == '__main__':
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# define dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,


+ 2
- 2
model_zoo/official/cv/mobilenetv2_quant/train.py View File

@@ -90,7 +90,7 @@ def train_on_ascend():
if config.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes)
else:
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
@@ -151,7 +151,7 @@ def train_on_gpu():
loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth,
num_classes=config.num_classes)
else:
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
epoch_size = config.epoch_size
dataset = create_dataset(dataset_path=args_opt.dataset_path,


+ 1
- 2
model_zoo/official/cv/mobilenetv3/eval.py View File

@@ -41,8 +41,7 @@ if __name__ == '__main__':
else:
raise ValueError("Unsupported device_target.")

loss = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net = mobilenet_v3_large(num_classes=config.num_classes)

dataset = create_dataset(dataset_path=args_opt.dataset_path,


+ 1
- 2
model_zoo/official/cv/mobilenetv3/train.py View File

@@ -163,8 +163,7 @@ if __name__ == '__main__':
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction='mean')
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
epoch_size = config_gpu.epoch_size
dataset = create_dataset(dataset_path=args_opt.dataset_path,


+ 3
- 2
model_zoo/official/cv/resnet/eval.py View File

@@ -22,6 +22,7 @@ from mindspore import dataset as de
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth

parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
@@ -79,8 +80,8 @@ if __name__ == '__main__':
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
loss = CrossEntropySmooth(sparse=True, reduction='mean',
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')



+ 38
- 0
model_zoo/official/cv/resnet/src/CrossEntropySmooth.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P


class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

+ 6
- 6
model_zoo/official/cv/resnet/train.py View File

@@ -33,6 +33,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
from src.CrossEntropySmooth import CrossEntropySmooth

parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
@@ -147,8 +148,8 @@ if __name__ == '__main__':
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
@@ -159,11 +160,10 @@ if __name__ == '__main__':
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False,
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False,
num_classes=config.class_num)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

if args_opt.net == "resnet101" or args_opt.net == "resnet50":
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay,


+ 1
- 1
model_zoo/official/cv/vgg16/eval.py View File

@@ -134,7 +134,7 @@ def test(cloud_args=None):
net = vgg16(num_classes=args.num_classes, args=args)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum,
weight_decay=args.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})

param_dict = load_checkpoint(args.pre_trained)


+ 1
- 1
model_zoo/official/cv/vgg16/train.py View File

@@ -211,7 +211,7 @@ if __name__ == '__main__':
loss_scale=args.loss_scale)

if args.dataset == "cifar10":
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
else:


+ 1
- 1
model_zoo/official/nlp/lstm/eval.py View File

@@ -64,7 +64,7 @@ if __name__ == '__main__':
weight=Tensor(embedding_table),
batch_size=cfg.batch_size)

loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
loss_cb = LossMonitor()



+ 1
- 1
model_zoo/official/nlp/lstm/train.py View File

@@ -70,7 +70,7 @@ if __name__ == '__main__':
if args.pre_trained:
load_param_into_net(network, load_checkpoint(args.pre_trained))

loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
loss_cb = LossMonitor()



+ 1
- 1
tests/st/fusion/test_conv_bn1_fusion.py View File

@@ -39,7 +39,7 @@ class MsWrapper(nn.Cell):


def me_train_tensor(net, input_np, label_np, epoch_size=2):
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
opt = nn.Momentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])),
filter(lambda x: x.requires_grad, net.get_parameters()))
context.set_context(mode=context.GRAPH_MODE)


+ 1
- 1
tests/st/host_device/test_host_device_lenet.py View File

@@ -66,7 +66,7 @@ def train(net, data, label):
momentum = 0.9

optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()


+ 1
- 1
tests/st/nccl/test_nccl_lenet.py View File

@@ -85,7 +85,7 @@ def test_lenet_nccl():
learning_rate = multisteplr(epoch, 2)
momentum = 0.9
mom_optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, mom_optimizer)
train_network.set_train()


+ 38
- 0
tests/st/networks/models/resnet50/src/CrossEntropySmooth.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P


class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

+ 5
- 6
tests/st/networks/models/resnet50/test_resnet50_imagenet.py View File

@@ -36,12 +36,12 @@ from tests.st.networks.models.resnet50.src.dataset import create_dataset
from tests.st.networks.models.resnet50.src.lr_generator import get_learning_rate
from tests.st.networks.models.resnet50.src.config import config
from tests.st.networks.models.resnet50.src.metric import DistAccuracy, ClassifyCorrectCell
from tests.st.networks.models.resnet50.src.CrossEntropySmooth import CrossEntropySmooth
from tests.st.networks.models.resnet50.src_thor.config import config as thor_config
from tests.st.networks.models.resnet50.src_thor.model_thor import Model as THOR_Model
from tests.st.networks.models.resnet50.src_thor.resnet import resnet50 as resnet50_thor
from tests.st.networks.models.resnet50.src_thor.thor import THOR


MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_1.json"
MINDSPORE_HCCL_CONFIG_PATH_2 = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_2.json"
dataset_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/train"
@@ -151,8 +151,8 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
config.label_smooth_factor = 0.0

# loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)
loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)

# train dataset
dataset = create_dataset(dataset_path=dataset_path, do_train=True,
@@ -260,9 +260,8 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
thor_config.label_smooth_factor = 0.0

# loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=thor_config.label_smooth_factor,
num_classes=thor_config.class_num)
loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=thor_config.label_smooth_factor,
num_classes=thor_config.class_num)

# train dataset
dataset = create_dataset(dataset_path=dataset_path, do_train=True,


+ 1
- 1
tests/st/networks/test_cpu_lenet.py View File

@@ -60,7 +60,7 @@ def train(net, data, label):
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()


+ 1
- 1
tests/st/networks/test_gpu_alexnet.py View File

@@ -76,7 +76,7 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32):
lr = 0.1
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum, weight_decay=0.0001)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()


+ 2
- 2
tests/st/networks/test_gpu_lenet.py View File

@@ -136,7 +136,7 @@ def test_train_lenet():
learning_rate = multisteplr(epoch, 30)

optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
@@ -192,7 +192,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
def test_train_and_eval_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})



+ 1
- 1
tests/st/networks/test_gpu_lstm.py View File

@@ -129,7 +129,7 @@ def test_LSTM():
momentum = 0.9

optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()


+ 3
- 3
tests/st/networks/test_gpu_resnet.py View File

@@ -337,7 +337,7 @@ def test_trainTensor(num_classes=10, epoch=8, batch_size=1):
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), lr, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer) # optimizer
@@ -361,7 +361,7 @@ def test_trainTensor_big_batchSize(num_classes=10, epoch=8, batch_size=338):
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), lr, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer) # optimizer
@@ -385,7 +385,7 @@ def test_trainTensor_amp(num_classes=10, epoch=18, batch_size=16):
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), lr, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
train_network = amp.build_train_network(
net, optimizer, criterion, level="O2")
train_network.set_train()


+ 1
- 1
tests/st/networks/test_network_main.py View File

@@ -39,7 +39,7 @@ def train(net, data, label):
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()


+ 1
- 1
tests/st/ops/cpu/test_momentum_op.py View File

@@ -52,7 +52,7 @@ def test_momentum():
momentum = 0.9

optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()


+ 1
- 1
tests/st/ops/gpu/test_adam_op.py View File

@@ -49,7 +49,7 @@ def test_adam():
net = NetAdam()
optimizer = Adam(filter(lambda x: x.requires_grad,
net.get_parameters()), learning_rate=0.01)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer)


+ 1
- 1
tests/st/ops/gpu/test_ftrl_op.py View File

@@ -49,7 +49,7 @@ def test_ftrl():
net = NetFtrl()
optimizer = FTRL(filter(lambda x: x.requires_grad,
net.get_parameters()), learning_rate=0.01)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer)


+ 1
- 1
tests/st/ops/gpu/test_momentum_op.py View File

@@ -52,7 +52,7 @@ def test_momentum():
momentum = 0.9

optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()


+ 1
- 1
tests/st/ops/gpu/test_sgd_op.py View File

@@ -55,7 +55,7 @@ def test_SGD():
optimizer = SGD(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum, dampening,
weight_decay, nesterov, loss_scale)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()


+ 5
- 18
tests/st/ops/gpu/test_sparse_softmax_cross_entropy_with_logits_op.py View File

@@ -20,15 +20,13 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor


class NetSparseSoftmaxCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(NetSparseSoftmaxCrossEntropyWithLogits, self).__init__()
self.loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
self.dlogits = nn.SoftmaxCrossEntropyWithLogits(is_grad=True, sparse=True)
self.loss = self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)

def construct(self, logits, labels):
return (self.loss(logits, labels), self.dlogits(logits, labels))
return self.loss(logits, labels)


@pytest.mark.level0
@@ -39,29 +37,18 @@ def test_sparse_softmax_cross_entropy_with_logits():
[1, 10, 1],
[10, 1, 1]]).astype(np.float32))
labels = Tensor(np.array([2, 1, 0]).astype(np.int32))
expect_loss = 0.0002467
expect_dlogits = np.array([[4.1126452e-05, 4.1126452e-05, -8.2234539e-05],
[4.1126452e-05, -8.2234539e-05, 4.1126452e-05],
[-8.2234539e-05, 4.1126452e-05, 4.1126452e-05]]).astype(np.float32)
expect_loss = [0.00024673, 0.00024673, 0.00024673]

context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
sparse_softmax_cross_entropy_with_logits = NetSparseSoftmaxCrossEntropyWithLogits()
output = sparse_softmax_cross_entropy_with_logits(logits, labels)
error0 = 1.0e-6
diff0 = output[0].asnumpy() - expect_loss
diff0 = output.asnumpy() - expect_loss
assert np.all(abs(diff0) < error0)

error1 = np.ones(shape=[3, 3]) * 1.0e-6
diff1 = output[1].asnumpy() - expect_dlogits
assert np.all(abs(diff1) < error1)

context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
sparse_softmax_cross_entropy_with_logits = NetSparseSoftmaxCrossEntropyWithLogits()
output = sparse_softmax_cross_entropy_with_logits(logits, labels)
error0 = 1.0e-6
diff0 = output[0].asnumpy() - expect_loss
diff0 = output.asnumpy() - expect_loss
assert np.all(abs(diff0) < error0)

error1 = np.ones(shape=[3, 3]) * 1.0e-6
diff1 = output[1].asnumpy() - expect_dlogits
assert np.all(abs(diff1) < error1)

+ 1
- 1
tests/st/probability/test_bnn_layer.py View File

@@ -124,7 +124,7 @@ def validate_model(net, dataset):
if __name__ == "__main__":
network = BNNLeNet5()

criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)

net_with_loss = bnn_layers.WithBNNLossCell(network, criterion, 60000, 0.000001)


+ 1
- 1
tests/st/probability/test_transform_bnn_layer.py View File

@@ -125,7 +125,7 @@ def validate_model(net, dataset):
if __name__ == "__main__":
network = LeNet5()

criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)

net_with_loss = WithLossCell(network, criterion)


+ 1
- 1
tests/st/probability/test_transform_bnn_model.py View File

@@ -124,7 +124,7 @@ def validate_model(net, dataset):
if __name__ == "__main__":
network = LeNet5()

criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)

net_with_loss = WithLossCell(network, criterion)


+ 1
- 3
tests/st/ps/cmp_sparse_embedding/test_cmp_sparse_embedding.py View File

@@ -73,9 +73,7 @@ def do_sparse_embedding(ps=False):

optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters()))
optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU")
criterion = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction="mean"
)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()


+ 1
- 1
tests/st/ps/full_ps/test_full_ps_lenet.py View File

@@ -123,7 +123,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
if __name__ == "__main__":
network = LeNet5(10)
network.set_param_ps()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})



+ 1
- 3
tests/st/ps/multi_full_ps/test_multi_full_ps.py View File

@@ -94,9 +94,7 @@ if __name__ == "__main__":
np.random.seed(0)
network = LeNet5(10)
network.set_param_ps()
criterion = nn.SoftmaxCrossEntropyWithLogits(
is_grad=False, sparse=True, reduction="mean"
)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
if device_target == "GPU":
context.set_auto_parallel_context(parallel_mode="data_parallel", mirror_mean=True, device_num=get_group_size())


+ 1
- 1
tests/st/pynative/test_pynative_hook.py View File

@@ -159,7 +159,7 @@ def test_pynative_lenet_train_hook_function_print_and_save_grad():
cell_hook_function_print_grad)
net = LeNet5(hook_function=function[0], cell_hook_function=function[1])
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrap(net_with_criterion)
train_network.set_train()


+ 2
- 2
tests/st/pynative/test_pynative_mindarmour.py View File

@@ -145,14 +145,14 @@ def test_multi_grads():
net = LeNet()

# grad operation
loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=sparse)
with_loss_cell = WithLossCell(net, loss_fn)
grad_all = GradWrapWithLoss(with_loss_cell)
grad_out = grad_all(Tensor(inputs_np), Tensor(labels_np)).asnumpy()
assert np.any(grad_out != 0), 'grad result can not be all zeros'

# train-one-step operation
loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=sparse)
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
0.01, 0.9)
loss_net = WithLossCell(net, loss_fn)


+ 3
- 3
tests/st/quantization/lenet_quant/test_lenet_quant.py View File

@@ -42,7 +42,7 @@ def train_lenet():
cfg.batch_size)

network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
@@ -74,7 +74,7 @@ def train_lenet_quant():
symmetric=[False, False])

# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)

@@ -104,7 +104,7 @@ def eval_quant():
per_channel=[True, False])

# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)



+ 1
- 1
tests/st/summary/test_summary.py View File

@@ -154,7 +154,7 @@ class TestSummary:
def _run_network(self, dataset_sink_mode=True):
lenet = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Accuracy()})
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)


+ 1
- 1
tests/ut/python/exec/test_train.py View File

@@ -31,7 +31,7 @@ def lr_gen(fn, epoch_size):

def me_train_tensor(net, input_np, label_np, epoch_size=2):
"""me_train_tensor"""
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr_gen(lambda i: 0.1, epoch_size), 0.9,
0.01, 1024)
Model(net, loss, opt)


+ 1
- 1
tests/ut/python/exec/test_train_with_lars.py View File

@@ -78,7 +78,7 @@ def lr_gen(fn, epoch_size):

def me_train_tensor(net, input_np, label_np, epoch_size=2):
"""me_train_tensor"""
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# reorder the net parameters , leave the parameters that need to be passed into lars to the end part

opt = Momentum(get_net_trainable_reordered_params(net)[2], lr_gen(lambda i: 0.1, epoch_size), 0.9, 0.01, 1024)


+ 1
- 1
tests/ut/python/parallel/test_allreduce_fusion.py View File

@@ -114,7 +114,7 @@ def train_common(net):
label = Tensor(np.ones([batch_size]), dtype=ms.int32)
dataset = Dataset(predict, label, 2)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = Momentum(net.trainable_params(), learning_rate, momentum)
model = Model(net, loss, opt)



+ 1
- 1
tests/ut/python/parallel/test_alltoall.py View File

@@ -79,7 +79,7 @@ def all_to_all_common(strategy1):
dataset = Dataset(predict, label, 2)
net = all_to_all_net(strategy1)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
loss.one_hot.set_strategy(((8, 1), (), ()))
opt = Momentum(net.trainable_params(), learning_rate, momentum)


+ 1
- 1
tests/ut/python/parallel/test_batchnorm_batch_parallel.py View File

@@ -134,7 +134,7 @@ def test_batchnorm_batch_parallel():
dataset = DatasetLenet(predict, label, 2)
net = batchnorm_net(num_classes)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(((dev_num, 1), (dev_num, 1)))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)



+ 1
- 1
tests/ut/python/parallel/test_bn_prelu_cell.py View File

@@ -209,7 +209,7 @@ def bn_common(parallel_mode, train_flag, strategy_loss=None):
dataset = Dataset(predict, label, 2)
net = bn_net()

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(strategy_loss)
opt = Momentum(net.trainable_params(), learning_rate, momentum, 0.0001, 1024 * rank_size)



+ 1
- 1
tests/ut/python/parallel/test_dataset_interface.py View File

@@ -80,7 +80,7 @@ def loss_scale_manager_common(strategy1):
dataset = Dataset(predict, label, 2)
net = all_to_all_net(strategy1)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
opt = Momentum(net.trainable_params(), learning_rate, momentum)
scale_manager = DynamicLossScaleManager(32, 2, 2000)


+ 1
- 1
tests/ut/python/parallel/test_full_batch.py View File

@@ -76,7 +76,7 @@ def all_to_all_common(strategy1):
dataset = Dataset(predict, label, 2)
net = all_to_all_net(strategy1)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
loss.one_hot.set_strategy(((8, 1), (), ()))
opt = Momentum(net.trainable_params(), learning_rate, momentum)


+ 1
- 1
tests/ut/python/parallel/test_one_dev.py View File

@@ -82,7 +82,7 @@ def all_to_all_common():
dataset = Dataset(predict, label, 2)
net = all_to_all_net()

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = Momentum(net.trainable_params(), learning_rate, momentum)
model = Model(net, loss, opt)



+ 2
- 2
tests/ut/python/parallel/test_operator_model_parallel.py View File

@@ -362,7 +362,7 @@ def test_resnet_operator_batch_parallel():
dataset = DatasetLenet(predict, label, 2)
net = resnet_operator_net(num_classes)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(((dev_num, 1), (dev_num, 1)))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)

@@ -387,7 +387,7 @@ def test_resnet_model_parallel():
dataset = DatasetLenet(predict, label, 2)
net = resnet_model_parallel_net(num_classes)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(((dev_num, 1), (dev_num, 1)))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)



+ 1
- 1
tests/ut/python/parallel/test_prelu_cell.py View File

@@ -108,7 +108,7 @@ def reshape_common(parallel_mode):
dataset = Dataset(predict, label, 2)
net = prelu_net()

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = Momentum(net.trainable_params(), learning_rate, momentum)
model = Model(net, loss, opt)
model.train(epoch_size, dataset, dataset_sink_mode=False)


+ 1
- 1
tests/ut/python/parallel/test_reshape.py View File

@@ -95,7 +95,7 @@ def reshape_common(parallel_mode, strategy0, strategy1, strategy2, strategy_loss
dataset = Dataset(predict, label, 2)
net = reshape_net(strategy0, strategy1, strategy2)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(strategy_loss)
loss.one_hot.set_strategy(((8, 1), (), ()))
opt = Momentum(net.trainable_params(), learning_rate, momentum)


+ 1
- 1
tests/ut/python/parallel/test_transpose.py View File

@@ -80,7 +80,7 @@ def transpose_common(strategy1, strategy2):
dataset = Dataset(predict, label, 2)
net = transpose_net(strategy1, strategy2)

loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1)))
opt = Momentum(net.trainable_params(), learning_rate, momentum)
context.set_context(mode=context.GRAPH_MODE)


+ 1
- 1
tests/ut/python/pynative_mode/test_hook.py View File

@@ -141,7 +141,7 @@ class GradWrap(nn.Cell):
def test_hook():
net = LeNet5()
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrap(net_with_criterion)
train_network.set_train()


+ 1
- 1
tests/ut/python/pynative_mode/test_pynative_model.py View File

@@ -129,7 +129,7 @@ def test_lenet_grad():
verification_step = 0

net = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False)
loss = nn.SoftmaxCrossEntropyWithLogits()
momen_opti = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = GradWrap(NetWithLossClass(net))
train_net.set_train()


+ 1
- 1
tests/ut/python/utils/test_serialize.py View File

@@ -283,7 +283,7 @@ def test_load_param_into_net():
def test_save_checkpoint_for_network():
""" test save_checkpoint for network"""
net = Net()
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)

loss_net = WithLossCell(net, loss)


Loading…
Cancel
Save