|
|
|
@@ -1884,6 +1884,7 @@ class SmoothL1Loss(PrimitiveWithInfer): |
|
|
|
validator.check_value_type('beta', beta, [float], self.name) |
|
|
|
validator.check('beta', beta, '', 0, Rel.GT, self.name) |
|
|
|
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) |
|
|
|
self.add_prim_attr('sigma', beta) |
|
|
|
|
|
|
|
def infer_shape(self, prediction, target): |
|
|
|
validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) |
|
|
|
|