| @@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum"; | |||||
| const char kNameIsFinite[] = "isFinite"; | const char kNameIsFinite[] = "isFinite"; | ||||
| const char kNameReciprocal[] = "Reciprocal"; | const char kNameReciprocal[] = "Reciprocal"; | ||||
| const char kNameRsqrt[] = "Rsqrt"; | const char kNameRsqrt[] = "Rsqrt"; | ||||
| const char kNameRsqrtGrad[] = "RsqrtGrad"; | |||||
| const char kNameSqrt[] = "Sqrt"; | const char kNameSqrt[] = "Sqrt"; | ||||
| const char kNameSquare[] = "Square"; | const char kNameSquare[] = "Square"; | ||||
| const char kNameSquaredDifference[] = "SquaredDifference"; | const char kNameSquaredDifference[] = "SquaredDifference"; | ||||
| @@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad"; | |||||
| const char kNameConvolution[] = "Convolution"; | const char kNameConvolution[] = "Convolution"; | ||||
| const char kNameBiasAdd[] = "BiasAdd"; | const char kNameBiasAdd[] = "BiasAdd"; | ||||
| const char kNameMaxPoolGrad[] = "MaxPoolGrad"; | const char kNameMaxPoolGrad[] = "MaxPoolGrad"; | ||||
| const char kNameRsqrtGrad[] = "RsqrtGrad"; | |||||
| const char kNameSqrtGrad[] = "SqrtGrad"; | |||||
| const char kNameReciprocalGrad[] = "ReciprocalGrad"; | |||||
| const char kNameAvgPoolGrad[] = "AvgPoolGrad"; | const char kNameAvgPoolGrad[] = "AvgPoolGrad"; | ||||
| const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; | const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; | ||||
| const char kNameApplyMomentum[] = "ApplyMomentum"; | const char kNameApplyMomentum[] = "ApplyMomentum"; | ||||
| @@ -233,6 +235,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, | {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, | ||||
| {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, | {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, | ||||
| {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, | {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, | ||||
| {string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)}, | |||||
| {string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)}, | |||||
| {string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)}, | |||||
| {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, | {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, | ||||
| {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, | {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, | ||||
| {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, | {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, | ||||
| @@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits< | |||||
| {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; | ||||
| OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; | ||||
| // RsqrtGrad | |||||
| INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; | |||||
| ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}}; | |||||
| // SqrtGrad | |||||
| INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; | |||||
| ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}}; | |||||
| // ReciprocalGrad | |||||
| INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; | |||||
| ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}}; | |||||
| // avgpoolgrad | // avgpoolgrad | ||||
| INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; | INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; | ||||
| ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, | ||||
| @@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool) | |||||
| DECLARE_OP_USE_OUTPUT(MaxPool) | DECLARE_OP_USE_OUTPUT(MaxPool) | ||||
| DECLARE_OP_ADAPTER(MaxPoolGrad) | DECLARE_OP_ADAPTER(MaxPoolGrad) | ||||
| DECLARE_OP_USE_OUTPUT(MaxPoolGrad) | DECLARE_OP_USE_OUTPUT(MaxPoolGrad) | ||||
| DECLARE_OP_ADAPTER(SqrtGrad) | |||||
| DECLARE_OP_USE_OUTPUT(SqrtGrad) | |||||
| DECLARE_OP_ADAPTER(ReciprocalGrad) | |||||
| DECLARE_OP_USE_OUTPUT(ReciprocalGrad) | |||||
| DECLARE_OP_ADAPTER(RsqrtGrad) | |||||
| DECLARE_OP_USE_OUTPUT(RsqrtGrad) | |||||
| DECLARE_OP_ADAPTER(AvgPool) | DECLARE_OP_ADAPTER(AvgPool) | ||||
| DECLARE_OP_USE_OUTPUT(AvgPool) | DECLARE_OP_USE_OUTPUT(AvgPool) | ||||
| DECLARE_OP_ADAPTER(AvgPoolGrad) | DECLARE_OP_ADAPTER(AvgPoolGrad) | ||||
| @@ -366,15 +366,10 @@ def get_bprop_square(self): | |||||
| @bprop_getters.register(P.Sqrt) | @bprop_getters.register(P.Sqrt) | ||||
| def get_bprop_sqrt(self): | def get_bprop_sqrt(self): | ||||
| """Grad definition for `Sqrt` operation.""" | """Grad definition for `Sqrt` operation.""" | ||||
| mul_func = P.Mul() | |||||
| fill_func = P.Fill() | |||||
| div_op = P.RealDiv() | |||||
| sqrt = P.Sqrt() | |||||
| dtype = P.DType() | |||||
| sqrt_grad = G.SqrtGrad() | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x)) | |||||
| dx = mul_func(dout, temp) | |||||
| dx = sqrt_grad(out, dout) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -383,10 +378,10 @@ def get_bprop_sqrt(self): | |||||
| @bprop_getters.register(P.Rsqrt) | @bprop_getters.register(P.Rsqrt) | ||||
| def get_bprop_rsqrt(self): | def get_bprop_rsqrt(self): | ||||
| """Grad definition for `Rsqrt` operation.""" | """Grad definition for `Rsqrt` operation.""" | ||||
| rsqrt_grad = G.RsqrtGrad() | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x) | |||||
| dx = dout * grad | |||||
| dx = rsqrt_grad(out, dout) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -395,14 +390,10 @@ def get_bprop_rsqrt(self): | |||||
| @bprop_getters.register(P.Reciprocal) | @bprop_getters.register(P.Reciprocal) | ||||
| def get_bprop_reciprocal(self): | def get_bprop_reciprocal(self): | ||||
| """Grad definition for `Reciprocal` operation.""" | """Grad definition for `Reciprocal` operation.""" | ||||
| neg = P.Neg() | |||||
| mul = P.Mul() | |||||
| square = P.Square() | |||||
| reciprocal = P.Reciprocal() | |||||
| reciprocal_grad = G.ReciprocalGrad() | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| g = neg(reciprocal(square(x))) | |||||
| dx = mul(dout, g) | |||||
| dx = reciprocal_grad(out, dout) | |||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -442,6 +442,7 @@ def get_bprop_softmax(self): | |||||
| sub = P.Sub() | sub = P.Sub() | ||||
| mul = P.Mul() | mul = P.Mul() | ||||
| axis = self.axis | axis = self.axis | ||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) | dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) | ||||
| return (dx,) | return (dx,) | ||||
| @@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe | |||||
| from .apply_rms_prop import _apply_rms_prop_tbe | from .apply_rms_prop import _apply_rms_prop_tbe | ||||
| from .cumprod import _cumprop_tbe | from .cumprod import _cumprop_tbe | ||||
| from .reduce_prod import _reduce_prod_tbe | from .reduce_prod import _reduce_prod_tbe | ||||
| from .reciprocal_grad import _reciprocal_grad_tbe | |||||
| from .sqrt_grad import _sqrt_grad_tbe | |||||
| from .rsqrt_grad import _rsqrt_grad_tbe | |||||
| from .flatten_grad import _flatten_grad_tbe | from .flatten_grad import _flatten_grad_tbe | ||||
| from .scatter_add import _scatter_add_tbe | from .scatter_add import _scatter_add_tbe | ||||
| from .atan2 import _atan2_tbe | from .atan2 import _atan2_tbe | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Add op""" | |||||
| """Reciprocal op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| reciprocal_op_info = TBERegOp("Reciprocal") \ | reciprocal_op_info = TBERegOp("Reciprocal") \ | ||||
| @@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \ | |||||
| @op_info_register(reciprocal_op_info) | @op_info_register(reciprocal_op_info) | ||||
| def _reciprocal_tbe(): | def _reciprocal_tbe(): | ||||
| """Add TBE register""" | |||||
| """Reciprocal TBE register""" | |||||
| return | return | ||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ReciprocalGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("reciprocal_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("reciprocal_grad") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "dy", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .op_pattern("broadcast") \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||||
| .get_op_info() | |||||
| @op_info_register(reciprocal_grad_op_info) | |||||
| def _reciprocal_grad_tbe(): | |||||
| """ReciprocalGrad TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """RsqrtGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("rsqrt_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("rsqrt_grad") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "dy", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ | |||||
| .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ | |||||
| .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ | |||||
| .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ | |||||
| .get_op_info() | |||||
| @op_info_register(rsqrt_grad_op_info) | |||||
| def _rsqrt_grad_tbe(): | |||||
| """RsqrtGrad TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SqrtGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| sqrt_grad_op_info = TBERegOp("SqrtGrad") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("sqrt_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("sqrt_grad") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "dy", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sqrt_grad_op_info) | |||||
| def _sqrt_grad_tbe(): | |||||
| """SqrtGrad TBE register""" | |||||
| return | |||||
| @@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| class ReciprocalGrad(PrimitiveWithInfer): | |||||
| """Performs grad of Reciprocal operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init ReciprocalGrad""" | |||||
| def infer_shape(self, x_shape, dout_shape): | |||||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | |||||
| args = {"x": x_dtype, "dout": dout_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | |||||
| class RsqrtGrad(PrimitiveWithInfer): | |||||
| """Performs grad of Rsqrt operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init RsqrtGrad""" | |||||
| def infer_shape(self, x_shape, dout_shape): | |||||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | |||||
| args = {"x": x_dtype, "dout": dout_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) | |||||
| return x_dtype | |||||
| class SoftmaxGrad(PrimitiveWithInfer): | |||||
| """Performs grad of Softmax operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init SoftmaxGrad""" | |||||
| def infer_shape(self, x_shape, dout_shape): | |||||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | |||||
| args = {"x": x_dtype, "dout": dout_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | |||||
| class SqrtGrad(PrimitiveWithInfer): | |||||
| """Performs grad of Sqrt operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init SqrtGrad""" | |||||
| def infer_shape(self, x_shape, dout_shape): | |||||
| validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | |||||
| args = {"x": x_dtype, "dout": dout_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | |||||
| class BatchNormGrad(PrimitiveWithInfer): | class BatchNormGrad(PrimitiveWithInfer): | ||||
| """Performs grad of BatchNorm operation.""" | """Performs grad of BatchNorm operation.""" | ||||