Browse Source

!11294 fix_nll_loss

From: @jiangzg001
Reviewed-by: @wuxuejian,@linqingke
Signed-off-by: @linqingke
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
4b4363fb66
2 changed files with 9 additions and 3 deletions
  1. +4
    -1
      mindspore/ops/operations/_grad_ops.py
  2. +5
    -2
      mindspore/ops/operations/nn_ops.py

+ 4
- 1
mindspore/ops/operations/_grad_ops.py View File

@@ -1776,7 +1776,10 @@ class NLLLossGrad(PrimitiveWithInfer):
validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name)
validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name)
validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name)
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
if len(x_shape) == 1:
validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name)
else:
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_dtype, y_grad_dtype, t_dtype, w_dtype, tw_dtype):


+ 5
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -2076,7 +2076,7 @@ class NLLLoss(PrimitiveWithInfer):
def __init__(self, reduction="mean"):
"""Initialize NLLLoss"""
self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss'])
self.reduction = validator.check_string(reduction.lower(), ['none', 'sum', 'mean'], 'reduction', self.name)
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
self.add_prim_attr('reduction', self.reduction)

def infer_shape(self, x_shape, t_shape, w_shape):
@@ -2084,7 +2084,10 @@ class NLLLoss(PrimitiveWithInfer):
validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name)
validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name)
validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name)
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
if len(x_shape) == 1:
validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name)
else:
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
if self.reduction == "none":
return t_shape, ()
return (), ()


Loading…
Cancel
Save