| @@ -46,6 +46,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"reduce_max", "reduce_max_d"}, | |||
| {"reduce_min", "reduce_min_d"}, | |||
| {"avg_pool_grad", "avg_pool_grad_d"}, | |||
| {"avg_pool_grad_vm", "avg_pool_grad_d"}, | |||
| {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, | |||
| {"conv2d_backprop_input", "conv2d_backprop_input_d"}, | |||
| {"depthwise_conv2d_native", "depthwise_conv2d"}, | |||
| @@ -26,6 +26,7 @@ namespace opt { | |||
| ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(prim::kPrimCast->name(), {1}); | |||
| Register(prim::kPrimAvgPoolGrad->name(), {0}); | |||
| Register(prim::kPrimAvgPoolGradVm->name(), {0}); | |||
| Register(prim::kPrimConv2DBackpropInput->name(), {2}); | |||
| Register(prim::kPrimConv2DBackpropFilter->name(), {2}); | |||
| Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); | |||
| @@ -128,6 +128,7 @@ inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||
| inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||
| inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); | |||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||
| inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm"); | |||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | |||
| @@ -33,7 +33,6 @@ from .activation import get_activation | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', | |||
| 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] | |||
| @@ -14,7 +14,10 @@ | |||
| # ============================================================================ | |||
| """Define the grad rules of neural network related operations.""" | |||
| import numpy as np | |||
| from mindspore.ops import _selected_grad_ops as SG | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.common.tensor import Tensor | |||
| from .grad_base import bprop_getters | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| @@ -24,7 +27,6 @@ from ..operations import _inner_ops as inner | |||
| from ... import context | |||
| @bprop_getters.register(P.BiasAdd) | |||
| def get_bprop_bias_add(self): | |||
| """Grad definition for `BiasAdd` operation.""" | |||
| @@ -195,33 +197,133 @@ def get_bprop_max_pool_grad(self): | |||
| return bprop | |||
| def _windowed_output_size(input_size, ksize, stride, padding): | |||
| """ | |||
| helper func for AvgPoolGrad | |||
| """ | |||
| tmp_output = 0 | |||
| tmp_pad_need = 0 | |||
| tmp_pad_before = 0 | |||
| tmp_pad_after = 0 | |||
| if padding == 'VALID': | |||
| tmp_output = (input_size - ksize + stride) // stride | |||
| tmp_pad_before = 0 | |||
| tmp_pad_after = 0 | |||
| elif padding == 'SAME': | |||
| tmp_output = (input_size + stride - 1) // stride | |||
| tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size) | |||
| tmp_pad_before = tmp_pad_need // 2 | |||
| tmp_pad_after = tmp_pad_need - tmp_pad_before | |||
| return tmp_output, tmp_pad_before, tmp_pad_after | |||
| @constexpr | |||
| def _get_mean_matrix(x_shape, ksize, stride, padding, x_dtype): | |||
| """ | |||
| helper func for AvgPoolGrad. | |||
| `assist_input_matrix` is a 2d matrix with input_shape after padding, | |||
| the value of element which is padded is 0, else are 1. | |||
| For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, | |||
| w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the | |||
| number of input that assosiate with output element. | |||
| """ | |||
| n_input, c_input, h_input, w_input = x_shape | |||
| h_ksize, w_ksize = ksize[2], ksize[3] | |||
| h_stride, w_stride = stride[2], stride[3] | |||
| n_output = n_input | |||
| c_output = c_input | |||
| h_output, w_output = 0, 0 | |||
| pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 | |||
| h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize, | |||
| h_stride, padding) | |||
| w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize, | |||
| w_stride, padding) | |||
| output_size = n_output * c_output * h_output * w_output | |||
| output_shape = (n_output, c_output, h_output, w_output) | |||
| output = np.array([0.0] * output_size) | |||
| output = np.reshape(output, output_shape) | |||
| in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right) | |||
| assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32) | |||
| if pad_top > 0: | |||
| assist_input_matrix[:pad_top, :] = 0 | |||
| if pad_bottom > 0: | |||
| assist_input_matrix[-pad_bottom:, :] = 0 | |||
| if pad_left > 0: | |||
| assist_input_matrix[:, :pad_left] = 0 | |||
| if pad_right > 0: | |||
| assist_input_matrix[:, -pad_right:] = 0 | |||
| for h in range(h_output): | |||
| for w in range(w_output): | |||
| curr_input = assist_input_matrix[h*h_stride : h*h_stride + h_ksize, w*w_stride : w*w_stride + w_ksize] | |||
| curr_sum = np.sum(curr_input) | |||
| if curr_sum > 0: | |||
| output[:, :, h, w] = 1. / curr_sum | |||
| return Tensor(output, x_dtype) | |||
| @constexpr | |||
| def _get_kernel_matrix(kernel_matrix_shape, x_dtype): | |||
| kernel_matrix = np.ones(kernel_matrix_shape) | |||
| return Tensor(kernel_matrix, x_dtype) | |||
| @bprop_getters.register(P.AvgPool) | |||
| def get_bprop_avg_pool_grad(self): | |||
| """Grad definition for `AvgPool` operation.""" | |||
| avgpool_grad = G.AvgPoolGrad( | |||
| ksize=self.ksize, | |||
| strides=self.strides, | |||
| padding=self.padding) | |||
| shape_op = P.Shape() | |||
| avgpool_grad_gpu = G.AvgPoolGradGpu( | |||
| ksize=self.ksize, | |||
| strides=self.strides, | |||
| padding=self.padding) | |||
| def bprop(x, out, dout): | |||
| dx = avgpool_grad(shape_op(x), dout) | |||
| return (dx,) | |||
| def bprop_gpu(x, out, dout): | |||
| dx = avgpool_grad_gpu(x, out, dout) | |||
| return (dx,) | |||
| # the parameter of AvgPoolGrad in GPU and TBE/CPU is not same | |||
| if self.target == "GPU": | |||
| avgpool_grad_gpu = G.AvgPoolGradGpu( | |||
| ksize=self.ksize, | |||
| strides=self.strides, | |||
| padding=self.padding) | |||
| def bprop_gpu(x, out, dout): | |||
| dx = avgpool_grad_gpu(x, out, dout) | |||
| return (dx,) | |||
| bprop_fn = bprop_gpu | |||
| elif self.target == "GE": | |||
| avgpool_grad_ge = G.AvgPoolGrad( | |||
| ksize=self.ksize, | |||
| strides=self.strides, | |||
| padding=self.padding) | |||
| shape_op = P.Shape() | |||
| def bprop_ge(x, out, dout): | |||
| dx = avgpool_grad_ge(shape_op(x), dout) | |||
| return (dx,) | |||
| bprop_fn = bprop_ge | |||
| else: | |||
| bprop_fn = bprop | |||
| avgpool_grad_vm = G.AvgPoolGradVm( | |||
| ksize=self.ksize, | |||
| strides=self.strides, | |||
| padding=self.padding) | |||
| k_size_nchw = avgpool_grad_vm.ksize | |||
| stride_nchw = avgpool_grad_vm.strides | |||
| padding = self.padding | |||
| def bprop_vm(x, out, dout): | |||
| x_shape_nchw = F.shape(x) | |||
| x_dtype = F.dtype(x) | |||
| kernel_matrix_shape = (1, x_shape_nchw[1], | |||
| k_size_nchw[2], | |||
| k_size_nchw[3]) | |||
| mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, padding, x_dtype) | |||
| kernel_matrix = _get_kernel_matrix(kernel_matrix_shape, x_dtype) | |||
| dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix) | |||
| return (dx,) | |||
| bprop_fn = bprop_vm | |||
| return bprop_fn | |||
| @@ -196,6 +196,7 @@ from .floor_mod import _floor_mod_tbe | |||
| from .scatter_nd_update import _scatter_nd_update_tbe | |||
| from .avg_pool import _avg_pool_tbe | |||
| from .avg_pool_grad import _avg_pool_grad_tbe | |||
| from .avg_pool_grad_vm import _avg_pool_grad_vm_tbe | |||
| from .ones_like import _ones_like_tbe | |||
| from .batch_to_space import _batch_to_space_tbe | |||
| from .space_to_batch import _space_to_batch_tbe | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """AvgPoolGradVm op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| avg_pool_grad_vm_op_info = TBERegOp("AvgPoolGradVm") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("avg_pool_grad_d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("avg_pool_grad_d") \ | |||
| .partial_flag(True) \ | |||
| .attr("x_origin", "required", "listInt", "all") \ | |||
| .attr("ksize", "required", "listInt", "all") \ | |||
| .attr("strides", "required", "listInt", "all") \ | |||
| .attr("padding", "required", "str", "all") \ | |||
| .attr("data_format", "optional", "str", "all") \ | |||
| .input(0, "input_grad", False, "required", "all") \ | |||
| .input(1, "mean_matrix", False, "optional", "all") \ | |||
| .input(2, "kernel_matrix", False, "optional", "all") \ | |||
| .output(0, "out_grad", True, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_C1HWNCoC0, DataType.F16_5HD) \ | |||
| .get_op_info() | |||
| @op_info_register(avg_pool_grad_vm_op_info) | |||
| def _avg_pool_grad_vm_tbe(): | |||
| """AvgPoolGradVm TBE register""" | |||
| return | |||
| @@ -23,7 +23,6 @@ from .._utils import get_concat_offset | |||
| from ...common import dtype as mstype | |||
| from .. import functional as F | |||
| class AbsGrad(PrimitiveWithInfer): | |||
| """Computes gradients for abs operation.""" | |||
| @@ -492,7 +491,7 @@ class _PoolGrad(PrimitiveWithInfer): | |||
| class AvgPoolGrad(_PoolGrad): | |||
| """Gradients of the avg pool operation.""" | |||
| """Gradients of the avg pool operation for ge.""" | |||
| @prim_attr_register | |||
| def __init__(self, ksize=1, strides=1, padding="VALID"): | |||
| @@ -508,6 +507,24 @@ class AvgPoolGrad(_PoolGrad): | |||
| return out | |||
| class AvgPoolGradVm(_PoolGrad): | |||
| """Gradients of the avg pool operation for vm.""" | |||
| @prim_attr_register | |||
| def __init__(self, ksize=1, strides=1, padding="VALID"): | |||
| super(AvgPoolGradVm, self).__init__(ksize, strides, padding) | |||
| self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output']) | |||
| def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix): | |||
| out = { | |||
| 'value': None, | |||
| 'shape': tuple(origin_input['value']), | |||
| 'dtype': dout['dtype'], | |||
| } | |||
| return out | |||
| class AvgPoolGradGpu(_PoolGrad): | |||
| """Gradients of the avg pool operation for gpu.""" | |||
| @@ -1276,6 +1276,8 @@ class AvgPool(_Pool): | |||
| def __init__(self, ksize=1, strides=1, padding="valid"): | |||
| if context.get_context("device_target") == "GPU": | |||
| self.target = "GPU" | |||
| elif context.get_context("enable_ge"): | |||
| self.target = "GE" | |||
| else: | |||
| self.target = "OTHER" | |||
| super(AvgPool, self).__init__(ksize, strides, padding) | |||
| @@ -1311,13 +1311,6 @@ test_case_nn_ops = [ | |||
| 'block': P.AvgPool(ksize=(2, 2), strides=(2, 2), padding="VALID"), | |||
| 'desc_inputs': [[100, 3, 28, 28]], | |||
| 'desc_bprop': [[100, 3, 14, 14]]}), | |||
| ('AvgPoolGrad', { | |||
| 'block': G.AvgPoolGrad(ksize=(2, 2), strides=(2, 2), padding="VALID"), | |||
| 'desc_const': [(3, 4, 6, 6)], | |||
| 'const_first': True, | |||
| 'desc_inputs': [[3, 4, 6, 6]], | |||
| 'desc_bprop': [[3, 4, 6, 6]], | |||
| 'skip': ['backward']}), | |||
| ('MaxPoolWithArgmax', { | |||
| 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), | |||
| 'desc_inputs': [[128, 32, 32, 64]], | |||