From: @yuchaojie Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -1085,7 +1085,7 @@ std::string TbeKernelBuild::GetNodeFusionType(const mindspore::CNodePtr &cnode) | |||
| {kTensorAddOpName, "ElemWise"}, | |||
| {kConv2DBackpropInputOpName, "Conv2d_backprop_input"}, | |||
| {kConv2DBackpropFilterOpName, "Conv2d_backprop_filter"}, | |||
| {kDepthwiseConv2dNativeName, "DepthwiseConvolution"}, | |||
| {kDepthwiseConv2dNativeOpName, "DepthwiseConvolution"}, | |||
| {kAddNOpName, "ElemWise"}, | |||
| {kReluGradV2OpName, "ElemWise"}, | |||
| {kRealDivOpName, "ElemWise"}}; | |||
| @@ -0,0 +1,320 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kConv2DBackpropInputNum = 4; | |||
| constexpr size_t kConv2DAxisNum = 4; | |||
| constexpr auto kAttrOffsetA = "offset_a"; | |||
| constexpr auto kAttrPadList = "pad_list"; | |||
| constexpr auto kAttrPads = "pads"; | |||
| constexpr auto kAttrMode = "mode"; | |||
| constexpr auto kAttrChannelMultiplier = "channel_multiplier"; | |||
| bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) { | |||
| MS_EXCEPTION_IF_NULL(conv2d); | |||
| auto group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(conv2d, kAttrGroup)); | |||
| if (group == 1) { | |||
| return false; | |||
| } | |||
| auto data_format = AnfAlgo::GetNodeAttr<std::string>(conv2d, kAttrDataFormat); | |||
| if (data_format != "NCHW") { | |||
| MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1, but got " << data_format; | |||
| } | |||
| if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size() | |||
| << "output axis num: " << out_shape.size(); | |||
| } | |||
| auto in_channel = in_shape[1]; | |||
| auto out_channel = out_shape[1]; | |||
| if (group != in_channel || group != out_channel) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's attr group should be equal to in_channel and out_channel when group > 1, but got " | |||
| << "group: " << group << " in_channel: " << in_channel << " out_channel: " << out_channel; | |||
| } | |||
| return true; | |||
| } | |||
| ValueNodePtr CreatePermValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &perm) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<ValuePtr> axis_values{}; | |||
| abstract::AbstractBasePtrList abs{}; | |||
| for (const auto &axis : perm) { | |||
| axis_values.push_back(MakeValue(axis)); | |||
| abs.push_back(std::make_shared<abstract::AbstractScalar>(axis)); | |||
| } | |||
| auto perm_value_tuple = std::make_shared<ValueTuple>(axis_values); | |||
| MS_EXCEPTION_IF_NULL(perm_value_tuple); | |||
| auto abstract = std::make_shared<abstract::AbstractTuple>(abs); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto perm_value = kernel_graph->NewValueNode(abstract, perm_value_tuple); | |||
| MS_EXCEPTION_IF_NULL(perm_value); | |||
| kernel_graph->AddValueNodeToGraph(perm_value); | |||
| return perm_value; | |||
| } | |||
| CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, const AnfNodePtr &input_node, | |||
| bool need_trans_output) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(conv2d); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| auto perm = std::vector<int64_t>{1, 0, 2, 3}; | |||
| std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node, | |||
| CreatePermValueNode(graph, perm)}; | |||
| auto transpose = graph->NewCNode(transpose_inputs); | |||
| MS_EXCEPTION_IF_NULL(transpose); | |||
| transpose->set_scope(conv2d->scope()); | |||
| if (need_trans_output) { | |||
| auto types = {AnfAlgo::GetOutputInferDataType(input_node, 0)}; | |||
| auto out_shape = AnfAlgo::GetOutputInferShape(input_node, 0); | |||
| if (out_shape.size() != kConv2DAxisNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's output axis number should be " << kConv2DAxisNum << ", but got " | |||
| << out_shape.size(); | |||
| } | |||
| std::swap(out_shape[0], out_shape[1]); | |||
| auto shapes = {out_shape}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, transpose.get()); | |||
| } else { | |||
| transpose->set_abstract(conv2d->abstract()); | |||
| } | |||
| auto input_names = std::vector<std::string>{"x", "perm"}; | |||
| auto output_names = std::vector<std::string>{"output"}; | |||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose); | |||
| AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose); | |||
| return transpose; | |||
| } | |||
| CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(conv2d); | |||
| if (conv2d->inputs().size() != kConvInputNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " | |||
| << conv2d->inputs().size() - 1; | |||
| } | |||
| std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)), | |||
| conv2d->input(1), transpose}; | |||
| auto depth_conv = graph->NewCNode(depth_conv_inputs); | |||
| MS_EXCEPTION_IF_NULL(depth_conv); | |||
| depth_conv->set_abstract(conv2d->abstract()); | |||
| depth_conv->set_scope(conv2d->scope()); | |||
| return depth_conv; | |||
| } | |||
| CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNodePtr &conv2d_backin, | |||
| const CNodePtr &transpose) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(conv2d_backin); | |||
| if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " | |||
| << conv2d_backin->inputs().size() - 1; | |||
| } | |||
| std::vector<AnfNodePtr> depth_conv_backin_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3), | |||
| transpose, conv2d_backin->input(1)}; | |||
| auto depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs); | |||
| MS_EXCEPTION_IF_NULL(depth_conv_backin); | |||
| depth_conv_backin->set_abstract(conv2d_backin->abstract()); | |||
| depth_conv_backin->set_scope(conv2d_backin->scope()); | |||
| return depth_conv_backin; | |||
| } | |||
| CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CNodePtr &conv2d_backfil) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(conv2d_backfil); | |||
| if (conv2d_backfil->inputs().size() != kConv2DBackpropInputNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " | |||
| << conv2d_backfil->inputs().size() - 1; | |||
| } | |||
| auto filter_size_node = conv2d_backfil->input(3); | |||
| MS_EXCEPTION_IF_NULL(filter_size_node); | |||
| auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(filter_size_vnode); | |||
| auto filter_size = GetValue<std::vector<int64_t>>(filter_size_vnode->value()); | |||
| // swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode | |||
| // when the filter_size value is same. | |||
| if (filter_size[0] != 1) { | |||
| std::swap(filter_size[0], filter_size[1]); | |||
| conv2d_backfil->input(3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size)); | |||
| } | |||
| std::vector<AnfNodePtr> depth_conv_backfil_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)), conv2d_backfil->input(2), | |||
| conv2d_backfil->input(3), conv2d_backfil->input(1)}; | |||
| auto depth_conv_backfil = graph->NewCNode(depth_conv_backfil_inputs); | |||
| MS_EXCEPTION_IF_NULL(depth_conv_backfil); | |||
| depth_conv_backfil->set_scope(conv2d_backfil->scope()); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(conv2d_backfil, 0)}; | |||
| std::vector<size_t> out_shape = AnfAlgo::GetOutputInferShape(conv2d_backfil, 0); | |||
| if (out_shape.size() != kConv2DAxisNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's output axis number should be " << kConv2DAxisNum << ", but got " | |||
| << out_shape.size(); | |||
| } | |||
| std::swap(out_shape[0], out_shape[1]); | |||
| auto shapes = {out_shape}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, depth_conv_backfil.get()); | |||
| return depth_conv_backfil; | |||
| } | |||
| void SetCommonAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) { | |||
| AnfAlgo::CopyNodeAttr(kAttrKernelSize, conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrDilation, conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrDataFormat, conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrPadList, kAttrPads, conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrPadMode, conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrPad, conv2d, depth_conv); | |||
| AnfAlgo::SetNodeAttr(kAttrMode, MakeValue(3), depth_conv); | |||
| AnfAlgo::SetNodeAttr(kAttrChannelMultiplier, MakeValue(1), depth_conv); | |||
| } | |||
| void SetConv2DAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) { | |||
| SetCommonAttrs(conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrInputNames, conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrStride, conv2d, depth_conv); | |||
| AnfAlgo::CopyNodeAttr(kAttrOffsetA, conv2d, depth_conv); | |||
| } | |||
| void SetConv2DBackpropInputAttrs(const CNodePtr &conv2d_backin, const CNodePtr &depth_conv_backin) { | |||
| SetCommonAttrs(conv2d_backin, depth_conv_backin); | |||
| auto input_names = std::vector<std::string>{"input_size", "filter", "dout"}; | |||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), depth_conv_backin); | |||
| auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(conv2d_backin, kAttrStride); | |||
| if (stride.size() == 2) { | |||
| stride.insert(stride.begin(), 2, 1); | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrStride, MakeValue(stride), depth_conv_backin); | |||
| } | |||
| void SetConv2DBackpropFilterAttrs(const CNodePtr &conv2d_backfil, const CNodePtr &depth_conv_backfil) { | |||
| SetCommonAttrs(conv2d_backfil, depth_conv_backfil); | |||
| auto input_names = std::vector<std::string>{"input", "filter_size", "dout"}; | |||
| AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), depth_conv_backfil); | |||
| auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(conv2d_backfil, kAttrStride); | |||
| if (stride.size() == 2) { | |||
| stride.insert(stride.begin(), 2, 1); | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrStride, MakeValue(stride), depth_conv_backfil); | |||
| } | |||
| } // namespace | |||
| const BaseRef Conv2DUnifyMindIR::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr W = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimConv2D, X, W}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto conv2d = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv2d); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d, 0); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(conv2d, 0); | |||
| if (!NeedUpdate(conv2d, input_shape, output_shape)) { | |||
| return nullptr; | |||
| } | |||
| if (conv2d->inputs().size() != kConvInputNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " | |||
| << conv2d->inputs().size() - 1; | |||
| } | |||
| auto transpose = CreateTranspose(graph, conv2d, conv2d->input(2), true); | |||
| auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose); | |||
| SetConv2DAttrs(conv2d, depth_conv); | |||
| return depth_conv; | |||
| } | |||
| const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const { | |||
| VarPtr dout = std::make_shared<Var>(); | |||
| VarPtr weight = std::make_shared<Var>(); | |||
| VarPtr input_size = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimConv2DBackpropInput, dout, weight, input_size}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto conv2d_backin = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv2d_backin); | |||
| auto input_shape = AnfAlgo::GetOutputInferShape(conv2d_backin, 0); | |||
| auto output_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backin, 0); | |||
| if (!NeedUpdate(conv2d_backin, input_shape, output_shape)) { | |||
| return nullptr; | |||
| } | |||
| if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) { | |||
| MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got " | |||
| << conv2d_backin->inputs().size() - 1; | |||
| } | |||
| auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(2), true); | |||
| auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose); | |||
| SetConv2DBackpropInputAttrs(conv2d_backin, depth_conv_backin); | |||
| return depth_conv_backin; | |||
| } | |||
| const BaseRef Conv2DBackpropFilterUnifyMindIR::DefinePattern() const { | |||
| VarPtr dout = std::make_shared<Var>(); | |||
| VarPtr input = std::make_shared<Var>(); | |||
| VarPtr filter_size = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimConv2DBackpropFilter, dout, input, filter_size}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr Conv2DBackpropFilterUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto conv2d_backfil = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(conv2d_backfil); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backfil, 1); | |||
| auto output_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backfil, 0); | |||
| if (!NeedUpdate(conv2d_backfil, input_shape, output_shape)) { | |||
| return nullptr; | |||
| } | |||
| auto depth_conv_backfil = CreateDepthwiseConv2DBackpropFilter(graph, conv2d_backfil); | |||
| SetConv2DBackpropFilterAttrs(conv2d_backfil, depth_conv_backfil); | |||
| auto transpose = CreateTranspose(graph, conv2d_backfil, depth_conv_backfil, false); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| (void)manager->Replace(conv2d_backfil, transpose); | |||
| return transpose; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class Conv2DUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit Conv2DUnifyMindIR(bool multigraph = true) : PatternProcessPass("conv2d_unify_mindir", multigraph) {} | |||
| ~Conv2DUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| class Conv2DBackpropInputUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit Conv2DBackpropInputUnifyMindIR(bool multigraph = true) | |||
| : PatternProcessPass("conv2d_backprop_input_unify_mindir", multigraph) {} | |||
| ~Conv2DBackpropInputUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| class Conv2DBackpropFilterUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit Conv2DBackpropFilterUnifyMindIR(bool multigraph = true) | |||
| : PatternProcessPass("conv2d_backprop_filter_unify_mindir", multigraph) {} | |||
| ~Conv2DBackpropFilterUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_ | |||
| @@ -0,0 +1,269 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <numeric> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/log_adapter.h" | |||
| constexpr auto kKeepProb = "keep_prob"; | |||
| constexpr auto kSeed0 = "Seed0"; | |||
| constexpr auto kSeed1 = "Seed1"; | |||
| constexpr auto kUint8BitSize = 8; | |||
| namespace mindspore::opt { | |||
| constexpr size_t kFloat16Len = 2; // size of float16 | |||
| namespace { | |||
| AnfNodePtr GetDropoutKeepProb(const AnfNodePtr &node, float *keep_prob) { | |||
| MS_LOG(INFO) << "GetDropoutNodeInfo start."; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(keep_prob); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode) || !AnfAlgo::HasNodeAttr(kSeed0, cnode) || | |||
| !AnfAlgo::HasNodeAttr(kSeed1, cnode)) { | |||
| MS_LOG(EXCEPTION) << "Dropout node does nothave attr: keep_prob or seed0 or seed1."; | |||
| } | |||
| *keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb); | |||
| MS_LOG(INFO) << "keep_prob: " << *keep_prob; | |||
| // return dropout input. maybe tensor or pre cnode output | |||
| return cnode->input(1); | |||
| } | |||
| ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float &keep_prob, const TypePtr &dtype) { | |||
| MS_LOG(INFO) << "CreateKeepPorbValueNode start."; | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<int64_t> keep_prob_shape = {}; | |||
| ShapeVector shape = {}; | |||
| auto keep_prob_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), keep_prob_shape); | |||
| MS_EXCEPTION_IF_NULL(keep_prob_tensor); | |||
| auto data_ptr = keep_prob_tensor->data_c(); | |||
| MS_EXCEPTION_IF_NULL(data_ptr); | |||
| // keep_prob's datatype is same with input data | |||
| if (dtype->type_id() == kNumberTypeFloat16) { | |||
| float16 half_data = float16(keep_prob); | |||
| auto ret_code = memcpy_s(data_ptr, kFloat16Len, &half_data, kFloat16Len); | |||
| if (ret_code != 0) { | |||
| MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; | |||
| } | |||
| } else { | |||
| auto *val = reinterpret_cast<float *>(data_ptr); | |||
| *val = keep_prob; | |||
| } | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(dtype, shape); | |||
| auto keep_prob_value = kernel_graph->NewValueNode(abstract, keep_prob_tensor); | |||
| MS_EXCEPTION_IF_NULL(keep_prob_value); | |||
| kernel_graph->AddValueNodeToGraph(keep_prob_value); | |||
| return keep_prob_value; | |||
| } | |||
| std::vector<int64_t> GetInputShape(const AnfNodePtr &node, const AnfNodePtr &dropout_input) { | |||
| MS_LOG(INFO) << "GetInputShape start."; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(dropout_input); | |||
| std::vector<int64_t> shapes; | |||
| if (dropout_input->isa<Parameter>()) { | |||
| MS_LOG(INFO) << "Dropout input from parameter node."; | |||
| // single test case | |||
| auto dropout_input_value = dropout_input->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(dropout_input_value); | |||
| MS_EXCEPTION_IF_NULL(dropout_input_value->Shape()); | |||
| auto shape = dropout_input_value->Shape()->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| return shape->shape(); | |||
| } else if (dropout_input->isa<CNode>()) { | |||
| MS_LOG(INFO) << "Dropout input from cnode."; | |||
| auto dropout_input_node = dropout_input->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(dropout_input_node); | |||
| auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||
| std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); | |||
| return shapes; | |||
| } else { | |||
| MS_LOG(ERROR) << "Dropout input is not parameter or cnode."; | |||
| return {}; | |||
| } | |||
| } | |||
| ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape) { | |||
| MS_LOG(INFO) << "CreateShapeValueNode start."; | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<ValuePtr> dim_values{}; | |||
| abstract::AbstractBasePtrList abs{}; | |||
| for (const auto &dim : shape) { | |||
| dim_values.push_back(MakeValue(dim)); | |||
| abs.push_back(std::make_shared<abstract::AbstractScalar>(dim)); | |||
| } | |||
| auto shape_value_tuple = std::make_shared<ValueTuple>(dim_values); | |||
| MS_EXCEPTION_IF_NULL(shape_value_tuple); | |||
| auto abstract = std::make_shared<abstract::AbstractTuple>(abs); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto shape_value = kernel_graph->NewValueNode(abstract, shape_value_tuple); | |||
| MS_EXCEPTION_IF_NULL(shape_value); | |||
| kernel_graph->AddValueNodeToGraph(shape_value); | |||
| return shape_value; | |||
| } | |||
| } // namespace | |||
| const BaseRef DropoutUnifyMindIR::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr Y = std::make_shared<Var>(); | |||
| auto prim = std::make_shared<Primitive>(kDropoutOpName); | |||
| auto ref = VectorRef({prim, X}); | |||
| return VectorRef({prim::kPrimTupleGetItem, ref, Y}); | |||
| } | |||
| const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto tuple_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_cnode); | |||
| auto dropout_node = tuple_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(dropout_node); | |||
| float keep_prob = 0; | |||
| auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); | |||
| auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; | |||
| auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); | |||
| auto shape = GetInputShape(dropout_node, dropout_input); | |||
| auto shape_value = CreateShapeValueNode(func_graph, shape); | |||
| // CreateDropoutGenMask | |||
| auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); | |||
| output_size = output_size / kUint8BitSize; | |||
| MS_LOG(INFO) << "Output_size: " << output_size; | |||
| std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), | |||
| shape_value, keep_prob_value}; | |||
| CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); | |||
| MS_EXCEPTION_IF_NULL(dropout_gen_mask); | |||
| AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); | |||
| ShapeVector dropout_gen_mask_output = {output_size}; | |||
| auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output); | |||
| MS_EXCEPTION_IF_NULL(gen_mask_abstract); | |||
| dropout_gen_mask->set_abstract(gen_mask_abstract); | |||
| dropout_gen_mask->set_scope(node->scope()); | |||
| // CreateDropoutDoMask | |||
| std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), | |||
| dropout_input, dropout_gen_mask, keep_prob_value}; | |||
| auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); | |||
| MS_EXCEPTION_IF_NULL(dropout_do_mask); | |||
| ShapeVector dropout_do_mask_output = shape; | |||
| auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask_output); | |||
| dropout_do_mask->set_abstract(do_mask_abstract); | |||
| dropout_do_mask->set_scope(node->scope()); | |||
| return dropout_do_mask; | |||
| } | |||
| const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr Y = std::make_shared<Var>(); | |||
| MS_EXCEPTION_IF_NULL(X); | |||
| MS_EXCEPTION_IF_NULL(Y); | |||
| auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName); | |||
| auto tuple_getitem_prim = prim::kPrimTupleGetItem; | |||
| auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName); | |||
| MS_EXCEPTION_IF_NULL(dropout_prim); | |||
| MS_EXCEPTION_IF_NULL(dropout_grad_prim); | |||
| auto ref0 = VectorRef({dropout_prim, X}); | |||
| auto ref1 = VectorRef({tuple_getitem_prim, ref0, Y}); | |||
| return VectorRef({dropout_grad_prim, grad_input_, ref1}); | |||
| } | |||
| const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto dropout_grad = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(dropout_grad); | |||
| auto tuple_getitem = dropout_grad->input(2); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||
| auto tuple_getitem_cnode = tuple_getitem->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); | |||
| auto dropout_node = tuple_getitem_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(dropout_node); | |||
| float keep_prob = 0; | |||
| auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); | |||
| auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; | |||
| auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); | |||
| auto shape = GetInputShape(dropout_node, dropout_input); | |||
| auto shape_value = CreateShapeValueNode(func_graph, shape); | |||
| // CreateDropoutGenMask | |||
| auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); | |||
| output_size = output_size / kUint8BitSize; | |||
| MS_LOG(INFO) << "Output_size: " << output_size; | |||
| std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), | |||
| shape_value, keep_prob_value}; | |||
| CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); | |||
| MS_EXCEPTION_IF_NULL(dropout_gen_mask); | |||
| AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); | |||
| ShapeVector dropout_gen_mask_output = {output_size}; | |||
| auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output); | |||
| MS_EXCEPTION_IF_NULL(gen_mask_abstract); | |||
| dropout_gen_mask->set_abstract(gen_mask_abstract); | |||
| dropout_gen_mask->set_scope(dropout_node->scope()); | |||
| // AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); | |||
| // CreateDropoutDoMask-forward | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(dropout_node); | |||
| if (iter != node_users.end()) { | |||
| for (auto &node_index : iter->second) { | |||
| // Dropout has two outputs, so output node is tuple_getitem | |||
| auto tuple_getitem_cnode2 = node_index.first->cast<CNodePtr>(); | |||
| // check if Dropout's first output, which is used by forward, is used. | |||
| auto getitem_index = GetValue<int64_t>(tuple_getitem_cnode2->input(2)->cast<ValueNodePtr>()->value()); | |||
| if (getitem_index == 0) { | |||
| std::vector<AnfNodePtr> dropout_do_mask1_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), | |||
| dropout_input, dropout_gen_mask, keep_prob_value}; | |||
| auto dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs); | |||
| MS_EXCEPTION_IF_NULL(dropout_do_mask1); | |||
| ShapeVector dropout_do_mask1_output = shape; | |||
| auto do_mask_abstract1 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask1_output); | |||
| dropout_do_mask1->set_abstract(do_mask_abstract1); | |||
| dropout_do_mask1->set_scope(dropout_node->scope()); | |||
| (void)manager->Replace(tuple_getitem_cnode2, dropout_do_mask1); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| // CreateDropoutDoMask-backward | |||
| if (equiv->find(grad_input_) == equiv->end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find grad_input in this pattern."; | |||
| } | |||
| auto grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]); | |||
| std::vector<AnfNodePtr> dropout_do_mask2_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), | |||
| grad_input, dropout_gen_mask, keep_prob_value}; | |||
| auto dropout_do_mask2 = func_graph->NewCNode(dropout_do_mask2_inputs); | |||
| MS_EXCEPTION_IF_NULL(dropout_do_mask2); | |||
| ShapeVector dropout_do_mask2_output = shape; | |||
| auto do_mask_abstract2 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask2_output); | |||
| dropout_do_mask2->set_abstract(do_mask_abstract2); | |||
| dropout_do_mask2->set_scope(node->scope()); | |||
| return dropout_do_mask2; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class DropoutUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit DropoutUnifyMindIR(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir", multigraph) {} | |||
| ~DropoutUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| class DropoutGradUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit DropoutGradUnifyMindIR(bool multigraph = true) | |||
| : PatternProcessPass("dropout_grad_unify_mindir", multigraph) { | |||
| grad_input_ = std::make_shared<Var>(); | |||
| } | |||
| ~DropoutGradUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| VarPtr grad_input_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ | |||
| @@ -0,0 +1,148 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kMaxPoolInputNum = 2; | |||
| constexpr size_t kMaxPoolAttrAxisNum = 4; | |||
| constexpr size_t kMaxPoolGradInputNum = 4; | |||
| constexpr size_t kMaxPoolWithArgmaxOutputNum = 2; | |||
| CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) { | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad); | |||
| if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got " | |||
| << maxpool_grad->inputs().size() - 1; | |||
| } | |||
| auto maxpool_anf = maxpool_grad->input(2); | |||
| MS_EXCEPTION_IF_NULL(maxpool_anf); | |||
| return maxpool_anf->cast<CNodePtr>(); | |||
| } | |||
| CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(maxpool); | |||
| if (maxpool->inputs().size() != kMaxPoolInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPool's input number should be " << kMaxPoolInputNum - 1 << ", but got " | |||
| << maxpool->inputs().size() - 1; | |||
| } | |||
| std::vector<AnfNodePtr> maxpool_argmax_inputs = {NewValueNode(std::make_shared<Primitive>(kMaxPoolWithArgmaxOpName)), | |||
| maxpool->input(1)}; | |||
| auto maxpool_argmax = graph->NewCNode(maxpool_argmax_inputs); | |||
| MS_EXCEPTION_IF_NULL(maxpool_argmax); | |||
| maxpool_argmax->set_scope(maxpool->scope()); | |||
| // MaxPoolWithArgmax's second output is argmax, whose datatype is uint16 and with same shape as first output | |||
| TypeId argmax_dtype = kNumberTypeUInt16; | |||
| auto types = {AnfAlgo::GetOutputInferDataType(maxpool, 0), argmax_dtype}; | |||
| auto out_shape = AnfAlgo::GetOutputInferShape(maxpool, 0); | |||
| auto shapes = {out_shape, out_shape}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_argmax.get()); | |||
| return maxpool_argmax; | |||
| } | |||
| CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool_grad, | |||
| const std::vector<AnfNodePtr> &maxpool_argmax_outputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad); | |||
| if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got " | |||
| << maxpool_grad->inputs().size() - 1; | |||
| } | |||
| // MaxPoolGrad's inputs are {input, output, grad_input}, MaxPoolGradWithArgmax's inputs are | |||
| // {input, grad_input, argmax_output} | |||
| std::vector<AnfNodePtr> maxpool_grad_argmax_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(1), | |||
| maxpool_grad->input(3), maxpool_argmax_outputs[1]}; | |||
| auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs); | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad_argmax); | |||
| maxpool_grad_argmax->set_scope(maxpool_grad->scope()); | |||
| maxpool_grad_argmax->set_abstract(maxpool_grad->abstract()); | |||
| return maxpool_grad_argmax; | |||
| } | |||
| void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax, | |||
| const CNodePtr &maxpool_grad_argmax) { | |||
| auto strides = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrStrides); | |||
| auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrKsize); | |||
| if (strides.size() != kMaxPoolAttrAxisNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPool's attr strides has wrong axis number, should be " << kMaxPoolAttrAxisNum | |||
| << ", but got " << strides.size(); | |||
| } | |||
| if (ksize.size() != kMaxPoolAttrAxisNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPool's attr ksize has wrong axis number, should be " << kMaxPoolAttrAxisNum << ", but got " | |||
| << ksize.size(); | |||
| } | |||
| // note that strides and ksize change from (1, 1, x, y) to (1, x, y, 1) | |||
| for (size_t i = 1; i <= 2; ++i) { | |||
| strides[i] = strides[i + 1]; | |||
| ksize[i] = ksize[i + 1]; | |||
| } | |||
| strides[3] = 1; | |||
| ksize[3] = 1; | |||
| AnfAlgo::CopyNodeAttrs(maxpool, maxpool_argmax); | |||
| AnfAlgo::CopyNodeAttrs(maxpool_grad, maxpool_grad_argmax); | |||
| AnfAlgo::SetNodeAttr(kAttrStrides, MakeValue(strides), maxpool_argmax); | |||
| AnfAlgo::SetNodeAttr(kAttrStrides, MakeValue(strides), maxpool_grad_argmax); | |||
| AnfAlgo::SetNodeAttr(kAttrKsize, MakeValue(ksize), maxpool_argmax); | |||
| AnfAlgo::SetNodeAttr(kAttrKsize, MakeValue(ksize), maxpool_grad_argmax); | |||
| } | |||
| } // namespace | |||
| const BaseRef MaxPool2MaxPoolWithArgmax::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr Y = std::make_shared<Var>(); | |||
| VectorRef maxpool({prim::kPrimMaxPool, X}); | |||
| VectorRef pattern({prim::kPrimMaxPoolGrad, X, maxpool, Y}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr MaxPool2MaxPoolWithArgmax::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto maxpool_grad = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad); | |||
| auto maxpool = GetMaxPool(maxpool_grad); | |||
| MS_EXCEPTION_IF_NULL(maxpool); | |||
| auto maxpool_argmax = CreateMaxPoolWithArgmax(graph, maxpool); | |||
| std::vector<AnfNodePtr> maxpool_argmax_outputs; | |||
| CreateMultipleOutputsOfAnfNode(graph, maxpool_argmax, kMaxPoolWithArgmaxOutputNum, &maxpool_argmax_outputs); | |||
| auto maxpool_grad_argmax = CreateMaxPoolGradWithArgmax(graph, maxpool_grad, maxpool_argmax_outputs); | |||
| SetNodeAttrs(maxpool, maxpool_grad, maxpool_argmax, maxpool_grad_argmax); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| (void)manager->Replace(maxpool, maxpool_argmax_outputs[0]); | |||
| return maxpool_grad_argmax; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class MaxPool2MaxPoolWithArgmax : public PatternProcessPass { | |||
| public: | |||
| explicit MaxPool2MaxPoolWithArgmax(bool multigraph = true) | |||
| : PatternProcessPass("maxpool_to_maxpool_with_argmax", multigraph) {} | |||
| ~MaxPool2MaxPoolWithArgmax() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_ | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "base/core_ops.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4; | |||
| bool IsC(const BaseRef &n) { | |||
| if (utils::isa<AnfNodePtr>(n)) { | |||
| AnfNodePtr in = utils::cast<AnfNodePtr>(n); | |||
| MS_EXCEPTION_IF_NULL(in); | |||
| return in->isa<ValueNode>(); | |||
| } | |||
| return false; | |||
| } | |||
| CNodePtr GetMaxPoolWithArgmax(const CNodePtr &maxpool_grad_with_argmax) { | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax); | |||
| if (maxpool_grad_with_argmax->inputs().size() != kMaxPoolGradWithArgmaxInputNum) { | |||
| MS_LOG(EXCEPTION) << "MaxPoolGradWithArgmax has wrong input size."; | |||
| } | |||
| auto tuple_getitem0_anf = maxpool_grad_with_argmax->input(3); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem0_anf); | |||
| return tuple_getitem0_anf->cast<CNodePtr>(); | |||
| } | |||
| } // namespace | |||
| const BaseRef MaxPoolWithArgmaxUnifyMindIR::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VectorRef pattern({prim::kPrimMaxPoolWithArgmax, X}); | |||
| return pattern; | |||
| } | |||
| const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto maxpool_with_argmax = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(maxpool_with_argmax); | |||
| TypeId argmax_dtype = kNumberTypeUInt16; | |||
| auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_with_argmax, kAttrKsize); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(maxpool_with_argmax, 0); | |||
| auto argmax_shape = output_shape; | |||
| if (argmax_shape.size() != 4) { | |||
| MS_LOG(DEBUG) << "argmax's infer shape size not equal 4"; | |||
| } | |||
| argmax_shape[2] = ksize[1] * ksize[2]; | |||
| argmax_shape[3] = (output_shape[2] * output_shape[3] + 15) / 16 + 1; | |||
| auto types = {AnfAlgo::GetOutputInferDataType(maxpool_with_argmax, 0), argmax_dtype}; | |||
| auto shapes = {output_shape, argmax_shape}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get()); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| return maxpool_with_argmax; | |||
| } | |||
| const BaseRef MaxPoolGradWithArgmaxUnifyMindIR::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr Y = std::make_shared<Var>(); | |||
| VarPtr index0 = std::make_shared<CondVar>(IsC); | |||
| VectorRef maxpool_with_argmax({prim::kPrimMaxPoolWithArgmax, X}); | |||
| VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, maxpool_with_argmax, index0}); | |||
| VectorRef maxpool_grad_with_argmax({prim::kPrimMaxPoolGradWithArgmax, X, Y, tuple_getitem0}); | |||
| return maxpool_grad_with_argmax; | |||
| } | |||
| const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto maxpool_grad_with_argmax = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax); | |||
| auto tuple_getitem0_anf = GetMaxPoolWithArgmax(maxpool_grad_with_argmax); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem0_anf); | |||
| TypeId argmax_dtype = kNumberTypeUInt16; | |||
| auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_grad_with_argmax, kAttrKsize); | |||
| auto argmax_shape = AnfAlgo::GetOutputInferShape(tuple_getitem0_anf, 0); | |||
| if (argmax_shape.size() != 4) { | |||
| MS_LOG(DEBUG) << "argmax's infer shape size not equal 4"; | |||
| } | |||
| argmax_shape[3] = (argmax_shape[2] * argmax_shape[3] + 15) / 16 + 1; | |||
| argmax_shape[2] = ksize[1] * ksize[2]; | |||
| AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get()); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| return maxpool_grad_with_argmax; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class MaxPoolWithArgmaxUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit MaxPoolWithArgmaxUnifyMindIR(bool multigraph = true) | |||
| : PatternProcessPass("maxpool_with_argmax_unify_mindir", multigraph) {} | |||
| ~MaxPoolWithArgmaxUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| class MaxPoolGradWithArgmaxUnifyMindIR : public PatternProcessPass { | |||
| public: | |||
| explicit MaxPoolGradWithArgmaxUnifyMindIR(bool multigraph = true) | |||
| : PatternProcessPass("maxpool_grad_with_argmax_unify_mindir", multigraph) {} | |||
| ~MaxPoolGradWithArgmaxUnifyMindIR() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ | |||
| @@ -101,6 +101,10 @@ constexpr size_t kFusedMulApplyMomentumOutputNum = 2; | |||
| constexpr size_t kSplitInputNum = 2; | |||
| constexpr size_t kGatherV2DynInputNum = 3; | |||
| constexpr size_t kUnsortedSegmentSumInputNum = 2; | |||
| constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2; | |||
| constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputNum = 3; | |||
| constexpr size_t kOneHotOutputNum = 1; | |||
| constexpr size_t kOneHotInputNum = 5; | |||
| enum FusedBatchNormInput { | |||
| kX = 1, | |||
| @@ -32,6 +32,10 @@ | |||
| #include "runtime/device/ascend/ascend_kernel_runtime.h" | |||
| #include "backend/optimizer/ascend/ascend_backend_optimization.h" | |||
| #include "backend/optimizer/common/common_backend_optimization.h" | |||
| #include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h" | |||
| #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" | |||
| #include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" | |||
| #include "runtime/device/kernel_adjust.h" | |||
| #include "runtime/device/ascend/ascend_stream_assign.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| @@ -423,6 +427,35 @@ void AscendSession::Init(uint32_t device_id) { | |||
| runtime_instance->CreateContext(); | |||
| } | |||
| void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| if (save_graphs) { | |||
| std::string file_name = "hwopt_d_before_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir"; | |||
| DumpIR(file_name, graph); | |||
| DumpIRProto(graph, "before_unify_mindir_hwopt_" + std::to_string(graph->graph_id())); | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm"); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>()); | |||
| unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>()); | |||
| optimizer->AddPassManager(unify_mindir_pm); | |||
| (void)optimizer->Optimize(graph); | |||
| graph->SetExecOrderByDefault(); | |||
| if (save_graphs) { | |||
| std::string file_name = "hwopt_d_after_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir"; | |||
| DumpIR(file_name, graph); | |||
| } | |||
| } | |||
| GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| MS_LOG(INFO) << "Start"; | |||
| // construct graph, if successfully, graph_sum_ + 1 | |||
| @@ -438,6 +471,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); | |||
| // Update Graph Dynamic Shape Attr | |||
| UpdateAllGraphDynamicShapeAttr(all_graphs); | |||
| for (const auto &graph : all_graphs) { | |||
| UnifyMindIR(graph); | |||
| } | |||
| BackendOptimization(all_graphs); | |||
| // empty graph dont entry to backend | |||
| if (root_graph->execution_order().empty()) { | |||
| @@ -1219,7 +1255,6 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| opt::AscendBackendIRFusionOptimization(graph); | |||
| graph->SetExecOrderByDefault(); | |||
| @@ -51,6 +51,7 @@ class AscendSession : public SessionBasic { | |||
| void SyncStream() override; | |||
| protected: | |||
| void UnifyMindIR(const KernelGraphPtr &graph) override; | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override; | |||
| @@ -32,6 +32,7 @@ class CPUSession : public SessionBasic { | |||
| void Init(uint32_t device_id) override { InitExecutor(kCPUDevice, device_id); } | |||
| protected: | |||
| void UnifyMindIR(const KernelGraphPtr &graph) override { return; } | |||
| void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override; | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| @@ -35,6 +35,7 @@ class GPUSession : public SessionBasic { | |||
| void SyncStream() override; | |||
| protected: | |||
| void UnifyMindIR(const KernelGraphPtr &graph) override { return; } | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| @@ -943,6 +943,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con | |||
| // Update Graph Dynamic Shape Attr | |||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | |||
| UnifyMindIR(graph); | |||
| opt::BackendCommonOptimization(graph); | |||
| graph->SetInputNodes(); | |||
| auto input_nodes = graph->input_nodes(); | |||
| @@ -1610,6 +1611,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf | |||
| // set output | |||
| CreateOutputNode(cnode, graph); | |||
| graph->SetInputNodes(); | |||
| UnifyMindIR(graph); | |||
| return graph; | |||
| } | |||
| @@ -147,6 +147,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | |||
| VectorRef *outputs, | |||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); | |||
| virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0; | |||
| virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; | |||
| virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } | |||
| virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) { | |||
| @@ -163,7 +163,9 @@ constexpr auto kBatchToSpaceOpName = "BatchToSpace"; | |||
| constexpr auto kPadOpName = "Pad"; | |||
| constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput"; | |||
| constexpr auto kConv2DBackpropFilterOpName = "Conv2DBackpropFilter"; | |||
| constexpr auto kDepthwiseConv2dNativeName = "DepthwiseConv2dNative"; | |||
| constexpr auto kDepthwiseConv2dNativeOpName = "DepthwiseConv2dNative"; | |||
| constexpr auto kDepthwiseConv2dNativeBackpropInputOpName = "DepthwiseConv2dNativeBackpropInput"; | |||
| constexpr auto kDepthwiseConv2dNativeBackpropFilterOpName = "DepthwiseConv2dNativeBackpropFilter"; | |||
| constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2"; | |||
| constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2"; | |||
| constexpr auto kLabelSetOpName = "LabelSet"; | |||
| @@ -204,6 +206,8 @@ constexpr auto kPaddingOpName = "Padding"; | |||
| constexpr auto kAvgPoolOpName = "AvgPool"; | |||
| constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; | |||
| constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; | |||
| constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax"; | |||
| constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax"; | |||
| constexpr auto kTensorAddOpName = "TensorAdd"; | |||
| constexpr auto kCastOpName = "Cast"; | |||
| constexpr auto kGreaterEqualOpName = "GreaterEqual"; | |||
| @@ -250,6 +254,13 @@ constexpr auto kMatMulV2OpName = "MatMulV2"; | |||
| constexpr auto kBroadcastToOpName = "BroadcastTo"; | |||
| constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; | |||
| constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2"; | |||
| constexpr auto kDropoutOpName = "Dropout"; | |||
| constexpr auto kDropoutGradOpName = "DropoutGrad"; | |||
| constexpr auto kDropoutGenMaskOpName = "DropoutGenMask"; | |||
| constexpr auto kDropoutDoMaskOpName = "DropoutDoMask"; | |||
| constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits"; | |||
| constexpr auto kOneHotOpName = "OneHot"; | |||
| constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; | |||
| // Hcom Op Type | |||
| constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | |||
| @@ -272,6 +283,7 @@ constexpr auto kAttrEpsilon = "epsilon"; | |||
| constexpr auto kAttrFactor = "factor"; | |||
| constexpr auto kAttrIsRef = "isRef"; | |||
| constexpr auto kAttrDataShape = "data_shape"; | |||
| constexpr auto kAttrDataFormat = "data_format"; | |||
| constexpr auto kAttrAxis = "axis"; | |||
| constexpr auto kAttrKeepDims = "keep_dims"; | |||
| constexpr auto kAttrShapeGamma = "shape_gamma"; | |||
| @@ -348,6 +360,15 @@ constexpr auto kAttrPynativeNextOpName = "next_op"; | |||
| constexpr auto kAttrPynativeNextIndex = "next_index"; | |||
| constexpr auto kAttrCompileInfo = "compile_info"; | |||
| constexpr auto kAttrFusionType = "fusion_type"; | |||
| constexpr auto kAttrStride = "stride"; | |||
| constexpr auto kAttrStrides = "strides"; | |||
| constexpr auto kAttrKsize = "ksize"; | |||
| constexpr auto kAttrKernelSize = "kernel_size"; | |||
| constexpr auto kAttrDilation = "dilation"; | |||
| constexpr auto kAttrPadMode = "pad_mode"; | |||
| constexpr auto kAttrPad = "pad"; | |||
| constexpr auto kAttrPadding = "padding"; | |||
| constexpr auto kAttrIsGrad = "is_grad"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||
| @@ -134,6 +134,8 @@ inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||
| inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | |||
| inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); | |||
| inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); | |||
| inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("MaxPoolWithArgmax"); | |||
| inline const PrimitivePtr kPrimMaxPoolGradWithArgmax = std::make_shared<Primitive>("MaxPoolGradWithArgmax"); | |||
| inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); | |||
| inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool"); | |||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||
| @@ -141,37 +141,21 @@ class Dropout(Cell): | |||
| raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) | |||
| Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||
| Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) | |||
| self.keep_prob = keep_prob | |||
| seed0, seed1 = _get_graph_seed(0, "dropout") | |||
| self.seed0 = seed0 | |||
| self.seed1 = seed1 | |||
| self.dtype = dtype | |||
| self.get_shape = P.Shape() | |||
| self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) | |||
| self.dropout_do_mask = P.DropoutDoMask() | |||
| self.cast = P.Cast() | |||
| self.is_ascend = context.get_context('device_target') in ["Ascend"] | |||
| self.dropout = P.Dropout(keep_prob) | |||
| self.keep_prob = keep_prob | |||
| self.dropout = P.Dropout(keep_prob, seed0, seed1) | |||
| def construct(self, x): | |||
| if not self.training: | |||
| return x | |||
| if not self.is_ascend: | |||
| out, _ = self.dropout(x) | |||
| return out | |||
| if self.keep_prob == 1: | |||
| return x | |||
| shape = self.get_shape(x) | |||
| dtype = P.DType()(x) | |||
| if _is_float_dtype(dtype): | |||
| keep_prob = self.cast(self.keep_prob, dtype) | |||
| else: | |||
| keep_prob = self.cast(self.keep_prob, mstype.float16) | |||
| output = self.dropout_gen_mask(shape, keep_prob) | |||
| return self.dropout_do_mask(x, output, keep_prob) | |||
| out, _ = self.dropout(x) | |||
| return out | |||
| def extend_repr(self): | |||
| return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) | |||
| @@ -19,7 +19,7 @@ from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer, Initializer | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import Validator, Rel, twice | |||
| from mindspore._extends import cell_attr_register | |||
| @@ -247,28 +247,8 @@ class Conv2d(_Conv): | |||
| dilation=self.dilation, | |||
| group=self.group, | |||
| data_format=self.format) | |||
| self._init_depthwise_conv2d() | |||
| self.bias_add = P.BiasAdd() | |||
| def _init_depthwise_conv2d(self): | |||
| """Initialize depthwise conv2d op""" | |||
| if context.get_context("device_target") == "Ascend" and self.group > 1: | |||
| self.dilation = self._dilation | |||
| Validator.check_equal_int(self.group, self.in_channels, 'group') | |||
| Validator.check_equal_int(self.group, self.out_channels, 'group') | |||
| self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=self.pad_mode, | |||
| pad=self.padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation) | |||
| weight_shape = [1, self.in_channels, *self.kernel_size] | |||
| if isinstance(self.weight_init, Tensor): | |||
| self.weight_init = Tensor(self.weight_init.asnumpy().swapaxes(0, 1), self.weight_init.dtype) | |||
| if isinstance(self.weight_init, Initializer): | |||
| self.weight_init.shape = weight_shape | |||
| self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight') | |||
| def construct(self, x): | |||
| output = self.conv2d(x, self.weight) | |||
| if self.has_bias: | |||
| @@ -124,16 +124,9 @@ class MaxPool2d(_PoolNd): | |||
| strides=self.stride, | |||
| padding=self.pad_mode, | |||
| data_format=self.format) | |||
| self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size, | |||
| strides=self.stride, | |||
| padding=self.pad_mode) | |||
| self.is_tbe = context.get_context("device_target") == "Ascend" | |||
| def construct(self, x): | |||
| if self.is_tbe and self.training: | |||
| out = self.max_pool_with_arg_max(x)[0] | |||
| else: | |||
| out = self.max_pool(x) | |||
| out = self.max_pool(x) | |||
| return out | |||
| @@ -198,22 +191,15 @@ class MaxPool1d(_PoolNd): | |||
| self.max_pool = P.MaxPool(ksize=self.kernel_size, | |||
| strides=self.stride, | |||
| padding=self.pad_mode) | |||
| self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size, | |||
| strides=self.stride, | |||
| padding=self.pad_mode) | |||
| self.shape = F.shape | |||
| self.reduce_mean = P.ReduceMean(keep_dims=True) | |||
| self.expand = P.ExpandDims() | |||
| self.squeeze = P.Squeeze(2) | |||
| self.is_tbe = context.get_context("device_target") == "Ascend" | |||
| def construct(self, x): | |||
| _shape_check(self.shape(x)) | |||
| x = self.expand(x, 2) | |||
| if self.is_tbe and self.training: | |||
| output = self.max_pool_with_arg_max(x)[0] | |||
| else: | |||
| output = self.max_pool(x) | |||
| output = self.max_pool(x) | |||
| output = self.squeeze(output) | |||
| return output | |||
| @@ -433,27 +433,15 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||
| (self.is_ge_backend or self.is_ascend) | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| Validator.check_equal_int(group, in_channels, 'group') | |||
| Validator.check_equal_int(group, out_channels, 'group') | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation) | |||
| weight_shape = [1, in_channels, *self.kernel_size] | |||
| channel_axis = 1 | |||
| else: | |||
| self.conv = P.Conv2D(out_channel=out_channels, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.conv = P.Conv2D(out_channel=out_channels, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.channel_axis = channel_axis | |||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||
| self.bias_add = P.BiasAdd() | |||
| @@ -651,27 +639,15 @@ class Conv2dBnFoldQuant(Cell): | |||
| self.is_gpu = context.get_context('device_target') == "GPU" | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| Validator.check_equal_int(group, in_channels, 'group') | |||
| Validator.check_equal_int(group, out_channels, 'group') | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation) | |||
| weight_shape = [1, in_channels, *self.kernel_size] | |||
| channel_axis = 1 | |||
| else: | |||
| self.conv = P.Conv2D(out_channel=out_channels, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.conv = P.Conv2D(out_channel=out_channels, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||
| self.bias_add = P.BiasAdd() | |||
| if Validator.check_bool(has_bias): | |||
| @@ -830,28 +806,16 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||
| else: | |||
| self.bias = None | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| Validator.check_equal_int(group, in_channels, 'group') | |||
| Validator.check_equal_int(group, out_channels, 'group') | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation) | |||
| weight_shape = [1, in_channels, *self.kernel_size] | |||
| channel_axis = 1 | |||
| else: | |||
| self.conv = P.Conv2D(out_channel=self.out_channels, | |||
| kernel_size=self.kernel_size, | |||
| mode=1, | |||
| pad_mode=self.pad_mode, | |||
| pad=self.padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=self.group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.conv = P.Conv2D(out_channel=self.out_channels, | |||
| kernel_size=self.kernel_size, | |||
| mode=1, | |||
| pad_mode=self.pad_mode, | |||
| pad=self.padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=self.group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||
| self.fake_quant_weight = quant_config.weight(min_init=-6, | |||
| max_init=6, | |||
| @@ -963,10 +927,11 @@ class Conv2dQuant(Cell): | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=self.group) | |||
| channel_axis = 0 | |||
| self.fake_quant_weight = quant_config.weight(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| channel_axis=0, | |||
| channel_axis=channel_axis, | |||
| num_channels=out_channels, | |||
| quant_dtype=quant_dtype) | |||
| @@ -1574,32 +1574,12 @@ class MaxPoolWithArgmax(_Pool): | |||
| def infer_shape(self, x_shape): | |||
| out_shape = _Pool.infer_shape(self, x_shape) | |||
| _, _, out_h, out_w = out_shape | |||
| _, kernel_h, kernel_w, _ = self.ksize | |||
| argmax_shape = [] | |||
| if self.is_tbe: | |||
| for i in range(4): | |||
| if i == 2: | |||
| dim = kernel_h * kernel_w | |||
| argmax_shape.append(dim) | |||
| elif i == 3: | |||
| dim = math.ceil(out_h * out_w / 16) + 1 | |||
| argmax_shape.append(dim) | |||
| else: | |||
| argmax_shape.append(x_shape[i]) | |||
| else: | |||
| argmax_shape = out_shape | |||
| return out_shape, argmax_shape | |||
| return out_shape, out_shape | |||
| def infer_dtype(self, x_dtype): | |||
| out_dtype = x_dtype | |||
| validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) | |||
| argmax_dtype = mstype.uint16 | |||
| if self.is_gpu: | |||
| argmax_dtype = mstype.int32 | |||
| return out_dtype, argmax_dtype | |||
| argmax_dtype = mstype.int32 | |||
| return x_dtype, argmax_dtype | |||
| class AvgPool(_Pool): | |||
| @@ -6070,7 +6050,9 @@ class Dropout(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, keep_prob=0.5): | |||
| def __init__(self, keep_prob=0.5, Seed0=0, Seed1=0): | |||
| self.seed0 = validator.check_value_type("Seed0", Seed0, [int], self.name) | |||
| self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name) | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) | |||
| def infer_shape(self, x_shape): | |||
| @@ -1615,7 +1615,7 @@ test_case_nn_ops = [ | |||
| ('MaxPoolWithArgmax', { | |||
| 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), | |||
| 'desc_inputs': [[128, 32, 32, 64]], | |||
| 'desc_bprop': [[128, 32, 16, 32], ([128, 32, 4, 33], {'dtype': np.uint16})]}), | |||
| 'desc_bprop': [[128, 32, 16, 32], ([128, 32, 16, 32], {'dtype': np.int32})]}), | |||
| ('SoftmaxCrossEntropyWithLogits', { | |||
| 'block': P.SoftmaxCrossEntropyWithLogits(), | |||
| 'desc_inputs': [[1, 10], [1, 10]], | |||
| @@ -18,7 +18,11 @@ import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.seed import _get_graph_seed | |||
| from mindspore.common.api import _executor | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| @@ -47,13 +51,61 @@ class GradWrap(nn.Cell): | |||
| return grad_all(self.network)(x, y, b) | |||
| @constexpr | |||
| def _is_float_dtype(dtype): | |||
| if dtype in [mstype.float32, mstype.float16]: | |||
| return True | |||
| return False | |||
| class Dropout(nn.Cell): | |||
| def __init__(self, keep_prob=0.5, dtype=mstype.float32): | |||
| super(Dropout, self).__init__() | |||
| if keep_prob <= 0 or keep_prob > 1: | |||
| raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) | |||
| Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||
| Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) | |||
| self.keep_prob = keep_prob | |||
| seed0, seed1 = _get_graph_seed(0, "dropout") | |||
| self.seed0 = seed0 | |||
| self.seed1 = seed1 | |||
| self.dtype = dtype | |||
| self.get_shape = P.Shape() | |||
| self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) | |||
| self.dropout_do_mask = P.DropoutDoMask() | |||
| self.cast = P.Cast() | |||
| self.is_gpu = context.get_context('device_target') in ["GPU"] | |||
| self.dropout = P.Dropout(keep_prob) | |||
| def construct(self, x): | |||
| if not self.training: | |||
| return x | |||
| if self.is_gpu: | |||
| out, _ = self.dropout(x) | |||
| return out | |||
| if self.keep_prob == 1: | |||
| return x | |||
| shape = self.get_shape(x) | |||
| dtype = P.DType()(x) | |||
| if _is_float_dtype(dtype): | |||
| keep_prob = self.cast(self.keep_prob, dtype) | |||
| else: | |||
| keep_prob = self.cast(self.keep_prob, mstype.float16) | |||
| output = self.dropout_gen_mask(shape, keep_prob) | |||
| return self.dropout_do_mask(x, output, keep_prob) | |||
| def extend_repr(self): | |||
| return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) | |||
| # model_parallel test | |||
| def test_two_matmul_dropout(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy2, strategy3): | |||
| super().__init__() | |||
| self.matmul1 = P.MatMul().shard(strategy1) | |||
| self.dropout = nn.Dropout() | |||
| self.dropout = Dropout() | |||
| self.dropout.dropout_do_mask.shard(strategy2) | |||
| self.dropout.dropout_gen_mask.shard(strategy2) | |||
| self.matmul2 = P.MatMul().shard(strategy3) | |||
| @@ -19,11 +19,14 @@ import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import _selected_ops | |||
| from mindspore.parallel._utils import _reset_op_id | |||
| from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| @@ -66,6 +69,33 @@ class AllToAllNet(nn.Cell): | |||
| return x | |||
| class SoftmaxCrossEntropyWithLogits(_Loss): | |||
| def __init__(self, | |||
| sparse=False, | |||
| reduction='none'): | |||
| super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) | |||
| self.sparse = sparse | |||
| self.reduction = reduction | |||
| self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits() | |||
| self.one_hot = P.OneHot() | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| self.off_value = Tensor(0., mstype.float32) | |||
| self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] | |||
| if self.is_cpugpu: | |||
| self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() | |||
| def construct(self, logits, labels): | |||
| if self.is_cpugpu and self.sparse and self.reduction == 'mean': | |||
| x = self.sparse_softmax_cross_entropy(logits, labels) | |||
| return x | |||
| if self.sparse: | |||
| labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value) | |||
| x = self.softmax_cross_entropy(logits, labels)[0] | |||
| return self.get_loss(x) | |||
| def all_to_all_net(): | |||
| return AllToAllNet() | |||