|
|
|
@@ -1303,8 +1303,18 @@ class BatchNorm(PrimitiveWithInfer): |
|
|
|
[ 1.00000000e+00, 1.00000000e+00])) |
|
|
|
""" |
|
|
|
|
|
|
|
__mindspore_signature__ = ( |
|
|
|
sig.make_sig('input_x', dtype=sig.sig_dtype.T1), |
|
|
|
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), |
|
|
|
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), |
|
|
|
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3), |
|
|
|
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3) |
|
|
|
) |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"): |
|
|
|
if is_training is False: |
|
|
|
self.set_signatures(tuple()) |
|
|
|
validator.check_value_type('is_training', is_training, (bool,), self.name) |
|
|
|
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) |
|
|
|
validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) |
|
|
|
|