| @@ -64,6 +64,7 @@ | |||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | |||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/pass/optimize_dependence.h" | |||
| #include "backend/optimizer/pass/erase_visit_attr.h" | |||
| @@ -231,6 +232,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto mixed_precision_pm = std::make_shared<PassManager>("cast_pm"); | |||
| mixed_precision_pm->AddPass(std::make_shared<InsertCast>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<InsertReshapeForExtractImagePatchesOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/insert_reshape_for_extract_image_patches_op.h" | |||
| #include <memory> | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef InsertReshapeForExtractImagePatchesOp::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({prim::kPrimExtractImagePatches, Xs}); | |||
| } | |||
| const AnfNodePtr InsertReshapeForExtractImagePatchesOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto extract = CheckAnfNodeIfCNodeAndInputSize(node, 2); | |||
| MS_EXCEPTION_IF_NULL(extract); | |||
| auto in_node = extract->input(1); | |||
| MS_EXCEPTION_IF_NULL(in_node); | |||
| auto extract_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(extract); | |||
| auto in_node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(in_node); | |||
| MS_EXCEPTION_IF_NULL(extract_kernel_build_info); | |||
| MS_EXCEPTION_IF_NULL(in_node_kernel_build_info); | |||
| std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())), | |||
| in_node}; | |||
| auto reshape_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| reshape_builder->SetInputsFormat({kOpFormat_NC1HWC0}); | |||
| reshape_builder->SetOutputsFormat({kOpFormat_NC1HWC0}); | |||
| reshape_builder->SetInputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | |||
| reshape_builder->SetOutputsDeviceType({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}); | |||
| reshape_builder->SetKernelType(in_node_kernel_build_info->kernel_type()); | |||
| reshape_builder->SetFusionType(in_node_kernel_build_info->fusion_type()); | |||
| reshape_builder->SetProcessor(in_node_kernel_build_info->processor()); | |||
| auto reshape = func_graph->NewCNode(reshape_inputs); | |||
| reshape->set_scope(in_node->scope()); | |||
| auto shape_tmp = AnfAlgo::GetOutputInferShape(in_node, 0); | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputDeviceDataType(in_node, 0)}, | |||
| {{shape_tmp[0], shape_tmp[2], shape_tmp[3], shape_tmp[1]}}, reshape.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(reshape_builder->Build(), reshape.get()); | |||
| AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reshape); | |||
| AnfAlgo::SetNodeInput(extract, reshape, 0); | |||
| return extract; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "backend/optimizer/common/pattern_engine.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertReshapeForExtractImagePatchesOp : public PatternProcessPass { | |||
| public: | |||
| explicit InsertReshapeForExtractImagePatchesOp(bool multigraph = true) | |||
| : PatternProcessPass("insert_reshape_for_extract_image_patches_op", multigraph) {} | |||
| ~InsertReshapeForExtractImagePatchesOp() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_INSERT_RESHAPE_FOR_EXTRACT_IMAGE_PATCHES_OP_H | |||
| @@ -516,6 +516,10 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n | |||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); | |||
| } | |||
| if (node->isa<CNode>() && GetCNodeName(node) == kExtractImagePatchesOpName) { | |||
| auto shape_tmp = {infer_shape[0], infer_shape[3], infer_shape[1], infer_shape[2]}; | |||
| return trans::TransShapeToDevice(shape_tmp, format); | |||
| } | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| } | |||
| @@ -104,6 +104,7 @@ inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | |||
| inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | |||
| inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique"); | |||
| inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad"); | |||
| inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| @@ -542,12 +542,16 @@ class Unfold(Cell): | |||
| self.transpose = P.Transpose() | |||
| self.format_NHWC = (0, 2, 3, 1) | |||
| self.format_NCHW = (0, 3, 1, 2) | |||
| self.is_ge = context.get_context("enable_ge") | |||
| def construct(self, input_x): | |||
| x_transpose = self.transpose(input_x, self.format_NHWC) | |||
| ret = self.extract_image_patches(x_transpose) | |||
| ret_transpose = self.transpose(ret, self.format_NCHW) | |||
| return ret_transpose | |||
| if self.is_ge: | |||
| x_transpose = self.transpose(input_x, self.format_NHWC) | |||
| ret = self.extract_image_patches(x_transpose) | |||
| result = self.transpose(ret, self.format_NCHW) | |||
| else: | |||
| result = self.extract_image_patches(input_x) | |||
| return result | |||
| @constexpr | |||
| @@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor | |||
| from .grad_base import bprop_getters | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| from ...common import dtype as mstype | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations import _grad_ops as G | |||
| from ..operations import _inner_ops as inner | |||
| @@ -75,11 +76,43 @@ def get_bprop_extract_image_patches(self): | |||
| fill = P.Fill() | |||
| slice_op = P.Slice() | |||
| transpose = P.Transpose() | |||
| cast = P.Cast() | |||
| matmul = P.MatMul() | |||
| _, ksizes_row, ksizes_col, _ = self.ksizes | |||
| def bprop(x, out, dout): | |||
| x_shape = get_shape(x) | |||
| x_batch, x_depth, x_row, x_col = x_shape | |||
| x_indices_num = x_row * x_col + 1 | |||
| x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32) | |||
| x_idx = reshape(x_idx, (1, 1, x_row, x_col)) | |||
| x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32) | |||
| x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1)) | |||
| out_shape = get_shape(out) | |||
| _, _, out_row, out_col = out_shape | |||
| out_indices_num = out_row * out_col * ksizes_row * ksizes_col | |||
| out_idx = F.tuple_to_array(range(out_indices_num)) | |||
| out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col)) | |||
| idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1))) | |||
| idx_tensor = reshape(idx_tensor, (-1, 2)) | |||
| sp_shape = (x_indices_num, out_indices_num) | |||
| sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape) | |||
| sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num)) | |||
| grad = transpose(dout, (0, 2, 3, 1)) | |||
| grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth)) | |||
| grad = transpose(grad, (1, 2, 3, 4, 0, 5)) | |||
| grad = reshape(grad, (-1, x_batch * x_depth)) | |||
| jac = matmul(sp_tensor, grad) | |||
| dx = reshape(jac, (x_row, x_col, x_batch, x_depth)) | |||
| dx = transpose(dx, (2, 3, 0, 1)) | |||
| return (dx,) | |||
| def bprop_ge(x, out, dout): | |||
| x_shape = get_shape(x) | |||
| x_batch, x_row, x_col, x_depth = x_shape | |||
| x_indices_num = x_row * x_col + 1 | |||
| @@ -109,6 +142,9 @@ def get_bprop_extract_image_patches(self): | |||
| return (dx,) | |||
| if context.get_context("enable_ge"): | |||
| return bprop_ge | |||
| return bprop | |||
| @@ -17,6 +17,7 @@ | |||
| from ..._checkparam import Rel | |||
| from ..._checkparam import Validator as validator | |||
| from ... import context | |||
| from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| @@ -200,10 +201,13 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) | |||
| self.add_prim_attr("padding", self.padding) | |||
| self.add_prim_attr("io_format", "NHWC") | |||
| self.is_ge = context.get_context("enable_ge") | |||
| def infer_shape(self, input_x): | |||
| """infer shape""" | |||
| in_batch, in_row, in_col, in_depth = input_x | |||
| in_batch, in_depth, in_row, in_col = input_x | |||
| if self.is_ge: | |||
| in_batch, in_row, in_col, in_depth = input_x | |||
| _, ksize_row, ksize_col, _ = self.ksizes | |||
| _, stride_row, stride_col, _ = self.strides | |||
| _, rate_row, rate_col, _ = self.rates | |||
| @@ -223,7 +227,9 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| out_row = (in_row - 1) // stride_row + 1 | |||
| out_col = (in_col - 1) // stride_col + 1 | |||
| out_shape = [out_batch, out_row, out_col, out_depth] | |||
| out_shape = [out_batch, out_depth, out_row, out_col] | |||
| if self.is_ge: | |||
| out_shape = [out_batch, out_row, out_col, out_depth] | |||
| return out_shape | |||
| def infer_dtype(self, input_x): | |||