| @@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum"; | |||
| const char kNameIsFinite[] = "isFinite"; | |||
| const char kNameReciprocal[] = "Reciprocal"; | |||
| const char kNameRsqrt[] = "Rsqrt"; | |||
| const char kNameRsqrtGrad[] = "RsqrtGrad"; | |||
| const char kNameSqrt[] = "Sqrt"; | |||
| const char kNameSquare[] = "Square"; | |||
| const char kNameSquaredDifference[] = "SquaredDifference"; | |||
| @@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad"; | |||
| const char kNameConvolution[] = "Convolution"; | |||
| const char kNameBiasAdd[] = "BiasAdd"; | |||
| const char kNameMaxPoolGrad[] = "MaxPoolGrad"; | |||
| const char kNameRsqrtGrad[] = "RsqrtGrad"; | |||
| const char kNameSqrtGrad[] = "SqrtGrad"; | |||
| const char kNameReciprocalGrad[] = "ReciprocalGrad"; | |||
| const char kNameAvgPoolGrad[] = "AvgPoolGrad"; | |||
| const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; | |||
| const char kNameApplyMomentum[] = "ApplyMomentum"; | |||
| @@ -233,6 +235,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, | |||
| {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, | |||
| {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, | |||
| {string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)}, | |||
| {string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)}, | |||
| {string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)}, | |||
| {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, | |||
| {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, | |||
| {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, | |||
| @@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits< | |||
| {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | |||
| OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; | |||
| // RsqrtGrad | |||
| INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; | |||
| ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}}; | |||
| // SqrtGrad | |||
| INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; | |||
| ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}}; | |||
| // ReciprocalGrad | |||
| INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; | |||
| ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}}; | |||
| // avgpoolgrad | |||
| INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; | |||
| ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | |||
| @@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool) | |||
| DECLARE_OP_USE_OUTPUT(MaxPool) | |||
| DECLARE_OP_ADAPTER(MaxPoolGrad) | |||
| DECLARE_OP_USE_OUTPUT(MaxPoolGrad) | |||
| DECLARE_OP_ADAPTER(SqrtGrad) | |||
| DECLARE_OP_USE_OUTPUT(SqrtGrad) | |||
| DECLARE_OP_ADAPTER(ReciprocalGrad) | |||
| DECLARE_OP_USE_OUTPUT(ReciprocalGrad) | |||
| DECLARE_OP_ADAPTER(RsqrtGrad) | |||
| DECLARE_OP_USE_OUTPUT(RsqrtGrad) | |||
| DECLARE_OP_ADAPTER(AvgPool) | |||
| DECLARE_OP_USE_OUTPUT(AvgPool) | |||
| DECLARE_OP_ADAPTER(AvgPoolGrad) | |||
| @@ -366,15 +366,10 @@ def get_bprop_square(self): | |||
| @bprop_getters.register(P.Sqrt) | |||
| def get_bprop_sqrt(self): | |||
| """Grad definition for `Sqrt` operation.""" | |||
| mul_func = P.Mul() | |||
| fill_func = P.Fill() | |||
| div_op = P.RealDiv() | |||
| sqrt = P.Sqrt() | |||
| dtype = P.DType() | |||
| sqrt_grad = G.SqrtGrad() | |||
| def bprop(x, out, dout): | |||
| temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x)) | |||
| dx = mul_func(dout, temp) | |||
| dx = sqrt_grad(out, dout) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -383,10 +378,10 @@ def get_bprop_sqrt(self): | |||
| @bprop_getters.register(P.Rsqrt) | |||
| def get_bprop_rsqrt(self): | |||
| """Grad definition for `Rsqrt` operation.""" | |||
| rsqrt_grad = G.RsqrtGrad() | |||
| def bprop(x, out, dout): | |||
| grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x) | |||
| dx = dout * grad | |||
| dx = rsqrt_grad(out, dout) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -395,14 +390,10 @@ def get_bprop_rsqrt(self): | |||
| @bprop_getters.register(P.Reciprocal) | |||
| def get_bprop_reciprocal(self): | |||
| """Grad definition for `Reciprocal` operation.""" | |||
| neg = P.Neg() | |||
| mul = P.Mul() | |||
| square = P.Square() | |||
| reciprocal = P.Reciprocal() | |||
| reciprocal_grad = G.ReciprocalGrad() | |||
| def bprop(x, out, dout): | |||
| g = neg(reciprocal(square(x))) | |||
| dx = mul(dout, g) | |||
| dx = reciprocal_grad(out, dout) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -442,6 +442,7 @@ def get_bprop_softmax(self): | |||
| sub = P.Sub() | |||
| mul = P.Mul() | |||
| axis = self.axis | |||
| def bprop(x, out, dout): | |||
| dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) | |||
| return (dx,) | |||
| @@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe | |||
| from .apply_rms_prop import _apply_rms_prop_tbe | |||
| from .cumprod import _cumprop_tbe | |||
| from .reduce_prod import _reduce_prod_tbe | |||
| from .reciprocal_grad import _reciprocal_grad_tbe | |||
| from .sqrt_grad import _sqrt_grad_tbe | |||
| from .rsqrt_grad import _rsqrt_grad_tbe | |||
| from .flatten_grad import _flatten_grad_tbe | |||
| from .scatter_add import _scatter_add_tbe | |||
| from .atan2 import _atan2_tbe | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Add op""" | |||
| """Reciprocal op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reciprocal_op_info = TBERegOp("Reciprocal") \ | |||
| @@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \ | |||
| @op_info_register(reciprocal_op_info) | |||
| def _reciprocal_tbe(): | |||
| """Add TBE register""" | |||
| """Reciprocal TBE register""" | |||
| return | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ReciprocalGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("reciprocal_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("reciprocal_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "dy", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("broadcast") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(reciprocal_grad_op_info) | |||
| def _reciprocal_grad_tbe(): | |||
| """ReciprocalGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RsqrtGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("rsqrt_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("rsqrt_grad") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("broadcast") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "dy", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||
| .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ | |||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||
| .get_op_info() | |||
| @op_info_register(rsqrt_grad_op_info) | |||
| def _rsqrt_grad_tbe(): | |||
| """RsqrtGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """SqrtGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sqrt_grad_op_info = TBERegOp("SqrtGrad") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sqrt_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sqrt_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "dy", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||
| .get_op_info() | |||
| @op_info_register(sqrt_grad_op_info) | |||
| def _sqrt_grad_tbe(): | |||
| """SqrtGrad TBE register""" | |||
| return | |||
| @@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer): | |||
| return x | |||
| class ReciprocalGrad(PrimitiveWithInfer): | |||
| """Performs grad of Reciprocal operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init ReciprocalGrad""" | |||
| def infer_shape(self, x_shape, dout_shape): | |||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class RsqrtGrad(PrimitiveWithInfer): | |||
| """Performs grad of Rsqrt operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init RsqrtGrad""" | |||
| def infer_shape(self, x_shape, dout_shape): | |||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) | |||
| return x_dtype | |||
| class SoftmaxGrad(PrimitiveWithInfer): | |||
| """Performs grad of Softmax operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init SoftmaxGrad""" | |||
| def infer_shape(self, x_shape, dout_shape): | |||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class SqrtGrad(PrimitiveWithInfer): | |||
| """Performs grad of Sqrt operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init SqrtGrad""" | |||
| def infer_shape(self, x_shape, dout_shape): | |||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class BatchNormGrad(PrimitiveWithInfer): | |||
| """Performs grad of BatchNorm operation.""" | |||