Browse Source

add 4 grad ops

tags/v0.7.0-beta
fangzehua 5 years ago
parent
commit
a80432f08e
11 changed files with 228 additions and 18 deletions
  1. +6
    -1
      mindspore/ccsrc/transform/graph_ir/convert.cc
  2. +15
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare.cc
  3. +6
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare.h
  4. +6
    -15
      mindspore/ops/_grad/grad_math_ops.py
  5. +1
    -0
      mindspore/ops/_grad/grad_nn_ops.py
  6. +3
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  7. +2
    -2
      mindspore/ops/_op_impl/tbe/reciprocal.py
  8. +38
    -0
      mindspore/ops/_op_impl/tbe/reciprocal_grad.py
  9. +40
    -0
      mindspore/ops/_op_impl/tbe/rsqrt_grad.py
  10. +43
    -0
      mindspore/ops/_op_impl/tbe/sqrt_grad.py
  11. +68
    -0
      mindspore/ops/operations/_grad_ops.py

+ 6
- 1
mindspore/ccsrc/transform/graph_ir/convert.cc View File

@@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum";
const char kNameIsFinite[] = "isFinite";
const char kNameReciprocal[] = "Reciprocal";
const char kNameRsqrt[] = "Rsqrt";
const char kNameRsqrtGrad[] = "RsqrtGrad";
const char kNameSqrt[] = "Sqrt";
const char kNameSquare[] = "Square";
const char kNameSquaredDifference[] = "SquaredDifference";
@@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad";
const char kNameConvolution[] = "Convolution";
const char kNameBiasAdd[] = "BiasAdd";
const char kNameMaxPoolGrad[] = "MaxPoolGrad";
const char kNameRsqrtGrad[] = "RsqrtGrad";
const char kNameSqrtGrad[] = "SqrtGrad";
const char kNameReciprocalGrad[] = "ReciprocalGrad";
const char kNameAvgPoolGrad[] = "AvgPoolGrad";
const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
const char kNameApplyMomentum[] = "ApplyMomentum";
@@ -233,6 +235,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameAllgather), ADPT_DESC(HcomAllGather)},
{string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)},
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
{string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)},
{string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)},
{string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)},
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},


+ 15
- 0
mindspore/ccsrc/transform/graph_ir/op_declare.cc View File

@@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}};

// RsqrtGrad
INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}};

// SqrtGrad
INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}};

// ReciprocalGrad
INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}};

// avgpoolgrad
INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}};
ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},


+ 6
- 0
mindspore/ccsrc/transform/graph_ir/op_declare.h View File

@@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool)
DECLARE_OP_USE_OUTPUT(MaxPool)
DECLARE_OP_ADAPTER(MaxPoolGrad)
DECLARE_OP_USE_OUTPUT(MaxPoolGrad)
DECLARE_OP_ADAPTER(SqrtGrad)
DECLARE_OP_USE_OUTPUT(SqrtGrad)
DECLARE_OP_ADAPTER(ReciprocalGrad)
DECLARE_OP_USE_OUTPUT(ReciprocalGrad)
DECLARE_OP_ADAPTER(RsqrtGrad)
DECLARE_OP_USE_OUTPUT(RsqrtGrad)
DECLARE_OP_ADAPTER(AvgPool)
DECLARE_OP_USE_OUTPUT(AvgPool)
DECLARE_OP_ADAPTER(AvgPoolGrad)


+ 6
- 15
mindspore/ops/_grad/grad_math_ops.py View File

@@ -366,15 +366,10 @@ def get_bprop_square(self):
@bprop_getters.register(P.Sqrt)
def get_bprop_sqrt(self):
"""Grad definition for `Sqrt` operation."""
mul_func = P.Mul()
fill_func = P.Fill()
div_op = P.RealDiv()
sqrt = P.Sqrt()
dtype = P.DType()
sqrt_grad = G.SqrtGrad()

def bprop(x, out, dout):
temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x))
dx = mul_func(dout, temp)
dx = sqrt_grad(out, dout)
return (dx,)

