From 094e306611824ef3d2dfd2548ee4061af3940e4d Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Fri, 15 Jan 2021 09:53:58 +0800 Subject: [PATCH] fix nll_loss --- mindspore/ops/operations/_grad_ops.py | 5 ++++- mindspore/ops/operations/nn_ops.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 6cce3c7f24..364725c6b7 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1761,7 +1761,10 @@ class NLLLossGrad(PrimitiveWithInfer): validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) - validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) + if len(x_shape) == 1: + validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name) + else: + validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_dtype, y_grad_dtype, t_dtype, w_dtype, tw_dtype): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 45ee28227e..8dc2279a6f 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1956,7 +1956,7 @@ class NLLLoss(PrimitiveWithInfer): def __init__(self, reduction="mean"): """Initialize NLLLoss""" self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss']) - self.reduction = validator.check_string(reduction.lower(), ['none', 'sum', 'mean'], 'reduction', self.name) + self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) self.add_prim_attr('reduction', self.reduction) def infer_shape(self, x_shape, t_shape, w_shape): @@ -1964,7 +1964,10 @@ class NLLLoss(PrimitiveWithInfer): validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) - validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) + if len(x_shape) == 1: + validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name) + else: + validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) if self.reduction == "none": return t_shape, () return (), ()