diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index a5d005d540..91f1b6c233 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -1085,7 +1085,7 @@ std::string TbeKernelBuild::GetNodeFusionType(const mindspore::CNodePtr &cnode) {kTensorAddOpName, "ElemWise"}, {kConv2DBackpropInputOpName, "Conv2d_backprop_input"}, {kConv2DBackpropFilterOpName, "Conv2d_backprop_filter"}, - {kDepthwiseConv2dNativeName, "DepthwiseConvolution"}, + {kDepthwiseConv2dNativeOpName, "DepthwiseConvolution"}, {kAddNOpName, "ElemWise"}, {kReluGradV2OpName, "ElemWise"}, {kRealDivOpName, "ElemWise"}}; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc new file mode 100644 index 0000000000..3ad5166bf3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc @@ -0,0 +1,320 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" + +#include +#include +#include +#include + +#include "utils/utils.h" +#include "utils/ms_context.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kConv2DBackpropInputNum = 4; +constexpr size_t kConv2DAxisNum = 4; +constexpr auto kAttrOffsetA = "offset_a"; +constexpr auto kAttrPadList = "pad_list"; +constexpr auto kAttrPads = "pads"; +constexpr auto kAttrMode = "mode"; +constexpr auto kAttrChannelMultiplier = "channel_multiplier"; + +bool NeedUpdate(const CNodePtr &conv2d, std::vector in_shape, std::vector out_shape) { + MS_EXCEPTION_IF_NULL(conv2d); + auto group = LongToSize(AnfAlgo::GetNodeAttr(conv2d, kAttrGroup)); + if (group == 1) { + return false; + } + auto data_format = AnfAlgo::GetNodeAttr(conv2d, kAttrDataFormat); + if (data_format != "NCHW") { + MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1, but got " << data_format; + } + if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) { + MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size() + << "output axis num: " << out_shape.size(); + } + auto in_channel = in_shape[1]; + auto out_channel = out_shape[1]; + if (group != in_channel || group != out_channel) { + MS_LOG(EXCEPTION) << "Conv2D's attr group should be equal to in_channel and out_channel when group > 1, but got " + << "group: " << group << " in_channel: " << in_channel << " out_channel: " << out_channel; + } + return true; +} + +ValueNodePtr CreatePermValueNode(const FuncGraphPtr &func_graph, const std::vector &perm) { + MS_EXCEPTION_IF_NULL(func_graph); + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector axis_values{}; + abstract::AbstractBasePtrList abs{}; + for (const auto &axis : perm) { + axis_values.push_back(MakeValue(axis)); + abs.push_back(std::make_shared(axis)); + } + auto perm_value_tuple = std::make_shared(axis_values); + MS_EXCEPTION_IF_NULL(perm_value_tuple); + auto abstract = std::make_shared(abs); + MS_EXCEPTION_IF_NULL(abstract); + auto perm_value = kernel_graph->NewValueNode(abstract, perm_value_tuple); + MS_EXCEPTION_IF_NULL(perm_value); + kernel_graph->AddValueNodeToGraph(perm_value); + return perm_value; +} + +CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, const AnfNodePtr &input_node, + bool need_trans_output) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(conv2d); + MS_EXCEPTION_IF_NULL(input_node); + auto perm = std::vector{1, 0, 2, 3}; + std::vector transpose_inputs = {NewValueNode(std::make_shared(kTransposeOpName)), input_node, + CreatePermValueNode(graph, perm)}; + auto transpose = graph->NewCNode(transpose_inputs); + MS_EXCEPTION_IF_NULL(transpose); + transpose->set_scope(conv2d->scope()); + + if (need_trans_output) { + auto types = {AnfAlgo::GetOutputInferDataType(input_node, 0)}; + auto out_shape = AnfAlgo::GetOutputInferShape(input_node, 0); + if (out_shape.size() != kConv2DAxisNum) { + MS_LOG(EXCEPTION) << "Conv2D's output axis number should be " << kConv2DAxisNum << ", but got " + << out_shape.size(); + } + std::swap(out_shape[0], out_shape[1]); + auto shapes = {out_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, transpose.get()); + } else { + transpose->set_abstract(conv2d->abstract()); + } + + auto input_names = std::vector{"x", "perm"}; + auto output_names = std::vector{"output"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose); + return transpose; +} + +CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(conv2d); + if (conv2d->inputs().size() != kConvInputNum) { + MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " + << conv2d->inputs().size() - 1; + } + std::vector depth_conv_inputs = {NewValueNode(std::make_shared(kDepthwiseConv2dNativeOpName)), + conv2d->input(1), transpose}; + auto depth_conv = graph->NewCNode(depth_conv_inputs); + MS_EXCEPTION_IF_NULL(depth_conv); + depth_conv->set_abstract(conv2d->abstract()); + depth_conv->set_scope(conv2d->scope()); + return depth_conv; +} + +CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNodePtr &conv2d_backin, + const CNodePtr &transpose) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(conv2d_backin); + if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { + MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " + << conv2d_backin->inputs().size() - 1; + } + std::vector depth_conv_backin_inputs = { + NewValueNode(std::make_shared(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3), + transpose, conv2d_backin->input(1)}; + auto depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); + MS_EXCEPTION_IF_NULL(depth_conv_backin); + depth_conv_backin->set_abstract(conv2d_backin->abstract()); + depth_conv_backin->set_scope(conv2d_backin->scope()); + return depth_conv_backin; +} + +CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CNodePtr &conv2d_backfil) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(conv2d_backfil); + if (conv2d_backfil->inputs().size() != kConv2DBackpropInputNum) { + MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " + << conv2d_backfil->inputs().size() - 1; + } + auto filter_size_node = conv2d_backfil->input(3); + MS_EXCEPTION_IF_NULL(filter_size_node); + auto filter_size_vnode = filter_size_node->cast(); + MS_EXCEPTION_IF_NULL(filter_size_vnode); + auto filter_size = GetValue>(filter_size_vnode->value()); + // swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode + // when the filter_size value is same. + if (filter_size[0] != 1) { + std::swap(filter_size[0], filter_size[1]); + conv2d_backfil->input(3)->cast()->set_value(MakeValue(filter_size)); + } + std::vector depth_conv_backfil_inputs = { + NewValueNode(std::make_shared(kDepthwiseConv2dNativeBackpropFilterOpName)), conv2d_backfil->input(2), + conv2d_backfil->input(3), conv2d_backfil->input(1)}; + auto depth_conv_backfil = graph->NewCNode(depth_conv_backfil_inputs); + MS_EXCEPTION_IF_NULL(depth_conv_backfil); + depth_conv_backfil->set_scope(conv2d_backfil->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(conv2d_backfil, 0)}; + std::vector out_shape = AnfAlgo::GetOutputInferShape(conv2d_backfil, 0); + if (out_shape.size() != kConv2DAxisNum) { + MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's output axis number should be " << kConv2DAxisNum << ", but got " + << out_shape.size(); + } + std::swap(out_shape[0], out_shape[1]); + auto shapes = {out_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, depth_conv_backfil.get()); + return depth_conv_backfil; +} + +void SetCommonAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) { + AnfAlgo::CopyNodeAttr(kAttrKernelSize, conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrDilation, conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrDataFormat, conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrPadList, kAttrPads, conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrPadMode, conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrPad, conv2d, depth_conv); + AnfAlgo::SetNodeAttr(kAttrMode, MakeValue(3), depth_conv); + AnfAlgo::SetNodeAttr(kAttrChannelMultiplier, MakeValue(1), depth_conv); +} + +void SetConv2DAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) { + SetCommonAttrs(conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrInputNames, conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrStride, conv2d, depth_conv); + AnfAlgo::CopyNodeAttr(kAttrOffsetA, conv2d, depth_conv); +} + +void SetConv2DBackpropInputAttrs(const CNodePtr &conv2d_backin, const CNodePtr &depth_conv_backin) { + SetCommonAttrs(conv2d_backin, depth_conv_backin); + auto input_names = std::vector{"input_size", "filter", "dout"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), depth_conv_backin); + auto stride = AnfAlgo::GetNodeAttr>(conv2d_backin, kAttrStride); + if (stride.size() == 2) { + stride.insert(stride.begin(), 2, 1); + } + AnfAlgo::SetNodeAttr(kAttrStride, MakeValue(stride), depth_conv_backin); +} + +void SetConv2DBackpropFilterAttrs(const CNodePtr &conv2d_backfil, const CNodePtr &depth_conv_backfil) { + SetCommonAttrs(conv2d_backfil, depth_conv_backfil); + auto input_names = std::vector{"input", "filter_size", "dout"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), depth_conv_backfil); + auto stride = AnfAlgo::GetNodeAttr>(conv2d_backfil, kAttrStride); + if (stride.size() == 2) { + stride.insert(stride.begin(), 2, 1); + } + AnfAlgo::SetNodeAttr(kAttrStride, MakeValue(stride), depth_conv_backfil); +} +} // namespace + +const BaseRef Conv2DUnifyMindIR::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr W = std::make_shared(); + VectorRef pattern({prim::kPrimConv2D, X, W}); + return pattern; +} + +const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + auto conv2d = node->cast(); + MS_EXCEPTION_IF_NULL(conv2d); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d, 0); + auto output_shape = AnfAlgo::GetOutputInferShape(conv2d, 0); + if (!NeedUpdate(conv2d, input_shape, output_shape)) { + return nullptr; + } + + if (conv2d->inputs().size() != kConvInputNum) { + MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " + << conv2d->inputs().size() - 1; + } + auto transpose = CreateTranspose(graph, conv2d, conv2d->input(2), true); + auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose); + SetConv2DAttrs(conv2d, depth_conv); + return depth_conv; +} + +const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const { + VarPtr dout = std::make_shared(); + VarPtr weight = std::make_shared(); + VarPtr input_size = std::make_shared(); + VectorRef pattern({prim::kPrimConv2DBackpropInput, dout, weight, input_size}); + return pattern; +} + +const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + auto conv2d_backin = node->cast(); + MS_EXCEPTION_IF_NULL(conv2d_backin); + auto input_shape = AnfAlgo::GetOutputInferShape(conv2d_backin, 0); + auto output_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backin, 0); + if (!NeedUpdate(conv2d_backin, input_shape, output_shape)) { + return nullptr; + } + + if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { + MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " + << conv2d_backin->inputs().size() - 1; + } + auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(2), true); + auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose); + SetConv2DBackpropInputAttrs(conv2d_backin, depth_conv_backin); + return depth_conv_backin; +} + +const BaseRef Conv2DBackpropFilterUnifyMindIR::DefinePattern() const { + VarPtr dout = std::make_shared(); + VarPtr input = std::make_shared(); + VarPtr filter_size = std::make_shared(); + VectorRef pattern({prim::kPrimConv2DBackpropFilter, dout, input, filter_size}); + return pattern; +} + +const AnfNodePtr Conv2DBackpropFilterUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + auto conv2d_backfil = node->cast(); + MS_EXCEPTION_IF_NULL(conv2d_backfil); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backfil, 1); + auto output_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backfil, 0); + if (!NeedUpdate(conv2d_backfil, input_shape, output_shape)) { + return nullptr; + } + + auto depth_conv_backfil = CreateDepthwiseConv2DBackpropFilter(graph, conv2d_backfil); + SetConv2DBackpropFilterAttrs(conv2d_backfil, depth_conv_backfil); + auto transpose = CreateTranspose(graph, conv2d_backfil, depth_conv_backfil, false); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(conv2d_backfil, transpose); + return transpose; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.h new file mode 100644 index 0000000000..98f24a0c97 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class Conv2DUnifyMindIR : public PatternProcessPass { + public: + explicit Conv2DUnifyMindIR(bool multigraph = true) : PatternProcessPass("conv2d_unify_mindir", multigraph) {} + ~Conv2DUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class Conv2DBackpropInputUnifyMindIR : public PatternProcessPass { + public: + explicit Conv2DBackpropInputUnifyMindIR(bool multigraph = true) + : PatternProcessPass("conv2d_backprop_input_unify_mindir", multigraph) {} + ~Conv2DBackpropInputUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class Conv2DBackpropFilterUnifyMindIR : public PatternProcessPass { + public: + explicit Conv2DBackpropFilterUnifyMindIR(bool multigraph = true) + : PatternProcessPass("conv2d_backprop_filter_unify_mindir", multigraph) {} + ~Conv2DBackpropFilterUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc new file mode 100644 index 0000000000..d848c910ba --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc @@ -0,0 +1,269 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/log_adapter.h" + +constexpr auto kKeepProb = "keep_prob"; +constexpr auto kSeed0 = "Seed0"; +constexpr auto kSeed1 = "Seed1"; +constexpr auto kUint8BitSize = 8; + +namespace mindspore::opt { +constexpr size_t kFloat16Len = 2; // size of float16 +namespace { +AnfNodePtr GetDropoutKeepProb(const AnfNodePtr &node, float *keep_prob) { + MS_LOG(INFO) << "GetDropoutNodeInfo start."; + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(keep_prob); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode) || !AnfAlgo::HasNodeAttr(kSeed0, cnode) || + !AnfAlgo::HasNodeAttr(kSeed1, cnode)) { + MS_LOG(EXCEPTION) << "Dropout node does nothave attr: keep_prob or seed0 or seed1."; + } + *keep_prob = AnfAlgo::GetNodeAttr(node, kKeepProb); + MS_LOG(INFO) << "keep_prob: " << *keep_prob; + // return dropout input. maybe tensor or pre cnode output + return cnode->input(1); +} + +ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float &keep_prob, const TypePtr &dtype) { + MS_LOG(INFO) << "CreateKeepPorbValueNode start."; + MS_EXCEPTION_IF_NULL(func_graph); + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector keep_prob_shape = {}; + ShapeVector shape = {}; + auto keep_prob_tensor = std::make_shared(dtype->type_id(), keep_prob_shape); + MS_EXCEPTION_IF_NULL(keep_prob_tensor); + auto data_ptr = keep_prob_tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + // keep_prob's datatype is same with input data + if (dtype->type_id() == kNumberTypeFloat16) { + float16 half_data = float16(keep_prob); + auto ret_code = memcpy_s(data_ptr, kFloat16Len, &half_data, kFloat16Len); + if (ret_code != 0) { + MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; + } + } else { + auto *val = reinterpret_cast(data_ptr); + *val = keep_prob; + } + auto abstract = std::make_shared(dtype, shape); + auto keep_prob_value = kernel_graph->NewValueNode(abstract, keep_prob_tensor); + MS_EXCEPTION_IF_NULL(keep_prob_value); + kernel_graph->AddValueNodeToGraph(keep_prob_value); + return keep_prob_value; +} + +std::vector GetInputShape(const AnfNodePtr &node, const AnfNodePtr &dropout_input) { + MS_LOG(INFO) << "GetInputShape start."; + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(dropout_input); + std::vector shapes; + if (dropout_input->isa()) { + MS_LOG(INFO) << "Dropout input from parameter node."; + // single test case + auto dropout_input_value = dropout_input->cast(); + MS_EXCEPTION_IF_NULL(dropout_input_value); + MS_EXCEPTION_IF_NULL(dropout_input_value->Shape()); + auto shape = dropout_input_value->Shape()->cast(); + MS_EXCEPTION_IF_NULL(shape); + return shape->shape(); + } else if (dropout_input->isa()) { + MS_LOG(INFO) << "Dropout input from cnode."; + auto dropout_input_node = dropout_input->cast(); + MS_EXCEPTION_IF_NULL(dropout_input_node); + auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); + return shapes; + } else { + MS_LOG(ERROR) << "Dropout input is not parameter or cnode."; + return {}; + } +} + +ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector &shape) { + MS_LOG(INFO) << "CreateShapeValueNode start."; + MS_EXCEPTION_IF_NULL(func_graph); + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector dim_values{}; + abstract::AbstractBasePtrList abs{}; + for (const auto &dim : shape) { + dim_values.push_back(MakeValue(dim)); + abs.push_back(std::make_shared(dim)); + } + auto shape_value_tuple = std::make_shared(dim_values); + MS_EXCEPTION_IF_NULL(shape_value_tuple); + auto abstract = std::make_shared(abs); + MS_EXCEPTION_IF_NULL(abstract); + auto shape_value = kernel_graph->NewValueNode(abstract, shape_value_tuple); + MS_EXCEPTION_IF_NULL(shape_value); + kernel_graph->AddValueNodeToGraph(shape_value); + return shape_value; +} +} // namespace + +const BaseRef DropoutUnifyMindIR::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + auto prim = std::make_shared(kDropoutOpName); + auto ref = VectorRef({prim, X}); + return VectorRef({prim::kPrimTupleGetItem, ref, Y}); +} + +const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto tuple_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_cnode); + auto dropout_node = tuple_cnode->input(1); + MS_EXCEPTION_IF_NULL(dropout_node); + float keep_prob = 0; + auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); + auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); + auto shape = GetInputShape(dropout_node, dropout_input); + auto shape_value = CreateShapeValueNode(func_graph, shape); + // CreateDropoutGenMask + auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + output_size = output_size / kUint8BitSize; + MS_LOG(INFO) << "Output_size: " << output_size; + std::vector dropout_gen_mask_inputs{NewValueNode(std::make_shared(kDropoutGenMaskOpName)), + shape_value, keep_prob_value}; + CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); + ShapeVector dropout_gen_mask_output = {output_size}; + auto gen_mask_abstract = std::make_shared(kUInt8, dropout_gen_mask_output); + MS_EXCEPTION_IF_NULL(gen_mask_abstract); + dropout_gen_mask->set_abstract(gen_mask_abstract); + dropout_gen_mask->set_scope(node->scope()); + + // CreateDropoutDoMask + std::vector dropout_do_mask_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), + dropout_input, dropout_gen_mask, keep_prob_value}; + auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask); + ShapeVector dropout_do_mask_output = shape; + auto do_mask_abstract = std::make_shared(dropout_dtype, dropout_do_mask_output); + dropout_do_mask->set_abstract(do_mask_abstract); + dropout_do_mask->set_scope(node->scope()); + + return dropout_do_mask; +} + +const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + MS_EXCEPTION_IF_NULL(X); + MS_EXCEPTION_IF_NULL(Y); + auto dropout_prim = std::make_shared(kDropoutOpName); + auto tuple_getitem_prim = prim::kPrimTupleGetItem; + auto dropout_grad_prim = std::make_shared(kDropoutGradOpName); + MS_EXCEPTION_IF_NULL(dropout_prim); + MS_EXCEPTION_IF_NULL(dropout_grad_prim); + auto ref0 = VectorRef({dropout_prim, X}); + auto ref1 = VectorRef({tuple_getitem_prim, ref0, Y}); + return VectorRef({dropout_grad_prim, grad_input_, ref1}); +} + +const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto dropout_grad = node->cast(); + MS_EXCEPTION_IF_NULL(dropout_grad); + auto tuple_getitem = dropout_grad->input(2); + MS_EXCEPTION_IF_NULL(tuple_getitem); + auto tuple_getitem_cnode = tuple_getitem->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); + auto dropout_node = tuple_getitem_cnode->input(1); + MS_EXCEPTION_IF_NULL(dropout_node); + float keep_prob = 0; + auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); + auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); + auto shape = GetInputShape(dropout_node, dropout_input); + auto shape_value = CreateShapeValueNode(func_graph, shape); + // CreateDropoutGenMask + auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + output_size = output_size / kUint8BitSize; + MS_LOG(INFO) << "Output_size: " << output_size; + std::vector dropout_gen_mask_inputs{NewValueNode(std::make_shared(kDropoutGenMaskOpName)), + shape_value, keep_prob_value}; + CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); + ShapeVector dropout_gen_mask_output = {output_size}; + auto gen_mask_abstract = std::make_shared(kUInt8, dropout_gen_mask_output); + MS_EXCEPTION_IF_NULL(gen_mask_abstract); + dropout_gen_mask->set_abstract(gen_mask_abstract); + dropout_gen_mask->set_scope(dropout_node->scope()); + // AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); + + // CreateDropoutDoMask-forward + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(dropout_node); + if (iter != node_users.end()) { + for (auto &node_index : iter->second) { + // Dropout has two outputs, so output node is tuple_getitem + auto tuple_getitem_cnode2 = node_index.first->cast(); + // check if Dropout's first output, which is used by forward, is used. + auto getitem_index = GetValue(tuple_getitem_cnode2->input(2)->cast()->value()); + if (getitem_index == 0) { + std::vector dropout_do_mask1_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), + dropout_input, dropout_gen_mask, keep_prob_value}; + auto dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask1); + ShapeVector dropout_do_mask1_output = shape; + auto do_mask_abstract1 = std::make_shared(dropout_dtype, dropout_do_mask1_output); + dropout_do_mask1->set_abstract(do_mask_abstract1); + dropout_do_mask1->set_scope(dropout_node->scope()); + (void)manager->Replace(tuple_getitem_cnode2, dropout_do_mask1); + break; + } + } + } + + // CreateDropoutDoMask-backward + if (equiv->find(grad_input_) == equiv->end()) { + MS_LOG(EXCEPTION) << "Can not find grad_input in this pattern."; + } + auto grad_input = utils::cast((*equiv)[grad_input_]); + std::vector dropout_do_mask2_inputs{NewValueNode(std::make_shared(kDropoutDoMaskOpName)), + grad_input, dropout_gen_mask, keep_prob_value}; + auto dropout_do_mask2 = func_graph->NewCNode(dropout_do_mask2_inputs); + MS_EXCEPTION_IF_NULL(dropout_do_mask2); + ShapeVector dropout_do_mask2_output = shape; + auto do_mask_abstract2 = std::make_shared(dropout_dtype, dropout_do_mask2_output); + dropout_do_mask2->set_abstract(do_mask_abstract2); + dropout_do_mask2->set_scope(node->scope()); + + return dropout_do_mask2; +} +} // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h new file mode 100644 index 0000000000..553796376c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class DropoutUnifyMindIR : public PatternProcessPass { + public: + explicit DropoutUnifyMindIR(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir", multigraph) {} + ~DropoutUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class DropoutGradUnifyMindIR : public PatternProcessPass { + public: + explicit DropoutGradUnifyMindIR(bool multigraph = true) + : PatternProcessPass("dropout_grad_unify_mindir", multigraph) { + grad_input_ = std::make_shared(); + } + ~DropoutGradUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr grad_input_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.cc new file mode 100644 index 0000000000..80604e9210 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" + +#include +#include + +#include "utils/utils.h" +#include "utils/ms_context.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kMaxPoolInputNum = 2; +constexpr size_t kMaxPoolAttrAxisNum = 4; +constexpr size_t kMaxPoolGradInputNum = 4; +constexpr size_t kMaxPoolWithArgmaxOutputNum = 2; + +CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) { + MS_EXCEPTION_IF_NULL(maxpool_grad); + if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) { + MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got " + << maxpool_grad->inputs().size() - 1; + } + auto maxpool_anf = maxpool_grad->input(2); + MS_EXCEPTION_IF_NULL(maxpool_anf); + return maxpool_anf->cast(); +} + +CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(maxpool); + if (maxpool->inputs().size() != kMaxPoolInputNum) { + MS_LOG(EXCEPTION) << "MaxPool's input number should be " << kMaxPoolInputNum - 1 << ", but got " + << maxpool->inputs().size() - 1; + } + std::vector maxpool_argmax_inputs = {NewValueNode(std::make_shared(kMaxPoolWithArgmaxOpName)), + maxpool->input(1)}; + auto maxpool_argmax = graph->NewCNode(maxpool_argmax_inputs); + MS_EXCEPTION_IF_NULL(maxpool_argmax); + maxpool_argmax->set_scope(maxpool->scope()); + + // MaxPoolWithArgmax's second output is argmax, whose datatype is uint16 and with same shape as first output + TypeId argmax_dtype = kNumberTypeUInt16; + auto types = {AnfAlgo::GetOutputInferDataType(maxpool, 0), argmax_dtype}; + auto out_shape = AnfAlgo::GetOutputInferShape(maxpool, 0); + auto shapes = {out_shape, out_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_argmax.get()); + return maxpool_argmax; +} + +CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool_grad, + const std::vector &maxpool_argmax_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(maxpool_grad); + if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) { + MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got " + << maxpool_grad->inputs().size() - 1; + } + // MaxPoolGrad's inputs are {input, output, grad_input}, MaxPoolGradWithArgmax's inputs are + // {input, grad_input, argmax_output} + std::vector maxpool_grad_argmax_inputs = { + NewValueNode(std::make_shared(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(1), + maxpool_grad->input(3), maxpool_argmax_outputs[1]}; + auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs); + MS_EXCEPTION_IF_NULL(maxpool_grad_argmax); + maxpool_grad_argmax->set_scope(maxpool_grad->scope()); + maxpool_grad_argmax->set_abstract(maxpool_grad->abstract()); + return maxpool_grad_argmax; +} + +void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax, + const CNodePtr &maxpool_grad_argmax) { + auto strides = AnfAlgo::GetNodeAttr>(maxpool, kAttrStrides); + auto ksize = AnfAlgo::GetNodeAttr>(maxpool, kAttrKsize); + if (strides.size() != kMaxPoolAttrAxisNum) { + MS_LOG(EXCEPTION) << "MaxPool's attr strides has wrong axis number, should be " << kMaxPoolAttrAxisNum + << ", but got " << strides.size(); + } + if (ksize.size() != kMaxPoolAttrAxisNum) { + MS_LOG(EXCEPTION) << "MaxPool's attr ksize has wrong axis number, should be " << kMaxPoolAttrAxisNum << ", but got " + << ksize.size(); + } + // note that strides and ksize change from (1, 1, x, y) to (1, x, y, 1) + for (size_t i = 1; i <= 2; ++i) { + strides[i] = strides[i + 1]; + ksize[i] = ksize[i + 1]; + } + strides[3] = 1; + ksize[3] = 1; + + AnfAlgo::CopyNodeAttrs(maxpool, maxpool_argmax); + AnfAlgo::CopyNodeAttrs(maxpool_grad, maxpool_grad_argmax); + AnfAlgo::SetNodeAttr(kAttrStrides, MakeValue(strides), maxpool_argmax); + AnfAlgo::SetNodeAttr(kAttrStrides, MakeValue(strides), maxpool_grad_argmax); + AnfAlgo::SetNodeAttr(kAttrKsize, MakeValue(ksize), maxpool_argmax); + AnfAlgo::SetNodeAttr(kAttrKsize, MakeValue(ksize), maxpool_grad_argmax); +} +} // namespace + +const BaseRef MaxPool2MaxPoolWithArgmax::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + VectorRef maxpool({prim::kPrimMaxPool, X}); + VectorRef pattern({prim::kPrimMaxPoolGrad, X, maxpool, Y}); + return pattern; +} + +const AnfNodePtr MaxPool2MaxPoolWithArgmax::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + auto maxpool_grad = node->cast(); + MS_EXCEPTION_IF_NULL(maxpool_grad); + auto maxpool = GetMaxPool(maxpool_grad); + MS_EXCEPTION_IF_NULL(maxpool); + + auto maxpool_argmax = CreateMaxPoolWithArgmax(graph, maxpool); + std::vector maxpool_argmax_outputs; + CreateMultipleOutputsOfAnfNode(graph, maxpool_argmax, kMaxPoolWithArgmaxOutputNum, &maxpool_argmax_outputs); + auto maxpool_grad_argmax = CreateMaxPoolGradWithArgmax(graph, maxpool_grad, maxpool_argmax_outputs); + SetNodeAttrs(maxpool, maxpool_grad, maxpool_argmax, maxpool_grad_argmax); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(maxpool, maxpool_argmax_outputs[0]); + return maxpool_grad_argmax; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h new file mode 100644 index 0000000000..6956c72365 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MaxPool2MaxPoolWithArgmax : public PatternProcessPass { + public: + explicit MaxPool2MaxPoolWithArgmax(bool multigraph = true) + : PatternProcessPass("maxpool_to_maxpool_with_argmax", multigraph) {} + ~MaxPool2MaxPoolWithArgmax() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.cc new file mode 100644 index 0000000000..5f2cd4722a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" +#include +#include +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "base/core_ops.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4; +bool IsC(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + return in->isa(); + } + return false; +} + +CNodePtr GetMaxPoolWithArgmax(const CNodePtr &maxpool_grad_with_argmax) { + MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax); + if (maxpool_grad_with_argmax->inputs().size() != kMaxPoolGradWithArgmaxInputNum) { + MS_LOG(EXCEPTION) << "MaxPoolGradWithArgmax has wrong input size."; + } + auto tuple_getitem0_anf = maxpool_grad_with_argmax->input(3); + MS_EXCEPTION_IF_NULL(tuple_getitem0_anf); + return tuple_getitem0_anf->cast(); +} +} // namespace + +const BaseRef MaxPoolWithArgmaxUnifyMindIR::DefinePattern() const { + VarPtr X = std::make_shared(); + VectorRef pattern({prim::kPrimMaxPoolWithArgmax, X}); + return pattern; +} + +const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto maxpool_with_argmax = node->cast(); + MS_EXCEPTION_IF_NULL(maxpool_with_argmax); + + TypeId argmax_dtype = kNumberTypeUInt16; + auto ksize = AnfAlgo::GetNodeAttr>(maxpool_with_argmax, kAttrKsize); + auto output_shape = AnfAlgo::GetOutputInferShape(maxpool_with_argmax, 0); + auto argmax_shape = output_shape; + if (argmax_shape.size() != 4) { + MS_LOG(DEBUG) << "argmax's infer shape size not equal 4"; + } + argmax_shape[2] = ksize[1] * ksize[2]; + argmax_shape[3] = (output_shape[2] * output_shape[3] + 15) / 16 + 1; + auto types = {AnfAlgo::GetOutputInferDataType(maxpool_with_argmax, 0), argmax_dtype}; + auto shapes = {output_shape, argmax_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get()); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + return maxpool_with_argmax; +} + +const BaseRef MaxPoolGradWithArgmaxUnifyMindIR::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VectorRef maxpool_with_argmax({prim::kPrimMaxPoolWithArgmax, X}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, maxpool_with_argmax, index0}); + VectorRef maxpool_grad_with_argmax({prim::kPrimMaxPoolGradWithArgmax, X, Y, tuple_getitem0}); + return maxpool_grad_with_argmax; +} + +const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto maxpool_grad_with_argmax = node->cast(); + MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax); + auto tuple_getitem0_anf = GetMaxPoolWithArgmax(maxpool_grad_with_argmax); + MS_EXCEPTION_IF_NULL(tuple_getitem0_anf); + + TypeId argmax_dtype = kNumberTypeUInt16; + auto ksize = AnfAlgo::GetNodeAttr>(maxpool_grad_with_argmax, kAttrKsize); + auto argmax_shape = AnfAlgo::GetOutputInferShape(tuple_getitem0_anf, 0); + if (argmax_shape.size() != 4) { + MS_LOG(DEBUG) << "argmax's infer shape size not equal 4"; + } + argmax_shape[3] = (argmax_shape[2] * argmax_shape[3] + 15) / 16 + 1; + argmax_shape[2] = ksize[1] * ksize[2]; + AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get()); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + return maxpool_grad_with_argmax; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h new file mode 100644 index 0000000000..fc618da3ac --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MaxPoolWithArgmaxUnifyMindIR : public PatternProcessPass { + public: + explicit MaxPoolWithArgmaxUnifyMindIR(bool multigraph = true) + : PatternProcessPass("maxpool_with_argmax_unify_mindir", multigraph) {} + ~MaxPoolWithArgmaxUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class MaxPoolGradWithArgmaxUnifyMindIR : public PatternProcessPass { + public: + explicit MaxPoolGradWithArgmaxUnifyMindIR(bool multigraph = true) + : PatternProcessPass("maxpool_grad_with_argmax_unify_mindir", multigraph) {} + ~MaxPoolGradWithArgmaxUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index c0cb8bc0ee..03b90a8850 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -101,6 +101,10 @@ constexpr size_t kFusedMulApplyMomentumOutputNum = 2; constexpr size_t kSplitInputNum = 2; constexpr size_t kGatherV2DynInputNum = 3; constexpr size_t kUnsortedSegmentSumInputNum = 2; +constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2; +constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputNum = 3; +constexpr size_t kOneHotOutputNum = 1; +constexpr size_t kOneHotInputNum = 5; enum FusedBatchNormInput { kX = 1, diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 16abae939e..69f8d1f28b 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -32,6 +32,10 @@ #include "runtime/device/ascend/ascend_kernel_runtime.h" #include "backend/optimizer/ascend/ascend_backend_optimization.h" #include "backend/optimizer/common/common_backend_optimization.h" +#include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" +#include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" +#include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" +#include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" #include "runtime/device/kernel_adjust.h" #include "runtime/device/ascend/ascend_stream_assign.h" #include "backend/session/anf_runtime_algorithm.h" @@ -423,6 +427,35 @@ void AscendSession::Init(uint32_t device_id) { runtime_instance->CreateContext(); } +void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); + if (save_graphs) { + std::string file_name = "hwopt_d_before_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_name, graph); + DumpIRProto(graph, "before_unify_mindir_hwopt_" + std::to_string(graph->graph_id())); + } + auto optimizer = std::make_shared(); + auto unify_mindir_pm = std::make_shared("unify_mindir_pm"); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + + optimizer->AddPassManager(unify_mindir_pm); + (void)optimizer->Optimize(graph); + graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_name = "hwopt_d_after_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_name, graph); + } +} + GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { MS_LOG(INFO) << "Start"; // construct graph, if successfully, graph_sum_ + 1 @@ -438,6 +471,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); // Update Graph Dynamic Shape Attr UpdateAllGraphDynamicShapeAttr(all_graphs); + for (const auto &graph : all_graphs) { + UnifyMindIR(graph); + } BackendOptimization(all_graphs); // empty graph dont entry to backend if (root_graph->execution_order().empty()) { @@ -1219,7 +1255,6 @@ void AscendSession::IrFusionPass(const NotNull graph, NotNullinsert(graph.get()); - opt::AscendBackendIRFusionOptimization(graph); graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 100882159e..7e33f70bba 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -51,6 +51,7 @@ class AscendSession : public SessionBasic { void SyncStream() override; protected: + void UnifyMindIR(const KernelGraphPtr &graph) override; GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraphImpl(NotNull func_graph) override; GraphId CompileGraphImpl(NotNull func_graph, const std::vector &inputs) override; diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index 376b6e6a20..e121b84f7a 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -32,6 +32,7 @@ class CPUSession : public SessionBasic { void Init(uint32_t device_id) override { InitExecutor(kCPUDevice, device_id); } protected: + void UnifyMindIR(const KernelGraphPtr &graph) override { return; } void CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, VectorRef *, std::map *tensor_to_node) override; GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 3ff72d5852..05ad8b5796 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -35,6 +35,7 @@ class GPUSession : public SessionBasic { void SyncStream() override; protected: + void UnifyMindIR(const KernelGraphPtr &graph) override { return; } GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; void RunGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index b457f65e81..11b72e49d9 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -943,6 +943,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con // Update Graph Dynamic Shape Attr UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); + UnifyMindIR(graph); opt::BackendCommonOptimization(graph); graph->SetInputNodes(); auto input_nodes = graph->input_nodes(); @@ -1610,6 +1611,7 @@ std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInf // set output CreateOutputNode(cnode, graph); graph->SetInputNodes(); + UnifyMindIR(graph); return graph; } diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 9d6b64903b..a6360dfd0f 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -147,6 +147,7 @@ class SessionBasic : public std::enable_shared_from_this { virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, VectorRef *outputs, std::map *tensor_to_node); + virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0; virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; virtual GraphId CompileGraphImpl(NotNull func_graph) { return kInvalidGraphId; } virtual GraphId CompileGraphImpl(NotNull func_graph, const std::vector &inputs) { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 1ae430c8c5..baeee29bcb 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -163,7 +163,9 @@ constexpr auto kBatchToSpaceOpName = "BatchToSpace"; constexpr auto kPadOpName = "Pad"; constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput"; constexpr auto kConv2DBackpropFilterOpName = "Conv2DBackpropFilter"; -constexpr auto kDepthwiseConv2dNativeName = "DepthwiseConv2dNative"; +constexpr auto kDepthwiseConv2dNativeOpName = "DepthwiseConv2dNative"; +constexpr auto kDepthwiseConv2dNativeBackpropInputOpName = "DepthwiseConv2dNativeBackpropInput"; +constexpr auto kDepthwiseConv2dNativeBackpropFilterOpName = "DepthwiseConv2dNativeBackpropFilter"; constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2"; constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2"; constexpr auto kLabelSetOpName = "LabelSet"; @@ -204,6 +206,8 @@ constexpr auto kPaddingOpName = "Padding"; constexpr auto kAvgPoolOpName = "AvgPool"; constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; +constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; +constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; constexpr auto kTensorAddOpName = "TensorAdd"; constexpr auto kCastOpName = "Cast"; constexpr auto kGreaterEqualOpName = "GreaterEqual"; @@ -250,6 +254,13 @@ constexpr auto kMatMulV2OpName = "MatMulV2"; constexpr auto kBroadcastToOpName = "BroadcastTo"; constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2"; +constexpr auto kDropoutOpName = "Dropout"; +constexpr auto kDropoutGradOpName = "DropoutGrad"; +constexpr auto kDropoutGenMaskOpName = "DropoutGenMask"; +constexpr auto kDropoutDoMaskOpName = "DropoutDoMask"; +constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits"; +constexpr auto kOneHotOpName = "OneHot"; +constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; // Hcom Op Type constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; @@ -272,6 +283,7 @@ constexpr auto kAttrEpsilon = "epsilon"; constexpr auto kAttrFactor = "factor"; constexpr auto kAttrIsRef = "isRef"; constexpr auto kAttrDataShape = "data_shape"; +constexpr auto kAttrDataFormat = "data_format"; constexpr auto kAttrAxis = "axis"; constexpr auto kAttrKeepDims = "keep_dims"; constexpr auto kAttrShapeGamma = "shape_gamma"; @@ -348,6 +360,15 @@ constexpr auto kAttrPynativeNextOpName = "next_op"; constexpr auto kAttrPynativeNextIndex = "next_index"; constexpr auto kAttrCompileInfo = "compile_info"; constexpr auto kAttrFusionType = "fusion_type"; +constexpr auto kAttrStride = "stride"; +constexpr auto kAttrStrides = "strides"; +constexpr auto kAttrKsize = "ksize"; +constexpr auto kAttrKernelSize = "kernel_size"; +constexpr auto kAttrDilation = "dilation"; +constexpr auto kAttrPadMode = "pad_mode"; +constexpr auto kAttrPad = "pad"; +constexpr auto kAttrPadding = "padding"; +constexpr auto kAttrIsGrad = "is_grad"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ecd27f3583..cd8be927ad 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -134,6 +134,8 @@ inline const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); inline const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); inline const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); +inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared("MaxPoolWithArgmax"); +inline const PrimitivePtr kPrimMaxPoolGradWithArgmax = std::make_shared("MaxPoolGradWithArgmax"); inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); inline const PrimitivePtr kPrimAvgPool = std::make_shared("AvgPool"); inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 0cd630f766..8a8ed81c5a 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -141,37 +141,21 @@ class Dropout(Cell): raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) - self.keep_prob = keep_prob seed0, seed1 = _get_graph_seed(0, "dropout") self.seed0 = seed0 self.seed1 = seed1 - self.dtype = dtype - self.get_shape = P.Shape() - self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) - self.dropout_do_mask = P.DropoutDoMask() - self.cast = P.Cast() - self.is_ascend = context.get_context('device_target') in ["Ascend"] - self.dropout = P.Dropout(keep_prob) + self.keep_prob = keep_prob + self.dropout = P.Dropout(keep_prob, seed0, seed1) def construct(self, x): if not self.training: return x - if not self.is_ascend: - out, _ = self.dropout(x) - return out - if self.keep_prob == 1: return x - shape = self.get_shape(x) - dtype = P.DType()(x) - if _is_float_dtype(dtype): - keep_prob = self.cast(self.keep_prob, dtype) - else: - keep_prob = self.cast(self.keep_prob, mstype.float16) - output = self.dropout_gen_mask(shape, keep_prob) - return self.dropout_do_mask(x, output, keep_prob) + out, _ = self.dropout(x) + return out def extend_repr(self): return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 9f2e0c3b26..7a80b2fae6 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -19,7 +19,7 @@ from mindspore import context from mindspore.ops import operations as P from mindspore.ops.primitive import constexpr from mindspore.common.parameter import Parameter -from mindspore.common.initializer import initializer, Initializer +from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor from mindspore._checkparam import Validator, Rel, twice from mindspore._extends import cell_attr_register @@ -247,28 +247,8 @@ class Conv2d(_Conv): dilation=self.dilation, group=self.group, data_format=self.format) - self._init_depthwise_conv2d() self.bias_add = P.BiasAdd() - def _init_depthwise_conv2d(self): - """Initialize depthwise conv2d op""" - if context.get_context("device_target") == "Ascend" and self.group > 1: - self.dilation = self._dilation - Validator.check_equal_int(self.group, self.in_channels, 'group') - Validator.check_equal_int(self.group, self.out_channels, 'group') - self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, - kernel_size=self.kernel_size, - pad_mode=self.pad_mode, - pad=self.padding, - stride=self.stride, - dilation=self.dilation) - weight_shape = [1, self.in_channels, *self.kernel_size] - if isinstance(self.weight_init, Tensor): - self.weight_init = Tensor(self.weight_init.asnumpy().swapaxes(0, 1), self.weight_init.dtype) - if isinstance(self.weight_init, Initializer): - self.weight_init.shape = weight_shape - self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight') - def construct(self, x): output = self.conv2d(x, self.weight) if self.has_bias: diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 018d7d194b..9a4fd7cf2c 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -124,16 +124,9 @@ class MaxPool2d(_PoolNd): strides=self.stride, padding=self.pad_mode, data_format=self.format) - self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size, - strides=self.stride, - padding=self.pad_mode) - self.is_tbe = context.get_context("device_target") == "Ascend" def construct(self, x): - if self.is_tbe and self.training: - out = self.max_pool_with_arg_max(x)[0] - else: - out = self.max_pool(x) + out = self.max_pool(x) return out @@ -198,22 +191,15 @@ class MaxPool1d(_PoolNd): self.max_pool = P.MaxPool(ksize=self.kernel_size, strides=self.stride, padding=self.pad_mode) - self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size, - strides=self.stride, - padding=self.pad_mode) self.shape = F.shape self.reduce_mean = P.ReduceMean(keep_dims=True) self.expand = P.ExpandDims() self.squeeze = P.Squeeze(2) - self.is_tbe = context.get_context("device_target") == "Ascend" def construct(self, x): _shape_check(self.shape(x)) x = self.expand(x, 2) - if self.is_tbe and self.training: - output = self.max_pool_with_arg_max(x)[0] - else: - output = self.max_pool(x) + output = self.max_pool(x) output = self.squeeze(output) return output diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index b04d4208da..562f5f1c3d 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -433,27 +433,15 @@ class Conv2dBnFoldQuantOneConv(Cell): (self.is_ge_backend or self.is_ascend) # initialize convolution op and Parameter - if context.get_context('device_target') == "Ascend" and group > 1: - Validator.check_equal_int(group, in_channels, 'group') - Validator.check_equal_int(group, out_channels, 'group') - self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, - kernel_size=self.kernel_size, - pad_mode=pad_mode, - pad=padding, - stride=self.stride, - dilation=self.dilation) - weight_shape = [1, in_channels, *self.kernel_size] - channel_axis = 1 - else: - self.conv = P.Conv2D(out_channel=out_channels, - kernel_size=self.kernel_size, - pad_mode=pad_mode, - pad=padding, - stride=self.stride, - dilation=self.dilation, - group=group) - weight_shape = [out_channels, in_channels // group, *self.kernel_size] - channel_axis = 0 + self.conv = P.Conv2D(out_channel=out_channels, + kernel_size=self.kernel_size, + pad_mode=pad_mode, + pad=padding, + stride=self.stride, + dilation=self.dilation, + group=group) + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + channel_axis = 0 self.channel_axis = channel_axis self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.bias_add = P.BiasAdd() @@ -651,27 +639,15 @@ class Conv2dBnFoldQuant(Cell): self.is_gpu = context.get_context('device_target') == "GPU" # initialize convolution op and Parameter - if context.get_context('device_target') == "Ascend" and group > 1: - Validator.check_equal_int(group, in_channels, 'group') - Validator.check_equal_int(group, out_channels, 'group') - self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, - kernel_size=self.kernel_size, - pad_mode=pad_mode, - pad=padding, - stride=self.stride, - dilation=self.dilation) - weight_shape = [1, in_channels, *self.kernel_size] - channel_axis = 1 - else: - self.conv = P.Conv2D(out_channel=out_channels, - kernel_size=self.kernel_size, - pad_mode=pad_mode, - pad=padding, - stride=self.stride, - dilation=self.dilation, - group=group) - weight_shape = [out_channels, in_channels // group, *self.kernel_size] - channel_axis = 0 + self.conv = P.Conv2D(out_channel=out_channels, + kernel_size=self.kernel_size, + pad_mode=pad_mode, + pad=padding, + stride=self.stride, + dilation=self.dilation, + group=group) + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + channel_axis = 0 self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.bias_add = P.BiasAdd() if Validator.check_bool(has_bias): @@ -830,28 +806,16 @@ class Conv2dBnWithoutFoldQuant(Cell): else: self.bias = None # initialize convolution op and Parameter - if context.get_context('device_target') == "Ascend" and group > 1: - Validator.check_equal_int(group, in_channels, 'group') - Validator.check_equal_int(group, out_channels, 'group') - self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, - kernel_size=self.kernel_size, - pad_mode=pad_mode, - pad=padding, - stride=self.stride, - dilation=self.dilation) - weight_shape = [1, in_channels, *self.kernel_size] - channel_axis = 1 - else: - self.conv = P.Conv2D(out_channel=self.out_channels, - kernel_size=self.kernel_size, - mode=1, - pad_mode=self.pad_mode, - pad=self.padding, - stride=self.stride, - dilation=self.dilation, - group=self.group) - weight_shape = [out_channels, in_channels // group, *self.kernel_size] - channel_axis = 0 + self.conv = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + channel_axis = 0 self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.fake_quant_weight = quant_config.weight(min_init=-6, max_init=6, @@ -963,10 +927,11 @@ class Conv2dQuant(Cell): stride=self.stride, dilation=self.dilation, group=self.group) + channel_axis = 0 self.fake_quant_weight = quant_config.weight(min_init=-6, max_init=6, ema=False, - channel_axis=0, + channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index ef889eb8b3..2f5c3d9b39 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1574,32 +1574,12 @@ class MaxPoolWithArgmax(_Pool): def infer_shape(self, x_shape): out_shape = _Pool.infer_shape(self, x_shape) - _, _, out_h, out_w = out_shape - _, kernel_h, kernel_w, _ = self.ksize - - argmax_shape = [] - if self.is_tbe: - for i in range(4): - if i == 2: - dim = kernel_h * kernel_w - argmax_shape.append(dim) - elif i == 3: - dim = math.ceil(out_h * out_w / 16) + 1 - argmax_shape.append(dim) - else: - argmax_shape.append(x_shape[i]) - else: - argmax_shape = out_shape - - return out_shape, argmax_shape + return out_shape, out_shape def infer_dtype(self, x_dtype): - out_dtype = x_dtype validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) - argmax_dtype = mstype.uint16 - if self.is_gpu: - argmax_dtype = mstype.int32 - return out_dtype, argmax_dtype + argmax_dtype = mstype.int32 + return x_dtype, argmax_dtype class AvgPool(_Pool): @@ -6070,7 +6050,9 @@ class Dropout(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, keep_prob=0.5): + def __init__(self, keep_prob=0.5, Seed0=0, Seed1=0): + self.seed0 = validator.check_value_type("Seed0", Seed0, [int], self.name) + self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name) self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) def infer_shape(self, x_shape): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index a6868f3b11..9416febc7c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1615,7 +1615,7 @@ test_case_nn_ops = [ ('MaxPoolWithArgmax', { 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), 'desc_inputs': [[128, 32, 32, 64]], - 'desc_bprop': [[128, 32, 16, 32], ([128, 32, 4, 33], {'dtype': np.uint16})]}), + 'desc_bprop': [[128, 32, 16, 32], ([128, 32, 16, 32], {'dtype': np.int32})]}), ('SoftmaxCrossEntropyWithLogits', { 'block': P.SoftmaxCrossEntropyWithLogits(), 'desc_inputs': [[1, 10], [1, 10]], diff --git a/tests/ut/python/parallel/test_matmul_dropout.py b/tests/ut/python/parallel/test_matmul_dropout.py index 98f955935e..0bea7d34f0 100644 --- a/tests/ut/python/parallel/test_matmul_dropout.py +++ b/tests/ut/python/parallel/test_matmul_dropout.py @@ -18,7 +18,11 @@ import mindspore as ms import mindspore.nn as nn from mindspore import Tensor from mindspore import context +import mindspore.common.dtype as mstype +from mindspore.common.seed import _get_graph_seed from mindspore.common.api import _executor +from mindspore._checkparam import Validator +from mindspore.ops.primitive import constexpr from mindspore.ops import composite as C from mindspore.ops import operations as P from tests.ut.python.ops.test_math_ops import VirtualLoss @@ -47,13 +51,61 @@ class GradWrap(nn.Cell): return grad_all(self.network)(x, y, b) +@constexpr +def _is_float_dtype(dtype): + if dtype in [mstype.float32, mstype.float16]: + return True + return False + +class Dropout(nn.Cell): + def __init__(self, keep_prob=0.5, dtype=mstype.float32): + super(Dropout, self).__init__() + if keep_prob <= 0 or keep_prob > 1: + raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) + Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) + Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) + self.keep_prob = keep_prob + seed0, seed1 = _get_graph_seed(0, "dropout") + self.seed0 = seed0 + self.seed1 = seed1 + self.dtype = dtype + self.get_shape = P.Shape() + self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) + self.dropout_do_mask = P.DropoutDoMask() + self.cast = P.Cast() + self.is_gpu = context.get_context('device_target') in ["GPU"] + self.dropout = P.Dropout(keep_prob) + + def construct(self, x): + if not self.training: + return x + + if self.is_gpu: + out, _ = self.dropout(x) + return out + + if self.keep_prob == 1: + return x + + shape = self.get_shape(x) + dtype = P.DType()(x) + if _is_float_dtype(dtype): + keep_prob = self.cast(self.keep_prob, dtype) + else: + keep_prob = self.cast(self.keep_prob, mstype.float16) + output = self.dropout_gen_mask(shape, keep_prob) + return self.dropout_do_mask(x, output, keep_prob) + + def extend_repr(self): + return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) + # model_parallel test def test_two_matmul_dropout(): class Net(nn.Cell): def __init__(self, strategy1, strategy2, strategy3): super().__init__() self.matmul1 = P.MatMul().shard(strategy1) - self.dropout = nn.Dropout() + self.dropout = Dropout() self.dropout.dropout_do_mask.shard(strategy2) self.dropout.dropout_gen_mask.shard(strategy2) self.matmul2 = P.MatMul().shard(strategy3) diff --git a/tests/ut/python/parallel/test_one_dev.py b/tests/ut/python/parallel/test_one_dev.py index 764470531a..e6bec95daf 100644 --- a/tests/ut/python/parallel/test_one_dev.py +++ b/tests/ut/python/parallel/test_one_dev.py @@ -19,11 +19,14 @@ import mindspore as ms import mindspore.nn as nn from mindspore import Tensor from mindspore import context +import mindspore.common.dtype as mstype from mindspore.common.api import _executor from mindspore.common.parameter import Parameter -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.nn.loss.loss import _Loss from mindspore.nn.optim.momentum import Momentum from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import _selected_ops from mindspore.parallel._utils import _reset_op_id from mindspore.train import Model from mindspore.context import ParallelMode @@ -66,6 +69,33 @@ class AllToAllNet(nn.Cell): return x +class SoftmaxCrossEntropyWithLogits(_Loss): + def __init__(self, + sparse=False, + reduction='none'): + super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) + self.sparse = sparse + self.reduction = reduction + self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0., mstype.float32) + self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] + + if self.is_cpugpu: + self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() + + def construct(self, logits, labels): + if self.is_cpugpu and self.sparse and self.reduction == 'mean': + x = self.sparse_softmax_cross_entropy(logits, labels) + return x + + if self.sparse: + labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value) + x = self.softmax_cross_entropy(logits, labels)[0] + return self.get_loss(x) + + def all_to_all_net(): return AllToAllNet()