return bprop
@@ -383,10 +378,10 @@ def get_bprop_sqrt(self):
@bprop_getters.register(P.Rsqrt)
def get_bprop_rsqrt(self):
"""Grad definition for `Rsqrt` operation."""
rsqrt_grad = G.RsqrtGrad()

def bprop(x, out, dout):
grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x)
dx = dout * grad
dx = rsqrt_grad(out, dout)
return (dx,)

return bprop
@@ -395,14 +390,10 @@ def get_bprop_rsqrt(self):
@bprop_getters.register(P.Reciprocal)
def get_bprop_reciprocal(self):
"""Grad definition for `Reciprocal` operation."""
neg = P.Neg()
mul = P.Mul()
square = P.Square()
reciprocal = P.Reciprocal()
reciprocal_grad = G.ReciprocalGrad()

def bprop(x, out, dout):
g = neg(reciprocal(square(x)))
dx = mul(dout, g)
dx = reciprocal_grad(out, dout)
return (dx,)

return bprop


+ 1
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -442,6 +442,7 @@ def get_bprop_softmax(self):
sub = P.Sub()
mul = P.Mul()
axis = self.axis

def bprop(x, out, dout):
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
return (dx,)


+ 3
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe
from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe
from .reciprocal_grad import _reciprocal_grad_tbe
from .sqrt_grad import _sqrt_grad_tbe
from .rsqrt_grad import _rsqrt_grad_tbe
from .flatten_grad import _flatten_grad_tbe
from .scatter_add import _scatter_add_tbe
from .atan2 import _atan2_tbe


+ 2
- 2
mindspore/ops/_op_impl/tbe/reciprocal.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================

"""Add op"""
"""Reciprocal op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reciprocal_op_info = TBERegOp("Reciprocal") \
@@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \

@op_info_register(reciprocal_op_info)
def _reciprocal_tbe():
"""Add TBE register"""
"""Reciprocal TBE register"""
return

+ 38
- 0
mindspore/ops/_op_impl/tbe/reciprocal_grad.py View File

@@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ReciprocalGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reciprocal_grad.so") \
.compute_cost(10) \
.kernel_name("reciprocal_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()


@op_info_register(reciprocal_grad_op_info)
def _reciprocal_grad_tbe():
"""ReciprocalGrad TBE register"""
return

+ 40
- 0
mindspore/ops/_op_impl/tbe/rsqrt_grad.py View File

@@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""RsqrtGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("rsqrt_grad.so") \
.compute_cost(10) \
.kernel_name("rsqrt_grad") \
.partial_flag(True) \
.op_pattern("broadcast") \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()


@op_info_register(rsqrt_grad_op_info)
def _rsqrt_grad_tbe():
"""RsqrtGrad TBE register"""
return

+ 43
- 0
mindspore/ops/_op_impl/tbe/sqrt_grad.py View File

@@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""SqrtGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

sqrt_grad_op_info = TBERegOp("SqrtGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("sqrt_grad.so") \
.compute_cost(10) \
.kernel_name("sqrt_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()


@op_info_register(sqrt_grad_op_info)
def _sqrt_grad_tbe():
"""SqrtGrad TBE register"""
return

+ 68
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer):
return x


class ReciprocalGrad(PrimitiveWithInfer):
"""Performs grad of Reciprocal operation."""

@prim_attr_register
def __init__(self):
"""init ReciprocalGrad"""

def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype


class RsqrtGrad(PrimitiveWithInfer):
"""Performs grad of Rsqrt operation."""

@prim_attr_register
def __init__(self):
"""init RsqrtGrad"""

def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
return x_dtype


class SoftmaxGrad(PrimitiveWithInfer):
"""Performs grad of Softmax operation."""

@prim_attr_register
def __init__(self):
"""init SoftmaxGrad"""

def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype


class SqrtGrad(PrimitiveWithInfer):
"""Performs grad of Sqrt operation."""

@prim_attr_register
def __init__(self):
"""init SqrtGrad"""

def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape

def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype


class BatchNormGrad(PrimitiveWithInfer):
"""Performs grad of BatchNorm operation."""



Loading…
Cancel
Save