| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/mindir/space_batch_nd_attr_update.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "runtime/device/kernel_info.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "base/core_ops.h" | |||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| constexpr size_t kBlockShapeDimNum = 2; | |||||
| constexpr auto kAttrBlockShape = "block_shape"; | |||||
| constexpr auto kAttrPaddings = "paddings"; | |||||
| constexpr auto kAttrCrops = "crops"; | |||||
| } // namespace | |||||
| const BaseRef SpaceToBatchNDAttrUpdate::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| VectorRef pattern({prim::kPrimSpaceToBatchND, X}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr SpaceToBatchNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto block_shape = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrBlockShape); | |||||
| if (block_shape.size() == kBlockShapeDimNum) { | |||||
| block_shape.insert(block_shape.begin(), 1); | |||||
| AnfAlgo::SetNodeAttr(kAttrBlockShape, MakeValue(block_shape), node); | |||||
| } | |||||
| auto paddings = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(node, kAttrPaddings); | |||||
| if (paddings.size() == kBlockShapeDimNum) { | |||||
| paddings.emplace(paddings.begin(), std::vector<int64_t>{0, 0}); | |||||
| AnfAlgo::SetNodeAttr(kAttrPaddings, MakeValue(paddings), node); | |||||
| } | |||||
| return node; | |||||
| } | |||||
| const BaseRef BatchToSpaceNDAttrUpdate::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| VectorRef pattern({prim::kPrimBatchToSpaceND, X}); | |||||
| return pattern; | |||||
| } | |||||
| const AnfNodePtr BatchToSpaceNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto block_shape = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrBlockShape); | |||||
| if (block_shape.size() == kBlockShapeDimNum) { | |||||
| block_shape.insert(block_shape.begin(), 1); | |||||
| AnfAlgo::SetNodeAttr(kAttrBlockShape, MakeValue(block_shape), node); | |||||
| } | |||||
| auto crops = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(node, kAttrCrops); | |||||
| if (crops.size() == kBlockShapeDimNum) { | |||||
| crops.emplace(crops.begin(), std::vector<int64_t>{0, 0}); | |||||
| AnfAlgo::SetNodeAttr(kAttrCrops, MakeValue(crops), node); | |||||
| } | |||||
| return node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class SpaceToBatchNDAttrUpdate : public PatternProcessPass { | |||||
| public: | |||||
| explicit SpaceToBatchNDAttrUpdate(bool multigraph = true) | |||||
| : PatternProcessPass("space_to_batch_nd_attr_update", multigraph) {} | |||||
| ~SpaceToBatchNDAttrUpdate() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| class BatchToSpaceNDAttrUpdate : public PatternProcessPass { | |||||
| public: | |||||
| explicit BatchToSpaceNDAttrUpdate(bool multigraph = true) | |||||
| : PatternProcessPass("batch_to_space_nd_attr_update", multigraph) {} | |||||
| ~BatchToSpaceNDAttrUpdate() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_ | |||||
| @@ -33,6 +33,7 @@ | |||||
| #include "runtime/device/ascend/profiling/profiling_manager.h" | #include "runtime/device/ascend/profiling/profiling_manager.h" | ||||
| #include "backend/optimizer/ascend/ascend_backend_optimization.h" | #include "backend/optimizer/ascend/ascend_backend_optimization.h" | ||||
| #include "backend/optimizer/common/common_backend_optimization.h" | #include "backend/optimizer/common/common_backend_optimization.h" | ||||
| #include "backend/optimizer/ascend/mindir/space_batch_nd_attr_update.h" | |||||
| #include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" | #include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" | ||||
| #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" | #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" | ||||
| #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" | #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" | ||||
| @@ -209,6 +210,8 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { | |||||
| } | } | ||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm"); | auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm"); | ||||
| unify_mindir_pm->AddPass(std::make_shared<opt::SpaceToBatchNDAttrUpdate>()); | |||||
| unify_mindir_pm->AddPass(std::make_shared<opt::BatchToSpaceNDAttrUpdate>()); | |||||
| unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>()); | unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>()); | ||||
| unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>()); | unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>()); | ||||
| unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>()); | unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>()); | ||||
| @@ -129,6 +129,8 @@ inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUnif | |||||
| inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | ||||
| inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask"); | inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask"); | ||||
| inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range"); | inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range"); | ||||
| inline const PrimitivePtr kPrimSpaceToBatchND = std::make_shared<Primitive>("SpaceToBatchND"); | |||||
| inline const PrimitivePtr kPrimBatchToSpaceND = std::make_shared<Primitive>("BatchToSpaceND"); | |||||
| // NN | // NN | ||||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| @@ -940,7 +940,7 @@ def get_bprop_batch_to_space(self): | |||||
| @bprop_getters.register(P.SpaceToBatchND) | @bprop_getters.register(P.SpaceToBatchND) | ||||
| def get_bprop_space_to_batch_nd(self): | def get_bprop_space_to_batch_nd(self): | ||||
| """Generate bprop for SpaceToBatchND""" | """Generate bprop for SpaceToBatchND""" | ||||
| space_to_batch_nd_grad = P.BatchToSpaceND(self.ori_block_shape, self.ori_paddings) | |||||
| space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings) | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = space_to_batch_nd_grad(dout) | dx = space_to_batch_nd_grad(dout) | ||||
| return (dx,) | return (dx,) | ||||
| @@ -950,7 +950,7 @@ def get_bprop_space_to_batch_nd(self): | |||||
| @bprop_getters.register(P.BatchToSpaceND) | @bprop_getters.register(P.BatchToSpaceND) | ||||
| def get_bprop_batch_to_space_nd(self): | def get_bprop_batch_to_space_nd(self): | ||||
| """Generate bprop for BatchToSpaceND""" | """Generate bprop for BatchToSpaceND""" | ||||
| batch_to_space_nd_grad = P.SpaceToBatchND(self.ori_block_shape, self.ori_crops) | |||||
| batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops) | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = batch_to_space_nd_grad(dout) | dx = batch_to_space_nd_grad(dout) | ||||
| return (dx,) | return (dx,) | ||||
| @@ -992,7 +992,7 @@ class Split(PrimitiveWithCheck): | |||||
| if output_valid_check != 0: | if output_valid_check != 0: | ||||
| raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" | raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" | ||||
| f" output_num {self.output_num}") | f" output_num {self.output_num}") | ||||
| size_splits = [x_shape[self.axis] / self.output_num] * self.output_num | |||||
| size_splits = [x_shape[self.axis] // self.output_num] * self.output_num | |||||
| self.add_prim_attr('size_splits', size_splits) | self.add_prim_attr('size_splits', size_splits) | ||||
| @@ -3893,6 +3893,8 @@ class SpaceToBatch(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, block_size, paddings): | def __init__(self, block_size, paddings): | ||||
| """Initialize SpaceToBatch""" | """Initialize SpaceToBatch""" | ||||
| logger.warning("WARN_DEPRECATED: The usage of SpaceToBatch is deprecated." | |||||
| " Please use SpaceToBatchND.") | |||||
| validator.check_value_type('block_size', block_size, [int], self.name) | validator.check_value_type('block_size', block_size, [int], self.name) | ||||
| validator.check('block_size', block_size, '', 2, Rel.GE, self.name) | validator.check('block_size', block_size, '', 2, Rel.GE, self.name) | ||||
| self.block_size = block_size | self.block_size = block_size | ||||
| @@ -3969,6 +3971,8 @@ class BatchToSpace(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, block_size, crops): | def __init__(self, block_size, crops): | ||||
| """Initialize BatchToSpace""" | """Initialize BatchToSpace""" | ||||
| logger.warning("WARN_DEPRECATED: The usage of BatchToSpace is deprecated." | |||||
| " Please use BatchToSpaceND.") | |||||
| validator.check_value_type('block_size', block_size, [int], self.name) | validator.check_value_type('block_size', block_size, [int], self.name) | ||||
| validator.check('block_size', block_size, '', 2, Rel.GE, self.name) | validator.check('block_size', block_size, '', 2, Rel.GE, self.name) | ||||
| self.block_size = block_size | self.block_size = block_size | ||||
| @@ -4009,8 +4013,10 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| the spatial dimensions of the input are zero padded according to paddings if necessary. | the spatial dimensions of the input are zero padded according to paddings if necessary. | ||||
| Args: | Args: | ||||
| block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value greater than 1. | |||||
| The length of `block_shape` is M corresponding to the number of spatial dimensions. M must be 2. | |||||
| block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater | |||||
| than 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding to the | |||||
| number of spatial dimensions. If `block_shape` is a int, the block size of M dimendions are the same, | |||||
| equal to `block_shape`. M must be 2. | |||||
| paddings (Union[tuple, list]): The padding values for H and W dimension, containing 2 subtraction list. | paddings (Union[tuple, list]): The padding values for H and W dimension, containing 2 subtraction list. | ||||
| Each contains 2 integer value. All values must be greater than 0. | Each contains 2 integer value. All values must be greater than 0. | ||||
| `paddings[i]` specifies the paddings for the spatial dimension i, | `paddings[i]` specifies the paddings for the spatial dimension i, | ||||
| @@ -4051,8 +4057,9 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, block_shape, paddings): | def __init__(self, block_shape, paddings): | ||||
| """Initialize SpaceToBatchND""" | """Initialize SpaceToBatchND""" | ||||
| self.ori_block_shape = block_shape | |||||
| self.ori_paddings = paddings | |||||
| if isinstance(block_shape, int): | |||||
| block_shape = (block_shape,) * 2 | |||||
| self.add_prim_attr("block_shape", block_shape) | |||||
| validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) | validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) | ||||
| validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) | validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) | ||||
| block_rank = len(block_shape) | block_rank = len(block_shape) | ||||
| @@ -4069,10 +4076,6 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| validator.check_non_negative_int(elem, 'paddings element', self.name) | validator.check_non_negative_int(elem, 'paddings element', self.name) | ||||
| validator.check_value_type('paddings element', elem, [int], self.name) | validator.check_value_type('paddings element', elem, [int], self.name) | ||||
| self.paddings = paddings | self.paddings = paddings | ||||
| block_shape_append = [1] + list(self.block_shape) | |||||
| self.add_prim_attr("block_shape", block_shape_append) | |||||
| paddings_append = [[0, 0]] + list(self.paddings) | |||||
| self.add_prim_attr("paddings", paddings_append) | |||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | ||||
| @@ -4085,8 +4088,6 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| block_shape_prod = 1 | block_shape_prod = 1 | ||||
| offset = 2 | offset = 2 | ||||
| if x_rank <= 4: | |||||
| offset = 1 | |||||
| for i in range(len(self.block_shape)): | for i in range(len(self.block_shape)): | ||||
| padded = out_shape[i + offset] + self.paddings[i][0] + \ | padded = out_shape[i + offset] + self.paddings[i][0] + \ | ||||
| self.paddings[i][1] | self.paddings[i][1] | ||||
| @@ -4108,8 +4109,10 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| dimension and block_shape with given amount to crop from dimension, respectively. | dimension and block_shape with given amount to crop from dimension, respectively. | ||||
| Args: | Args: | ||||
| block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1. | |||||
| The length of block_shape is M corresponding to the number of spatial dimensions. M must be 2. | |||||
| block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater | |||||
| than 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding to the | |||||
| number of spatial dimensions. If `block_shape` is a int, the block size of M dimendions are the same, | |||||
| equal to `block_shape`. M must be 2. | |||||
| crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list, | crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list, | ||||
| each containing 2 int value. | each containing 2 int value. | ||||
| All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to | All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to | ||||
| @@ -4149,8 +4152,9 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, block_shape, crops): | def __init__(self, block_shape, crops): | ||||
| """Initialize BatchToSpaceND""" | """Initialize BatchToSpaceND""" | ||||
| self.ori_block_shape = block_shape | |||||
| self.ori_crops = crops | |||||
| if isinstance(block_shape, int): | |||||
| block_shape = (block_shape,) * 2 | |||||
| self.add_prim_attr("block_shape", block_shape) | |||||
| validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) | validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) | ||||
| validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) | validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name) | ||||
| block_rank = len(block_shape) | block_rank = len(block_shape) | ||||
| @@ -4167,10 +4171,6 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| validator.check_non_negative_int(elem, 'crops element', self.name) | validator.check_non_negative_int(elem, 'crops element', self.name) | ||||
| validator.check_value_type('crops element', elem, [int], self.name) | validator.check_value_type('crops element', elem, [int], self.name) | ||||
| self.crops = crops | self.crops = crops | ||||
| block_shape_append = [1] + list(self.block_shape) | |||||
| self.add_prim_attr("block_shape", block_shape_append) | |||||
| crops_append = [[0, 0]] + list(self.crops) | |||||
| self.add_prim_attr("crops", crops_append) | |||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | ||||
| @@ -4183,8 +4183,6 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| block_shape_prod = 1 | block_shape_prod = 1 | ||||
| offset = 2 | offset = 2 | ||||
| if x_rank <= 4: | |||||
| offset = 1 | |||||
| for i in range(len(self.block_shape)): | for i in range(len(self.block_shape)): | ||||
| block_shape_prod = block_shape_prod * self.block_shape[i] | block_shape_prod = block_shape_prod * self.block_shape[i] | ||||
| x_block_prod = out_shape[i + offset] * self.block_shape[i] | x_block_prod = out_shape[i + offset] * self.block_shape[i] | ||||