|
|
|
@@ -2076,7 +2076,7 @@ class NLLLoss(PrimitiveWithInfer): |
|
|
|
def __init__(self, reduction="mean"): |
|
|
|
"""Initialize NLLLoss""" |
|
|
|
self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss']) |
|
|
|
self.reduction = validator.check_string(reduction.lower(), ['none', 'sum', 'mean'], 'reduction', self.name) |
|
|
|
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) |
|
|
|
self.add_prim_attr('reduction', self.reduction) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, t_shape, w_shape): |
|
|
|
@@ -2084,7 +2084,10 @@ class NLLLoss(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) |
|
|
|
validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) |
|
|
|
validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) |
|
|
|
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) |
|
|
|
if len(x_shape) == 1: |
|
|
|
validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name) |
|
|
|
else: |
|
|
|
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) |
|
|
|
if self.reduction == "none": |
|
|
|
return t_shape, () |
|
|
|
return (), () |
|
|
|
|