|
|
|
@@ -678,6 +678,13 @@ class FusedBatchNormEx(PrimitiveWithInfer): |
|
|
|
>>> op = P.FusedBatchNormEx() |
|
|
|
>>> output = op(input_x, scale, bias, mean, variance) |
|
|
|
""" |
|
|
|
__mindspore_signature__ = ( |
|
|
|
('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), |
|
|
|
('scale', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('bias', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('mean', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
('variance', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
|
) |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): |
|
|
|
|