Merge pull request !6653 from chujinjin/fix_stream_sync_error_for_mixed_precisiontags/v1.1.0
| @@ -260,7 +260,7 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { | |||||
| py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { | py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { | ||||
| auto tensor = py::cast<tensor::TensorPtr>(obj); | auto tensor = py::cast<tensor::TensorPtr>(obj); | ||||
| auto cast_type = tensor->cast_dtype(); | auto cast_type = tensor->cast_dtype(); | ||||
| py::object cast_output; | |||||
| py::object cast_output = obj; | |||||
| if (cast_type != nullptr) { | if (cast_type != nullptr) { | ||||
| auto source_element = tensor->Dtype(); | auto source_element = tensor->Dtype(); | ||||
| if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { | if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { | ||||
| @@ -282,6 +282,8 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { | |||||
| result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); | result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); | ||||
| } else if (py::isinstance<py::tuple>(tuple[i])) { | } else if (py::isinstance<py::tuple>(tuple[i])) { | ||||
| result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); | result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); | ||||
| } else { | |||||
| result[i] = tuple[i]; | |||||
| } | } | ||||
| } | } | ||||
| return result; | return result; | ||||
| @@ -609,6 +609,13 @@ class FusedBatchNorm(Primitive): | |||||
| >>> op = P.FusedBatchNorm() | >>> op = P.FusedBatchNorm() | ||||
| >>> output = op(input_x, scale, bias, mean, variance) | >>> output = op(input_x, scale, bias, mean, variance) | ||||
| """ | """ | ||||
| __mindspore_signature__ = ( | |||||
| sig.make_sig('input_x', dtype=sig.sig_dtype.T2), | |||||
| sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||||
| ) | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): | def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): | ||||
| @@ -1394,11 +1394,6 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], | 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], | ||||
| 'desc_bprop': [[2, 16], [16], [16]], | 'desc_bprop': [[2, 16], [16], [16]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('FusedBatchNorm', { | |||||
| 'block': P.FusedBatchNorm(), | |||||
| 'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]], | |||||
| 'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]], | |||||
| 'skip': []}), | |||||
| ('FusedBatchNormGrad', { | ('FusedBatchNormGrad', { | ||||
| 'block': G.FusedBatchNormGrad(), | 'block': G.FusedBatchNormGrad(), | ||||
| 'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]], | 'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]], | ||||