| @@ -38,6 +38,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"reduce_mean", "reduce_mean_d"}, | {"reduce_mean", "reduce_mean_d"}, | ||||
| {"reduce_max", "reduce_max_d"}, | {"reduce_max", "reduce_max_d"}, | ||||
| {"reduce_min", "reduce_min_d"}, | {"reduce_min", "reduce_min_d"}, | ||||
| {"avg_pool_grad", "avg_pool_grad_d"}, | |||||
| {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, | {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, | ||||
| {"conv2d_backprop_input", "conv2d_backprop_input_d"}, | {"conv2d_backprop_input", "conv2d_backprop_input_d"}, | ||||
| {"depthwise_conv2d_native", "depthwise_conv2d"}, | {"depthwise_conv2d_native", "depthwise_conv2d"}, | ||||
| @@ -170,6 +170,7 @@ const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||||
| const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | ||||
| const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | ||||
| const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | ||||
| const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||||
| const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | ||||
| const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | ||||
| const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | ||||
| @@ -178,6 +178,7 @@ extern const PrimitivePtr kPrimFusedBatchNorm; | |||||
| extern const PrimitivePtr kPrimConv2D; | extern const PrimitivePtr kPrimConv2D; | ||||
| extern const PrimitivePtr kPrimMaxPool; | extern const PrimitivePtr kPrimMaxPool; | ||||
| extern const PrimitivePtr kPrimMaxPoolGrad; | extern const PrimitivePtr kPrimMaxPoolGrad; | ||||
| extern const PrimitivePtr kPrimAvgPoolGrad; | |||||
| extern const PrimitivePtr kPrimFusedBatchNormGrad; | extern const PrimitivePtr kPrimFusedBatchNormGrad; | ||||
| extern const PrimitivePtr kPrimReluGrad; | extern const PrimitivePtr kPrimReluGrad; | ||||
| extern const PrimitivePtr kPrimConv2DBackpropInput; | extern const PrimitivePtr kPrimConv2DBackpropInput; | ||||
| @@ -25,6 +25,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | ||||
| Register(prim::kPrimCast->name(), {1}); | Register(prim::kPrimCast->name(), {1}); | ||||
| Register(prim::kPrimAvgPoolGrad->name(), {0}); | |||||
| Register(prim::kPrimConv2DBackpropInput->name(), {2}); | Register(prim::kPrimConv2DBackpropInput->name(), {2}); | ||||
| Register(prim::kPrimConv2DBackpropFilter->name(), {2}); | Register(prim::kPrimConv2DBackpropFilter->name(), {2}); | ||||
| Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); | Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); | ||||
| @@ -178,6 +178,7 @@ const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; | |||||
| const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; | const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; | ||||
| const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; | const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; | ||||
| const char kNameAcosh[] = "Acosh"; | const char kNameAcosh[] = "Acosh"; | ||||
| const char kNameAcoshGrad[] = "AcoshGrad"; | |||||
| const char kNameFloorMod[] = "FloorMod"; | const char kNameFloorMod[] = "FloorMod"; | ||||
| const char kNameSpaceToDepth[] = "SpaceToDepth"; | const char kNameSpaceToDepth[] = "SpaceToDepth"; | ||||
| const char kNameDepthToSpace[] = "DepthToSpace"; | const char kNameDepthToSpace[] = "DepthToSpace"; | ||||
| @@ -375,6 +376,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, | {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, | ||||
| {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, | {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, | ||||
| {string(kNameAcosh), ADPT_DESC(Acosh)}, | {string(kNameAcosh), ADPT_DESC(Acosh)}, | ||||
| {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, | |||||
| {string(kNameFloorMod), ADPT_DESC(FloorMod)}, | {string(kNameFloorMod), ADPT_DESC(FloorMod)}, | ||||
| {string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)}, | {string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)}, | ||||
| {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, | {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, | ||||
| @@ -357,6 +357,11 @@ INPUT_MAP(Acosh) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(Acosh) = EMPTY_ATTR_MAP; | ATTR_MAP(Acosh) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}}; | ||||
| // AcoshGrad | |||||
| INPUT_MAP(AcoshGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; | |||||
| ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}}; | |||||
| // Floor | // Floor | ||||
| INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(Floor) = EMPTY_ATTR_MAP; | ATTR_MAP(Floor) = EMPTY_ATTR_MAP; | ||||
| @@ -327,13 +327,15 @@ DECLARE_OP_ADAPTER(Const) | |||||
| DECLARE_OP_USE_OUTPUT(Const) | DECLARE_OP_USE_OUTPUT(Const) | ||||
| DECLARE_OP_ADAPTER(Cos) | DECLARE_OP_ADAPTER(Cos) | ||||
| DECLARE_OP_USE_OUTPUT(Cos) | DECLARE_OP_USE_OUTPUT(Cos) | ||||
| DECLARE_OP_ADAPTER(Acos) | DECLARE_OP_ADAPTER(Acos) | ||||
| DECLARE_OP_USE_OUTPUT(Acos) | DECLARE_OP_USE_OUTPUT(Acos) | ||||
| DECLARE_OP_ADAPTER(AcosGrad) | DECLARE_OP_ADAPTER(AcosGrad) | ||||
| DECLARE_OP_USE_OUTPUT(AcosGrad) | DECLARE_OP_USE_OUTPUT(AcosGrad) | ||||
| DECLARE_OP_ADAPTER(Acosh) | DECLARE_OP_ADAPTER(Acosh) | ||||
| DECLARE_OP_USE_OUTPUT(Acosh) | DECLARE_OP_USE_OUTPUT(Acosh) | ||||
| DECLARE_OP_ADAPTER(AcoshGrad) | |||||
| DECLARE_OP_USE_OUTPUT(AcoshGrad) | |||||
| DECLARE_OP_ADAPTER(Floor) | DECLARE_OP_ADAPTER(Floor) | ||||
| DECLARE_OP_USE_OUTPUT(Floor) | DECLARE_OP_USE_OUTPUT(Floor) | ||||
| @@ -21,6 +21,7 @@ from mindspore._checkparam import check_int_positive, check_bool | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops.functional import identity | from mindspore.ops.functional import identity | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore._extends import cell_attr_register | from mindspore._extends import cell_attr_register | ||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| @@ -480,7 +481,7 @@ class Unfold(Cell): | |||||
| """ | """ | ||||
| def __init__(self, ksizes, strides, rates, padding="valid"): | def __init__(self, ksizes, strides, rates, padding="valid"): | ||||
| super(Unfold, self).__init__() | super(Unfold, self).__init__() | ||||
| self.extract_image_patches = P.ExtractImagePatches(ksizes, strides, rates, padding) | |||||
| self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) | |||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| self.format_NHWC = (0, 2, 3, 1) | self.format_NHWC = (0, 2, 3, 1) | ||||
| self.format_NCHW = (0, 3, 1, 2) | self.format_NCHW = (0, 3, 1, 2) | ||||
| @@ -18,6 +18,7 @@ from mindspore.common import dtype as mstype | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..operations import _grad_ops as G | from ..operations import _grad_ops as G | ||||
| from ..operations import _inner_ops as inner | |||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| @@ -29,6 +30,7 @@ def get_bprop_bias_add(self): | |||||
| def bprop(x, w, out, dout): | def bprop(x, w, out, dout): | ||||
| return dout, bias_grad(dout) | return dout, bias_grad(dout) | ||||
| return bprop | return bprop | ||||
| @@ -49,18 +51,19 @@ def get_bprop_conv2d(self): | |||||
| dx = input_grad(dout, w, get_shape(x)) | dx = input_grad(dout, w, get_shape(x)) | ||||
| dw = filter_grad(dout, x, get_shape(w)) | dw = filter_grad(dout, x, get_shape(w)) | ||||
| return dx, dw | return dx, dw | ||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.ExtractImagePatches) | |||||
| @bprop_getters.register(inner.ExtractImagePatches) | |||||
| def get_bprop_extract_image_patches(self): | def get_bprop_extract_image_patches(self): | ||||
| """Grad definition for `ExtractImagePatches` operation.""" | """Grad definition for `ExtractImagePatches` operation.""" | ||||
| get_shape = P.Shape() | get_shape = P.Shape() | ||||
| reshape = P.Reshape() | reshape = P.Reshape() | ||||
| extract_image_patches = P.ExtractImagePatches(ksizes=self.ksizes, | |||||
| strides=self.strides, | |||||
| rates=self.rates, | |||||
| padding=self.padding) | |||||
| extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes, | |||||
| strides=self.strides, | |||||
| rates=self.rates, | |||||
| padding=self.padding) | |||||
| concat = P.Concat(axis=-1) | concat = P.Concat(axis=-1) | ||||
| expand_dims = P.ExpandDims() | expand_dims = P.ExpandDims() | ||||
| scatter_nd = P.ScatterNd() | scatter_nd = P.ScatterNd() | ||||
| @@ -104,6 +107,7 @@ def get_bprop_extract_image_patches(self): | |||||
| dx = transpose(dx, (2, 0, 1, 3)) | dx = transpose(dx, (2, 0, 1, 3)) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -124,6 +128,7 @@ def get_bprop_depthwise_conv2d_native(self): | |||||
| dx = input_grad(get_shape(x), w, dout) | dx = input_grad(get_shape(x), w, dout) | ||||
| dw = filter_grad(x, get_shape(w), dout) | dw = filter_grad(x, get_shape(w), dout) | ||||
| return dx, dw | return dx, dw | ||||
| return bprop | return bprop | ||||
| @@ -133,11 +138,12 @@ def get_bprop_max_pool_with_argmax(self): | |||||
| maxpool_grad = G.MaxPoolGradWithArgmax( | maxpool_grad = G.MaxPoolGradWithArgmax( | ||||
| ksize=self.ksize, | ksize=self.ksize, | ||||
| strides=self.strides, | strides=self.strides, | ||||
| padding=self.padding,) | |||||
| padding=self.padding) | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = maxpool_grad(x, dout[0], out[1]) | dx = maxpool_grad(x, dout[0], out[1]) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -152,6 +158,7 @@ def get_bprop_max_pool_grad(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = maxpool_grad(x, out, dout) | dx = maxpool_grad(x, out, dout) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -192,6 +199,7 @@ def get_bprop_dropout_gen_mask(self): | |||||
| def bprop(shape, keep_prob, out, dout): | def bprop(shape, keep_prob, out, dout): | ||||
| return (zeros_like(shape), zeros_like(keep_prob)) | return (zeros_like(shape), zeros_like(keep_prob)) | ||||
| return bprop | return bprop | ||||
| @@ -202,6 +210,7 @@ def get_bprop_dropout_do_mask(self): | |||||
| def bprop(x, y, keep_prob, out, dout): | def bprop(x, y, keep_prob, out, dout): | ||||
| return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob)) | return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob)) | ||||
| return bprop | return bprop | ||||
| @@ -213,6 +222,7 @@ def get_bprop_relu(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(dout, out) | dx = input_grad(dout, out) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -224,6 +234,7 @@ def get_bprop_relu6(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(dout, x) | dx = input_grad(dout, x) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -236,6 +247,7 @@ def get_bprop_relu_v2(self): | |||||
| mask = out[1] | mask = out[1] | ||||
| dx = input_grad(dout[0], mask) | dx = input_grad(dout[0], mask) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -247,6 +259,7 @@ def get_bprop_hswish(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(dout, x) | dx = input_grad(dout, x) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -258,6 +271,7 @@ def get_bprop_hsigmoid(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(dout, x) | dx = input_grad(dout, x) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -269,6 +283,7 @@ def get_bprop_elu(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(dout, x) | dx = input_grad(dout, x) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -280,6 +295,7 @@ def get_bprop_sigmoid(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(out, dout) | dx = input_grad(out, dout) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -294,6 +310,7 @@ def get_bprop_softmax(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out) | dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -305,6 +322,7 @@ def get_bprop_log_softmax(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = logsoftmax_grad(out, dout) | dx = logsoftmax_grad(out, dout) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -316,6 +334,7 @@ def get_bprop_tanh(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = logsoftmax_grad(out, dout) | dx = logsoftmax_grad(out, dout) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -327,6 +346,7 @@ def get_bprop_gelu(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(dout, x, out) | dx = input_grad(dout, x, out) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -343,6 +363,7 @@ def get_bprop_fused_batch_norm(self): | |||||
| dscale = out[1] | dscale = out[1] | ||||
| dbias = out[2] | dbias = out[2] | ||||
| return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) | return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) | ||||
| return bprop | return bprop | ||||
| @@ -366,6 +387,7 @@ def get_bprop_batch_norm(self): | |||||
| dscale = out[1] | dscale = out[1] | ||||
| dbias = out[2] | dbias = out[2] | ||||
| return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) | return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) | ||||
| return bprop | return bprop | ||||
| @@ -377,6 +399,7 @@ def get_bprop_layer_norm(self): | |||||
| def bprop(x, gamma, beta, out, dout): | def bprop(x, gamma, beta, out, dout): | ||||
| dx, d_gamma, d_beta = layer_norm_grad(x, dout[0], out[2], out[1], gamma) | dx, d_gamma, d_beta = layer_norm_grad(x, dout[0], out[2], out[1], gamma) | ||||
| return dx, d_gamma, d_beta | return dx, d_gamma, d_beta | ||||
| return bprop | return bprop | ||||
| @@ -388,6 +411,7 @@ def get_bprop_l2normalize(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = input_grad(x, out, dout) | dx = input_grad(x, out, dout) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -400,6 +424,7 @@ def get_bprop_softmax_cross_entropy_with_logits(self): | |||||
| grad = out[1] | grad = out[1] | ||||
| grad = grad * expand(dout[0], -1) | grad = grad * expand(dout[0], -1) | ||||
| return grad, zeros_like(labels) | return grad, zeros_like(labels) | ||||
| return bprop | return bprop | ||||
| @@ -417,6 +442,7 @@ def get_bprop_sparse_softmax_cross_entropy_with_logits(self): | |||||
| grad = F.depend(grad, out) | grad = F.depend(grad, out) | ||||
| grad = grad * dout | grad = grad * dout | ||||
| return grad, zeros_like(labels) | return grad, zeros_like(labels) | ||||
| return bprop | return bprop | ||||
| @@ -428,6 +454,7 @@ def get_bprop_resize_bilinear(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = resize_grad(dout, x) | dx = resize_grad(dout, x) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -437,6 +464,7 @@ def get_bprop_onehot(self): | |||||
| def bprop(indices, depth, on_value, off_value, out, dout): | def bprop(indices, depth, on_value, off_value, out, dout): | ||||
| return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value) | return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value) | ||||
| return bprop | return bprop | ||||
| @@ -453,6 +481,7 @@ def get_bprop_top_kv2(self): | |||||
| updates = dout[0] | updates = dout[0] | ||||
| shapes = shape_op(input_x) | shapes = shape_op(input_x) | ||||
| return scatter(indices, updates, shapes), zeros_like(k) | return scatter(indices, updates, shapes), zeros_like(k) | ||||
| return bprop | return bprop | ||||
| @@ -518,6 +547,7 @@ def get_bprop_lstm(self): | |||||
| dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state) | dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state) | ||||
| dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state) | dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state) | ||||
| return dx, dhx, dcx, dw | return dx, dhx, dcx, dw | ||||
| return bprop | return bprop | ||||
| @@ -529,6 +559,7 @@ def get_bprop_sigmoid_crossentropy_with_logits(self): | |||||
| def bprop(x, y, out, dout): | def bprop(x, y, out, dout): | ||||
| dx = op(x, y, dout) | dx = op(x, y, dout) | ||||
| return (dx, zeros_like(y)) | return (dx, zeros_like(y)) | ||||
| return bprop | return bprop | ||||
| @@ -545,6 +576,7 @@ def get_bprop_pad(self): | |||||
| shp = shape_op(x) | shp = shape_op(x) | ||||
| dx = P.Slice()(dout, begin, shp) | dx = P.Slice()(dout, begin, shp) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @@ -556,6 +588,7 @@ def get_bprop_mirror_pad(self): | |||||
| def bprop(x, paddings, out, dout): | def bprop(x, paddings, out, dout): | ||||
| dx = mirror_pad_grad(dout, paddings, x) | dx = mirror_pad_grad(dout, paddings, x) | ||||
| return (dx, zeros_like(paddings)) | return (dx, zeros_like(paddings)) | ||||
| return bprop | return bprop | ||||
| @@ -151,3 +151,5 @@ from .greater_equal import _greater_equal_tbe | |||||
| from .not_equal import _not_equal_tbe | from .not_equal import _not_equal_tbe | ||||
| from .floor_mod import _floor_mod_tbe | from .floor_mod import _floor_mod_tbe | ||||
| from .scatter_nd_update import _scatter_nd_update_tbe | from .scatter_nd_update import _scatter_nd_update_tbe | ||||
| from .avg_pool import _avg_pool_tbe | |||||
| from .avg_pool_grad import _avg_pool_grad_tbe | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """AvgPool op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| avg_pool_op_info = TBERegOp("AvgPool") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("avg_pool.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("avg_pool") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("ksize", "required", "listInt", "all") \ | |||||
| .attr("strides", "required", "listInt", "all") \ | |||||
| .attr("padding", "required", "str", "all") \ | |||||
| .attr("data_format", "optional", "str", "all") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(avg_pool_op_info) | |||||
| def _avg_pool_tbe(): | |||||
| """AvgPool TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """AvgPoolGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| avg_pool_grad_op_info = TBERegOp("AvgPoolGrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("avg_pool_grad_d.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("avg_pool_grad_d") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("x_origin", "required", "listInt", "all") \ | |||||
| .attr("ksize", "required", "listInt", "all") \ | |||||
| .attr("strides", "required", "listInt", "all") \ | |||||
| .attr("padding", "required", "str", "all") \ | |||||
| .attr("data_format", "optional", "str", "all") \ | |||||
| .input(0, "input_grad", False, "required", "all") \ | |||||
| .input(1, "mean_matrix", False, "optional", "all") \ | |||||
| .input(2, "kernel_matrix", False, "optional", "all") \ | |||||
| .output(0, "out_grad", True, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(avg_pool_grad_op_info) | |||||
| def _avg_pool_grad_tbe(): | |||||
| """AvgPoolGrad TBE register""" | |||||
| return | |||||
| @@ -57,7 +57,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||||
| Gelu, Elu, | Gelu, Elu, | ||||
| GetNext, L2Normalize, LayerNorm, L2Loss, | GetNext, L2Normalize, LayerNorm, L2Loss, | ||||
| LogSoftmax, | LogSoftmax, | ||||
| MaxPool, ExtractImagePatches, | |||||
| MaxPool, | |||||
| AvgPool, Conv2DBackpropInput, ConfusionMulGrad, | AvgPool, Conv2DBackpropInput, ConfusionMulGrad, | ||||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | ||||
| ResizeBilinear, Sigmoid, | ResizeBilinear, Sigmoid, | ||||
| @@ -89,7 +89,6 @@ __all__ = [ | |||||
| 'Sqrt', | 'Sqrt', | ||||
| 'Square', | 'Square', | ||||
| 'Conv2D', | 'Conv2D', | ||||
| 'ExtractImagePatches', | |||||
| 'Flatten', | 'Flatten', | ||||
| 'MaxPoolWithArgmax', | 'MaxPoolWithArgmax', | ||||
| 'FusedBatchNorm', | 'FusedBatchNorm', | ||||
| @@ -59,6 +59,23 @@ class ACosGrad(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| class AcoshGrad(PrimitiveWithInfer): | |||||
| """Performs grad of Acosh operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init AcoshGrad""" | |||||
| def infer_shape(self, x, dout): | |||||
| validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) | |||||
| return x | |||||
| def infer_dtype(self, x, dout): | |||||
| args = {"x": x, "dout": dout} | |||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| return x | |||||
| class BatchNormGrad(PrimitiveWithInfer): | class BatchNormGrad(PrimitiveWithInfer): | ||||
| """Performs grad of BatchNorm operation.""" | """Performs grad of BatchNorm operation.""" | ||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Inner operators.""" | |||||
| from ..._checkparam import Validator as validator | |||||
| from ...common import dtype as mstype | |||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||||
| class ExtractImagePatches(PrimitiveWithInfer): | |||||
| """ | |||||
| Extract patches from images. | |||||
| The input tensor must be a 4-D tensor and the data format is NHWC. | |||||
| Args: | |||||
| ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int, | |||||
| and the format is [1, ksize_row, ksize_col, 1]. | |||||
| strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, | |||||
| should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. | |||||
| rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim | |||||
| pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1]. | |||||
| padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", | |||||
| not case sensitive. Default: "valid". | |||||
| - same: Means that the patch can take the part beyond the original image, and this part is filled with 0. | |||||
| - valid: Means that the patch area taken must be completely contained in the original image. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and | |||||
| data type is int8, float16, uint8. | |||||
| Outputs: | |||||
| Tensor, a 4-D tensor whose data type is same as 'input_x', | |||||
| and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, ksizes, strides, rates, padding="valid"): | |||||
| """init""" | |||||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | |||||
| validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | |||||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | |||||
| raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " | |||||
| f"{arg_name}_col, 1], but got {arg_val}.") | |||||
| if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: | |||||
| raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " | |||||
| f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " | |||||
| f"is {arg_val[2]}") | |||||
| _check_tuple_or_list("ksize", ksizes, self.name) | |||||
| _check_tuple_or_list("stride", strides, self.name) | |||||
| _check_tuple_or_list("rate", rates, self.name) | |||||
| self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) | |||||
| self.add_prim_attr("padding", self.padding) | |||||
| def infer_shape(self, input_x): | |||||
| """infer shape""" | |||||
| in_batch, in_row, in_col, in_depth = input_x | |||||
| _, ksize_row, ksize_col, _ = self.ksizes | |||||
| _, stride_row, stride_col, _ = self.strides | |||||
| _, rate_row, rate_col, _ = self.rates | |||||
| if len(input_x) != 4: | |||||
| raise ValueError("The `input_x` should be a 4-D tensor, " | |||||
| f"but got a {len(input_x)}-D tensor whose shape is {input_x}") | |||||
| out_batch = in_batch | |||||
| out_depth = ksize_row * ksize_col * in_depth | |||||
| if self.padding == "VALID": | |||||
| out_row = \ | |||||
| (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1 | |||||
| out_col = \ | |||||
| (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1 | |||||
| else: | |||||
| out_row = (in_row - 1) // stride_row + 1 | |||||
| out_col = (in_col - 1) // stride_col + 1 | |||||
| out_shape = [out_batch, out_row, out_col, out_depth] | |||||
| return out_shape | |||||
| def infer_dtype(self, input_x): | |||||
| """infer dtype""" | |||||
| validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) | |||||
| return input_x | |||||
| @@ -2654,82 +2654,6 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| return var_type | return var_type | ||||
| class ExtractImagePatches(PrimitiveWithInfer): | |||||
| """ | |||||
| Extract patches from images. | |||||
| The input tensor must be a 4-D tensor and the data format is NHWC. | |||||
| Args: | |||||
| ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int, | |||||
| and the format is [1, ksize_row, ksize_col, 1]. | |||||
| strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, | |||||
| should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1]. | |||||
| rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim | |||||
| pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1]. | |||||
| padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", | |||||
| not case sensitive. Default: "valid". | |||||
| - same: Means that the patch can take the part beyond the original image, and this part is filled with 0. | |||||
| - valid: Means that the patch area taken must be completely contained in the original image. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and | |||||
| data type is int8, float16, uint8. | |||||
| Outputs: | |||||
| Tensor, a 4-D tensor whose data type is same as 'input_x', | |||||
| and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, ksizes, strides, rates, padding="valid"): | |||||
| """init""" | |||||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | |||||
| validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | |||||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | |||||
| raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " | |||||
| f"{arg_name}_col, 1], but got {arg_val}.") | |||||
| if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: | |||||
| raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " | |||||
| f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " | |||||
| f"is {arg_val[2]}") | |||||
| _check_tuple_or_list("ksize", ksizes, self.name) | |||||
| _check_tuple_or_list("stride", strides, self.name) | |||||
| _check_tuple_or_list("rate", rates, self.name) | |||||
| self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) | |||||
| self.add_prim_attr("padding", self.padding) | |||||
| def infer_shape(self, input_x): | |||||
| in_batch, in_row, in_col, in_depth = input_x | |||||
| _, ksize_row, ksize_col, _ = self.ksizes | |||||
| _, stride_row, stride_col, _ = self.strides | |||||
| _, rate_row, rate_col, _ = self.rates | |||||
| if len(input_x) != 4: | |||||
| raise ValueError("The `input_x` should be a 4-D tensor, " | |||||
| f"but got a {len(input_x)}-D tensor whose shape is {input_x}") | |||||
| out_batch = in_batch | |||||
| out_depth = ksize_row * ksize_col * in_depth | |||||
| if self.padding == "VALID": | |||||
| out_row = \ | |||||
| (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1 | |||||
| out_col = \ | |||||
| (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1 | |||||
| else: | |||||
| out_row = (in_row - 1) // stride_row + 1 | |||||
| out_col = (in_col - 1) // stride_col + 1 | |||||
| out_shape = [out_batch, out_row, out_col, out_depth] | |||||
| return out_shape | |||||
| def infer_dtype(self, input_x): | |||||
| validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) | |||||
| return input_x | |||||
| class ConfusionMulGrad(PrimitiveWithInfer): | class ConfusionMulGrad(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| `output0` is the result of which input0 dot multily input1. | `output0` is the result of which input0 dot multily input1. | ||||
| @@ -265,8 +265,8 @@ test_case_math_ops = [ | |||||
| 'desc_bprop': [[2, 3]]}), | 'desc_bprop': [[2, 3]]}), | ||||
| ('Acosh', { | ('Acosh', { | ||||
| 'block': P.Acosh(), | 'block': P.Acosh(), | ||||
| 'desc_inputs': [Tensor(np.random.rand(4).astype(np.float16))], | |||||
| 'skip': ['backward']}), | |||||
| 'desc_inputs': [[3, 4, 5]], | |||||
| 'desc_bprop': [[3, 4, 5]]}), | |||||
| ('Sin', { | ('Sin', { | ||||
| 'block': P.Sin(), | 'block': P.Sin(), | ||||
| 'desc_inputs': [[2, 3]], | 'desc_inputs': [[2, 3]], | ||||