From: @zhaosida1994 Reviewed-by: @jjfeing,@kisnwang Signed-off-by: @jjfeingtags/v1.1.0
| @@ -95,6 +95,8 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| @@ -47,6 +47,19 @@ AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||||
| return inp->Clone()->Broaden(); | return inp->Clone()->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors. | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| (void)CheckDtypeSame(op_name, out, dout); | |||||
| (void)CheckShapeSame(op_name, out, dout); | |||||
| return out->Broaden(); | |||||
| } | |||||
| AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: two tensors. | // Inputs: two tensors. | ||||
| @@ -41,6 +41,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, | {prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, | ||||
| {prim::kPrimSquare, {InferImplSquare, true}}, | {prim::kPrimSquare, {InferImplSquare, true}}, | ||||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | {prim::kPrimSqrt, {InferImplSqrt, true}}, | ||||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||||
| {prim::kPrimSub, {InferImplSub, true}}, | {prim::kPrimSub, {InferImplSub, true}}, | ||||
| {prim::kPrimEqual, {InferImplEqual, true}}, | {prim::kPrimEqual, {InferImplEqual, true}}, | ||||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | {prim::kPrimMinimum, {InferImplMinimum, true}}, | ||||
| @@ -241,6 +241,7 @@ inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("Inplace | |||||
| inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | ||||
| inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | ||||
| inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | ||||
| inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"); | |||||
| inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | ||||
| inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); | ||||
| inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs"); | inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs"); | ||||
| @@ -260,6 +260,7 @@ from .cumprod import _cumprop_tbe | |||||
| from .reduce_prod import _reduce_prod_tbe | from .reduce_prod import _reduce_prod_tbe | ||||
| from .reciprocal_grad import _reciprocal_grad_tbe | from .reciprocal_grad import _reciprocal_grad_tbe | ||||
| from .sqrt_grad import _sqrt_grad_tbe | from .sqrt_grad import _sqrt_grad_tbe | ||||
| from .sqrt_grad_ds import _sqrt_grad_ds_tbe | |||||
| from .rsqrt_grad import _rsqrt_grad_tbe | from .rsqrt_grad import _rsqrt_grad_tbe | ||||
| from .flatten_grad import _flatten_grad_tbe | from .flatten_grad import _flatten_grad_tbe | ||||
| from .scatter_add import _scatter_add_tbe | from .scatter_add import _scatter_add_tbe | ||||
| @@ -0,0 +1,44 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SqrtGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| sqrt_grad_op_info = TBERegOp("SqrtGrad") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("sqrt_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("sqrt_grad") \ | |||||
| .partial_flag(True) \ | |||||
| .dynamic_shape(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "dy", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sqrt_grad_op_info) | |||||
| def _sqrt_grad_ds_tbe(): | |||||
| """SqrtGrad TBE register""" | |||||
| return | |||||