| @@ -75,7 +75,7 @@ PrimitivePy::~PrimitivePy() { | |||||
| void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } | void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } | ||||
| void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | ||||
| signatures_ = signatures; | signatures_ = signatures; | ||||
| set_has_signature(true); | |||||
| set_has_signature(!signatures.empty()); | |||||
| } | } | ||||
| py::function PrimitivePy::GetBpropFunction() { | py::function PrimitivePy::GetBpropFunction() { | ||||
| @@ -1303,8 +1303,18 @@ class BatchNorm(PrimitiveWithInfer): | |||||
| [ 1.00000000e+00, 1.00000000e+00])) | [ 1.00000000e+00, 1.00000000e+00])) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | |||||
| sig.make_sig('input_x', dtype=sig.sig_dtype.T1), | |||||
| sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3), | |||||
| sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3) | |||||
| ) | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"): | def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"): | ||||
| if is_training is False: | |||||
| self.set_signatures(tuple()) | |||||
| validator.check_value_type('is_training', is_training, (bool,), self.name) | validator.check_value_type('is_training', is_training, (bool,), self.name) | ||||
| validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | ||||
| validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | ||||
| @@ -129,7 +129,7 @@ def test_sit_auto_mix_precision_model_o0(): | |||||
| model.train(1, dataset1, dataset_sink_mode=False) | model.train(1, dataset1, dataset_sink_mode=False) | ||||
| contend = read_validateir_file('./test_amp_o0') | contend = read_validateir_file('./test_amp_o0') | ||||
| castnum = re.findall("Cast", contend) | castnum = re.findall("Cast", contend) | ||||
| assert len(castnum) == 17 | |||||
| assert len(castnum) == 5 | |||||
| model.predict(Tensor(input_data)) | model.predict(Tensor(input_data)) | ||||
| contend = read_validateir_file('./test_amp_o0') | contend = read_validateir_file('./test_amp_o0') | ||||
| castnum = re.findall("Cast", contend) | castnum = re.findall("Cast", contend) | ||||
| @@ -109,8 +109,8 @@ class FusedBatchNorm(nn.Cell): | |||||
| self.bn_train(x, | self.bn_train(x, | ||||
| self.gamma, | self.gamma, | ||||
| self.beta, | self.beta, | ||||
| None, | |||||
| None) | |||||
| self.moving_mean, | |||||
| self.moving_variance) | |||||
| mean_sub = self.sub_mean(self.moving_mean, batch_mean) | mean_sub = self.sub_mean(self.moving_mean, batch_mean) | ||||
| temp_mean = self.mul_mean(mean_sub, self.momentum) | temp_mean = self.mul_mean(mean_sub, self.momentum) | ||||