diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index bc9491e4cf..dc58bb6653 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -64,6 +64,7 @@ #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" +#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/optimize_dependence.h" #include "backend/optimizer/pass/erase_visit_attr.h" @@ -231,6 +232,7 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap auto optimizer = std::make_shared(); auto mixed_precision_pm = std::make_shared("cast_pm"); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.cc new file mode 100644 index 0000000000..2cfb9cfb3e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "base/core_ops.h" + +namespace mindspore { +namespace opt { +const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimExtractImagePatches, Xs}); +} + +const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2); + MS_EXCEPTION_IF_NULL(extract); + auto in_node = extract->input(1); + MS_EXCEPTION_IF_NULL(in_node); + auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract); + auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node); + MS_EXCEPTION_IF_NULL(extract_kernel_build_info); + MS_EXCEPTION_IF_NULL(in_node_kernel_build_info); + std::vector reshape_inputs = {NewValueNode(std::make_shared(prim::kPrimReshape->name())), + in_node}; + auto reshape_builder = std::make_shared(); + reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0}); + reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0}); + reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); + reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); + reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type()); + reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type()); + reshape_builder->SetProcessor(in_node_kernel_build_info->processor()); + + auto reshape = func_graph->NewCNode(reshape_inputs); + reshape->set_scope(in_node->scope()); + auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}, + {{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get()); + AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get()); + AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape); + AnfAlgo::SetNodeInput(extract, reshape, 0); + return extract; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h new file mode 100644 index 0000000000..1344d7f2bc --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass { + public: + explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true) + : PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {} + ~InsertReshapeForExtractImagePatchesOp() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 47b82d1435..9c5ffc0e58 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -516,6 +516,10 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); } + if (node->isa() && GetCNodeName(node) == kExtractImagePatchesOpName) { + auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]}; + return trans::TransShapeToDevice(shape_tmp, format); + } return trans::TransShapeToDevice(infer_shape, format); } diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 58f15c79b4..b9225b62a1 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -104,6 +104,7 @@ inline const PrimitivePtr kPrimPad = std::make_shared("Pad"); inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); inline const PrimitivePtr kPrimUnique = std::make_shared("Unique"); inline const PrimitivePtr kPrimUniqueGrad = std::make_shared("UniqueGrad"); +inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared("ExtractImagePatches"); // NN inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index d1ed0ec69f..bb8fc2e02c 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -542,12 +542,16 @@ class Unfold(Cell): self.transpose = P.Transpose() self.format_NHWC = (0, 2, 3, 1) self.format_NCHW = (0, 3, 1, 2) + self.is_ge = context.get_context("enable_ge") def construct(self, input_x): - x_transpose = self.transpose(input_x, self.format_NHWC) - ret = self.extract_image_patches(x_transpose) - ret_transpose = self.transpose(ret, self.format_NCHW) - return ret_transpose + if self.is_ge: + x_transpose = self.transpose(input_x, self.format_NHWC) + ret = self.extract_image_patches(x_transpose) + result = self.transpose(ret, self.format_NCHW) + else: + result = self.extract_image_patches(input_x) + return result @constexpr diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 246e4a6363..f1a4937dc5 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor from .grad_base import bprop_getters from .. import functional as F from .. import operations as P +from ...common import dtype as mstype from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations import _grad_ops as G from ..operations import _inner_ops as inner @@ -75,11 +76,43 @@ def get_bprop_extract_image_patches(self): fill = P.Fill() slice_op = P.Slice() transpose = P.Transpose() + cast = P.Cast() matmul = P.MatMul() _, ksizes_row, ksizes_col, _ = self.ksizes def bprop(x, out, dout): + x_shape = get_shape(x) + x_batch, x_depth, x_row, x_col = x_shape + x_indices_num = x_row * x_col + 1 + x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32) + x_idx = reshape(x_idx, (1, 1, x_row, x_col)) + x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32) + x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1)) + + out_shape = get_shape(out) + _, _, out_row, out_col = out_shape + out_indices_num = out_row * out_col * ksizes_row * ksizes_col + out_idx = F.tuple_to_array(range(out_indices_num)) + out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col)) + + idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1))) + idx_tensor = reshape(idx_tensor, (-1, 2)) + sp_shape = (x_indices_num, out_indices_num) + sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape) + sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num)) + + grad = transpose(dout, (0, 2, 3, 1)) + grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)) + grad = transpose(grad, (1, 2, 3, 4, 0, 5)) + grad = reshape(grad, (-1, x_batch * x_depth)) + + jac = matmul(sp_tensor, grad) + dx = reshape(jac, (x_row, x_col, x_batch, x_depth)) + dx = transpose(dx, (2, 3, 0, 1)) + return (dx,) + + def bprop_ge(x, out, dout): x_shape = get_shape(x) x_batch, x_row, x_col, x_depth = x_shape x_indices_num = x_row * x_col + 1 @@ -109,6 +142,9 @@ def get_bprop_extract_image_patches(self): return (dx,) + if context.get_context("enable_ge"): + return bprop_ge + return bprop diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 9d240beff4..71bd64c94b 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -17,6 +17,7 @@ from ..._checkparam import Rel from ..._checkparam import Validator as validator +from ... import context from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register from ..operations.math_ops import _infer_shape_reduce @@ -200,10 +201,13 @@ class ExtractImagePatches(PrimitiveWithInfer): self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) self.add_prim_attr("io_format", "NHWC") + self.is_ge = context.get_context("enable_ge") def infer_shape(self, input_x): """infer shape""" - in_batch, in_row, in_col, in_depth = input_x + in_batch, in_depth, in_row, in_col = input_x + if self.is_ge: + in_batch, in_row, in_col, in_depth = input_x _, ksize_row, ksize_col, _ = self.ksizes _, stride_row, stride_col, _ = self.strides _, rate_row, rate_col, _ = self.rates @@ -223,7 +227,9 @@ class ExtractImagePatches(PrimitiveWithInfer): out_row = (in_row - 1) // stride_row + 1 out_col = (in_col - 1) // stride_col + 1 - out_shape = [out_batch, out_row, out_col, out_depth] + out_shape = [out_batch, out_depth, out_row, out_col] + if self.is_ge: + out_shape = [out_batch, out_row, out_col, out_depth] return out_shape def infer_dtype(self, input_x):