Browse Source

!9723 add ir passes to unify mindir

From: @yuchaojie
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e10dbb4a7f
26 changed files with 1226 additions and 152 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc
  2. +320
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc
  3. +51
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.h
  4. +269
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc
  5. +47
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h
  6. +148
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.cc
  7. +34
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h
  8. +115
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.cc
  9. +43
    -0
      mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h
  10. +4
    -0
      mindspore/ccsrc/backend/optimizer/common/helper.h
  11. +36
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc
  12. +1
    -0
      mindspore/ccsrc/backend/session/ascend_session.h
  13. +1
    -0
      mindspore/ccsrc/backend/session/cpu_session.h
  14. +1
    -0
      mindspore/ccsrc/backend/session/gpu_session.h
  15. +2
    -0
      mindspore/ccsrc/backend/session/session_basic.cc
  16. +1
    -0
      mindspore/ccsrc/backend/session/session_basic.h
  17. +22
    -1
      mindspore/ccsrc/utils/utils.h
  18. +2
    -0
      mindspore/core/base/core_ops.h
  19. +4
    -20
      mindspore/nn/layer/basic.py
  20. +1
    -21
      mindspore/nn/layer/conv.py
  21. +2
    -16
      mindspore/nn/layer/pooling.py
  22. +30
    -65
      mindspore/nn/layer/quant.py
  23. +6
    -24
      mindspore/ops/operations/nn_ops.py
  24. +1
    -1
      tests/ut/python/ops/test_ops.py
  25. +53
    -1
      tests/ut/python/parallel/test_matmul_dropout.py
  26. +31
    -1
      tests/ut/python/parallel/test_one_dev.py

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc View File

@@ -1085,7 +1085,7 @@ std::string TbeKernelBuild::GetNodeFusionType(const mindspore::CNodePtr &cnode)
{kTensorAddOpName, "ElemWise"}, {kTensorAddOpName, "ElemWise"},
{kConv2DBackpropInputOpName, "Conv2d_backprop_input"}, {kConv2DBackpropInputOpName, "Conv2d_backprop_input"},
{kConv2DBackpropFilterOpName, "Conv2d_backprop_filter"}, {kConv2DBackpropFilterOpName, "Conv2d_backprop_filter"},
{kDepthwiseConv2dNativeName, "DepthwiseConvolution"},
{kDepthwiseConv2dNativeOpName, "DepthwiseConvolution"},
{kAddNOpName, "ElemWise"}, {kAddNOpName, "ElemWise"},
{kReluGradV2OpName, "ElemWise"}, {kReluGradV2OpName, "ElemWise"},
{kRealDivOpName, "ElemWise"}}; {kRealDivOpName, "ElemWise"}};


+ 320
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc View File

@@ -0,0 +1,320 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h"

#include <vector>
#include <string>
#include <memory>
#include <utility>

#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_info.h"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kConv2DBackpropInputNum = 4;
constexpr size_t kConv2DAxisNum = 4;
constexpr auto kAttrOffsetA = "offset_a";
constexpr auto kAttrPadList = "pad_list";
constexpr auto kAttrPads = "pads";
constexpr auto kAttrMode = "mode";
constexpr auto kAttrChannelMultiplier = "channel_multiplier";

bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
MS_EXCEPTION_IF_NULL(conv2d);
auto group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(conv2d, kAttrGroup));
if (group == 1) {
return false;
}
auto data_format = AnfAlgo::GetNodeAttr<std::string>(conv2d, kAttrDataFormat);
if (data_format != "NCHW") {
MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1, but got " << data_format;
}
if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) {
MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size()
<< "output axis num: " << out_shape.size();
}
auto in_channel = in_shape[1];
auto out_channel = out_shape[1];
if (group != in_channel || group != out_channel) {
MS_LOG(EXCEPTION) << "Conv2D's attr group should be equal to in_channel and out_channel when group > 1, but got "
<< "group: " << group << " in_channel: " << in_channel << " out_channel: " << out_channel;
}
return true;
}

ValueNodePtr CreatePermValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &perm) {
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<ValuePtr> axis_values{};
abstract::AbstractBasePtrList abs{};
for (const auto &axis : perm) {
axis_values.push_back(MakeValue(axis));
abs.push_back(std::make_shared<abstract::AbstractScalar>(axis));
}
auto perm_value_tuple = std::make_shared<ValueTuple>(axis_values);
MS_EXCEPTION_IF_NULL(perm_value_tuple);
auto abstract = std::make_shared<abstract::AbstractTuple>(abs);
MS_EXCEPTION_IF_NULL(abstract);
auto perm_value = kernel_graph->NewValueNode(abstract, perm_value_tuple);
MS_EXCEPTION_IF_NULL(perm_value);
kernel_graph->AddValueNodeToGraph(perm_value);
return perm_value;
}

CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, const AnfNodePtr &input_node,
bool need_trans_output) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
MS_EXCEPTION_IF_NULL(input_node);
auto perm = std::vector<int64_t>{1, 0, 2, 3};
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
CreatePermValueNode(graph, perm)};
auto transpose = graph->NewCNode(transpose_inputs);
MS_EXCEPTION_IF_NULL(transpose);
transpose->set_scope(conv2d->scope());

if (need_trans_output) {
auto types = {AnfAlgo::GetOutputInferDataType(input_node, 0)};
auto out_shape = AnfAlgo::GetOutputInferShape(input_node, 0);
if (out_shape.size() != kConv2DAxisNum) {
MS_LOG(EXCEPTION) << "Conv2D's output axis number should be " << kConv2DAxisNum << ", but got "
<< out_shape.size();
}
std::swap(out_shape[0], out_shape[1]);
auto shapes = {out_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, transpose.get());
} else {
transpose->set_abstract(conv2d->abstract());
}

auto input_names = std::vector<std::string>{"x", "perm"};
auto output_names = std::vector<std::string>{"output"};
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose);
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose);
return transpose;
}

CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
if (conv2d->inputs().size() != kConvInputNum) {
MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got "
<< conv2d->inputs().size() - 1;
}
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
conv2d->input(1), transpose};
auto depth_conv = graph->NewCNode(depth_conv_inputs);
MS_EXCEPTION_IF_NULL(depth_conv);
depth_conv->set_abstract(conv2d->abstract());
depth_conv->set_scope(conv2d->scope());
return depth_conv;
}

CNodePtr CreateDepthwiseConv2DBackpropInput(const FuncGraphPtr &graph, const CNodePtr &conv2d_backin,
const CNodePtr &transpose) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d_backin);
if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got "
<< conv2d_backin->inputs().size() - 1;
}
std::vector<AnfNodePtr> depth_conv_backin_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropInputOpName)), conv2d_backin->input(3),
transpose, conv2d_backin->input(1)};
auto depth_conv_backin = graph->NewCNode(depth_conv_backin_inputs);
MS_EXCEPTION_IF_NULL(depth_conv_backin);
depth_conv_backin->set_abstract(conv2d_backin->abstract());
depth_conv_backin->set_scope(conv2d_backin->scope());
return depth_conv_backin;
}

CNodePtr CreateDepthwiseConv2DBackpropFilter(const FuncGraphPtr &graph, const CNodePtr &conv2d_backfil) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d_backfil);
if (conv2d_backfil->inputs().size() != kConv2DBackpropInputNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's input number should be " << kConv2DBackpropInputNum - 1 << ", but got "
<< conv2d_backfil->inputs().size() - 1;
}
auto filter_size_node = conv2d_backfil->input(3);
MS_EXCEPTION_IF_NULL(filter_size_node);
auto filter_size_vnode = filter_size_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(filter_size_vnode);
auto filter_size = GetValue<std::vector<int64_t>>(filter_size_vnode->value());
// swap axis 0 and 1 of filter shape, but don't swap twice since some node share same filter_size valuenode
// when the filter_size value is same.
if (filter_size[0] != 1) {
std::swap(filter_size[0], filter_size[1]);
conv2d_backfil->input(3)->cast<ValueNodePtr>()->set_value(MakeValue(filter_size));
}
std::vector<AnfNodePtr> depth_conv_backfil_inputs = {
NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeBackpropFilterOpName)), conv2d_backfil->input(2),
conv2d_backfil->input(3), conv2d_backfil->input(1)};
auto depth_conv_backfil = graph->NewCNode(depth_conv_backfil_inputs);
MS_EXCEPTION_IF_NULL(depth_conv_backfil);
depth_conv_backfil->set_scope(conv2d_backfil->scope());

auto types = {AnfAlgo::GetOutputInferDataType(conv2d_backfil, 0)};
std::vector<size_t> out_shape = AnfAlgo::GetOutputInferShape(conv2d_backfil, 0);
if (out_shape.size() != kConv2DAxisNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropFilter's output axis number should be " << kConv2DAxisNum << ", but got "
<< out_shape.size();
}
std::swap(out_shape[0], out_shape[1]);
auto shapes = {out_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, depth_conv_backfil.get());
return depth_conv_backfil;
}

void SetCommonAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) {
AnfAlgo::CopyNodeAttr(kAttrKernelSize, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrDilation, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrDataFormat, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrPadList, kAttrPads, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrPadMode, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrPad, conv2d, depth_conv);
AnfAlgo::SetNodeAttr(kAttrMode, MakeValue(3), depth_conv);
AnfAlgo::SetNodeAttr(kAttrChannelMultiplier, MakeValue(1), depth_conv);
}

void SetConv2DAttrs(const CNodePtr &conv2d, const CNodePtr &depth_conv) {
SetCommonAttrs(conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrInputNames, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrStride, conv2d, depth_conv);
AnfAlgo::CopyNodeAttr(kAttrOffsetA, conv2d, depth_conv);
}

void SetConv2DBackpropInputAttrs(const CNodePtr &conv2d_backin, const CNodePtr &depth_conv_backin) {
SetCommonAttrs(conv2d_backin, depth_conv_backin);
auto input_names = std::vector<std::string>{"input_size", "filter", "dout"};
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), depth_conv_backin);
auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(conv2d_backin, kAttrStride);
if (stride.size() == 2) {
stride.insert(stride.begin(), 2, 1);
}
AnfAlgo::SetNodeAttr(kAttrStride, MakeValue(stride), depth_conv_backin);
}

void SetConv2DBackpropFilterAttrs(const CNodePtr &conv2d_backfil, const CNodePtr &depth_conv_backfil) {
SetCommonAttrs(conv2d_backfil, depth_conv_backfil);
auto input_names = std::vector<std::string>{"input", "filter_size", "dout"};
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), depth_conv_backfil);
auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(conv2d_backfil, kAttrStride);
if (stride.size() == 2) {
stride.insert(stride.begin(), 2, 1);
}
AnfAlgo::SetNodeAttr(kAttrStride, MakeValue(stride), depth_conv_backfil);
}
} // namespace

const BaseRef Conv2DUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr W = std::make_shared<Var>();
VectorRef pattern({prim::kPrimConv2D, X, W});
return pattern;
}

const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

auto conv2d = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv2d);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d, 0);
auto output_shape = AnfAlgo::GetOutputInferShape(conv2d, 0);
if (!NeedUpdate(conv2d, input_shape, output_shape)) {
return nullptr;
}

if (conv2d->inputs().size() != kConvInputNum) {
MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got "
<< conv2d->inputs().size() - 1;
}
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(2), true);
auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose);
SetConv2DAttrs(conv2d, depth_conv);
return depth_conv;
}

const BaseRef Conv2DBackpropInputUnifyMindIR::DefinePattern() const {
VarPtr dout = std::make_shared<Var>();
VarPtr weight = std::make_shared<Var>();
VarPtr input_size = std::make_shared<Var>();
VectorRef pattern({prim::kPrimConv2DBackpropInput, dout, weight, input_size});
return pattern;
}

const AnfNodePtr Conv2DBackpropInputUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

auto conv2d_backin = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv2d_backin);
auto input_shape = AnfAlgo::GetOutputInferShape(conv2d_backin, 0);
auto output_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backin, 0);
if (!NeedUpdate(conv2d_backin, input_shape, output_shape)) {
return nullptr;
}

if (conv2d_backin->inputs().size() != kConv2DBackpropInputNum) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInput's input number should be " << kConv2DBackpropInputNum - 1 << ", but got "
<< conv2d_backin->inputs().size() - 1;
}
auto transpose = CreateTranspose(graph, conv2d_backin, conv2d_backin->input(2), true);
auto depth_conv_backin = CreateDepthwiseConv2DBackpropInput(graph, conv2d_backin, transpose);
SetConv2DBackpropInputAttrs(conv2d_backin, depth_conv_backin);
return depth_conv_backin;
}

const BaseRef Conv2DBackpropFilterUnifyMindIR::DefinePattern() const {
VarPtr dout = std::make_shared<Var>();
VarPtr input = std::make_shared<Var>();
VarPtr filter_size = std::make_shared<Var>();
VectorRef pattern({prim::kPrimConv2DBackpropFilter, dout, input, filter_size});
return pattern;
}

const AnfNodePtr Conv2DBackpropFilterUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

auto conv2d_backfil = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv2d_backfil);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backfil, 1);
auto output_shape = AnfAlgo::GetPrevNodeOutputInferShape(conv2d_backfil, 0);
if (!NeedUpdate(conv2d_backfil, input_shape, output_shape)) {
return nullptr;
}

auto depth_conv_backfil = CreateDepthwiseConv2DBackpropFilter(graph, conv2d_backfil);
SetConv2DBackpropFilterAttrs(conv2d_backfil, depth_conv_backfil);
auto transpose = CreateTranspose(graph, conv2d_backfil, depth_conv_backfil, false);

auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
(void)manager->Replace(conv2d_backfil, transpose);
return transpose;
}
} // namespace opt
} // namespace mindspore

+ 51
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.h View File

@@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_

#include <memory>
#include "backend/optimizer/common/optimizer.h"

namespace mindspore {
namespace opt {
class Conv2DUnifyMindIR : public PatternProcessPass {
public:
explicit Conv2DUnifyMindIR(bool multigraph = true) : PatternProcessPass("conv2d_unify_mindir", multigraph) {}
~Conv2DUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};

class Conv2DBackpropInputUnifyMindIR : public PatternProcessPass {
public:
explicit Conv2DBackpropInputUnifyMindIR(bool multigraph = true)
: PatternProcessPass("conv2d_backprop_input_unify_mindir", multigraph) {}
~Conv2DBackpropInputUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};

class Conv2DBackpropFilterUnifyMindIR : public PatternProcessPass {
public:
explicit Conv2DBackpropFilterUnifyMindIR(bool multigraph = true)
: PatternProcessPass("conv2d_backprop_filter_unify_mindir", multigraph) {}
~Conv2DBackpropFilterUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_CONV2D_UNIFY_MINDIR_H_

+ 269
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.cc View File

@@ -0,0 +1,269 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h"
#include <vector>
#include <memory>
#include <numeric>
#include <algorithm>
#include <functional>
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/log_adapter.h"

constexpr auto kKeepProb = "keep_prob";
constexpr auto kSeed0 = "Seed0";
constexpr auto kSeed1 = "Seed1";
constexpr auto kUint8BitSize = 8;

namespace mindspore::opt {
constexpr size_t kFloat16Len = 2; // size of float16
namespace {
AnfNodePtr GetDropoutKeepProb(const AnfNodePtr &node, float *keep_prob) {
MS_LOG(INFO) << "GetDropoutNodeInfo start.";
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(keep_prob);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode) || !AnfAlgo::HasNodeAttr(kSeed0, cnode) ||
!AnfAlgo::HasNodeAttr(kSeed1, cnode)) {
MS_LOG(EXCEPTION) << "Dropout node does nothave attr: keep_prob or seed0 or seed1.";
}
*keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb);
MS_LOG(INFO) << "keep_prob: " << *keep_prob;
// return dropout input. maybe tensor or pre cnode output
return cnode->input(1);
}

ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float &keep_prob, const TypePtr &dtype) {
MS_LOG(INFO) << "CreateKeepPorbValueNode start.";
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<int64_t> keep_prob_shape = {};
ShapeVector shape = {};
auto keep_prob_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), keep_prob_shape);
MS_EXCEPTION_IF_NULL(keep_prob_tensor);
auto data_ptr = keep_prob_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
// keep_prob's datatype is same with input data
if (dtype->type_id() == kNumberTypeFloat16) {
float16 half_data = float16(keep_prob);
auto ret_code = memcpy_s(data_ptr, kFloat16Len, &half_data, kFloat16Len);
if (ret_code != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
}
} else {
auto *val = reinterpret_cast<float *>(data_ptr);
*val = keep_prob;
}
auto abstract = std::make_shared<abstract::AbstractTensor>(dtype, shape);
auto keep_prob_value = kernel_graph->NewValueNode(abstract, keep_prob_tensor);
MS_EXCEPTION_IF_NULL(keep_prob_value);
kernel_graph->AddValueNodeToGraph(keep_prob_value);
return keep_prob_value;
}

std::vector<int64_t> GetInputShape(const AnfNodePtr &node, const AnfNodePtr &dropout_input) {
MS_LOG(INFO) << "GetInputShape start.";
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(dropout_input);
std::vector<int64_t> shapes;
if (dropout_input->isa<Parameter>()) {
MS_LOG(INFO) << "Dropout input from parameter node.";
// single test case
auto dropout_input_value = dropout_input->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(dropout_input_value);
MS_EXCEPTION_IF_NULL(dropout_input_value->Shape());
auto shape = dropout_input_value->Shape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
return shape->shape();
} else if (dropout_input->isa<CNode>()) {
MS_LOG(INFO) << "Dropout input from cnode.";
auto dropout_input_node = dropout_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_input_node);
auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong);
return shapes;
} else {
MS_LOG(ERROR) << "Dropout input is not parameter or cnode.";
return {};
}
}

ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape) {
MS_LOG(INFO) << "CreateShapeValueNode start.";
MS_EXCEPTION_IF_NULL(func_graph);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<ValuePtr> dim_values{};
abstract::AbstractBasePtrList abs{};
for (const auto &dim : shape) {
dim_values.push_back(MakeValue(dim));
abs.push_back(std::make_shared<abstract::AbstractScalar>(dim));
}
auto shape_value_tuple = std::make_shared<ValueTuple>(dim_values);
MS_EXCEPTION_IF_NULL(shape_value_tuple);
auto abstract = std::make_shared<abstract::AbstractTuple>(abs);
MS_EXCEPTION_IF_NULL(abstract);
auto shape_value = kernel_graph->NewValueNode(abstract, shape_value_tuple);
MS_EXCEPTION_IF_NULL(shape_value);
kernel_graph->AddValueNodeToGraph(shape_value);
return shape_value;
}
} // namespace

const BaseRef DropoutUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
auto prim = std::make_shared<Primitive>(kDropoutOpName);
auto ref = VectorRef({prim, X});
return VectorRef({prim::kPrimTupleGetItem, ref, Y});
}

const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto tuple_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_cnode);
auto dropout_node = tuple_cnode->input(1);
MS_EXCEPTION_IF_NULL(dropout_node);
float keep_prob = 0;
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob);
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32;
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype);
auto shape = GetInputShape(dropout_node, dropout_input);
auto shape_value = CreateShapeValueNode(func_graph, shape);
// CreateDropoutGenMask
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
output_size = output_size / kUint8BitSize;
MS_LOG(INFO) << "Output_size: " << output_size;
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
shape_value, keep_prob_value};
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
ShapeVector dropout_gen_mask_output = {output_size};
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(node->scope());

// CreateDropoutDoMask
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
ShapeVector dropout_do_mask_output = shape;
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask_output);
dropout_do_mask->set_abstract(do_mask_abstract);
dropout_do_mask->set_scope(node->scope());

return dropout_do_mask;
}

const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
MS_EXCEPTION_IF_NULL(Y);
auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName);
auto tuple_getitem_prim = prim::kPrimTupleGetItem;
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
MS_EXCEPTION_IF_NULL(dropout_prim);
MS_EXCEPTION_IF_NULL(dropout_grad_prim);
auto ref0 = VectorRef({dropout_prim, X});
auto ref1 = VectorRef({tuple_getitem_prim, ref0, Y});
return VectorRef({dropout_grad_prim, grad_input_, ref1});
}

const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto dropout_grad = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_grad);
auto tuple_getitem = dropout_grad->input(2);
MS_EXCEPTION_IF_NULL(tuple_getitem);
auto tuple_getitem_cnode = tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode);
auto dropout_node = tuple_getitem_cnode->input(1);
MS_EXCEPTION_IF_NULL(dropout_node);
float keep_prob = 0;
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob);
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32;
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype);
auto shape = GetInputShape(dropout_node, dropout_input);
auto shape_value = CreateShapeValueNode(func_graph, shape);
// CreateDropoutGenMask
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
output_size = output_size / kUint8BitSize;
MS_LOG(INFO) << "Output_size: " << output_size;
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)),
shape_value, keep_prob_value};
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);
ShapeVector dropout_gen_mask_output = {output_size};
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output);
MS_EXCEPTION_IF_NULL(gen_mask_abstract);
dropout_gen_mask->set_abstract(gen_mask_abstract);
dropout_gen_mask->set_scope(dropout_node->scope());
// AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask);

// CreateDropoutDoMask-forward
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(dropout_node);
if (iter != node_users.end()) {
for (auto &node_index : iter->second) {
// Dropout has two outputs, so output node is tuple_getitem
auto tuple_getitem_cnode2 = node_index.first->cast<CNodePtr>();
// check if Dropout's first output, which is used by forward, is used.
auto getitem_index = GetValue<int64_t>(tuple_getitem_cnode2->input(2)->cast<ValueNodePtr>()->value());
if (getitem_index == 0) {
std::vector<AnfNodePtr> dropout_do_mask1_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
dropout_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask1);
ShapeVector dropout_do_mask1_output = shape;
auto do_mask_abstract1 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask1_output);
dropout_do_mask1->set_abstract(do_mask_abstract1);
dropout_do_mask1->set_scope(dropout_node->scope());
(void)manager->Replace(tuple_getitem_cnode2, dropout_do_mask1);
break;
}
}
}

// CreateDropoutDoMask-backward
if (equiv->find(grad_input_) == equiv->end()) {
MS_LOG(EXCEPTION) << "Can not find grad_input in this pattern.";
}
auto grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]);
std::vector<AnfNodePtr> dropout_do_mask2_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)),
grad_input, dropout_gen_mask, keep_prob_value};
auto dropout_do_mask2 = func_graph->NewCNode(dropout_do_mask2_inputs);
MS_EXCEPTION_IF_NULL(dropout_do_mask2);
ShapeVector dropout_do_mask2_output = shape;
auto do_mask_abstract2 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask2_output);
dropout_do_mask2->set_abstract(do_mask_abstract2);
dropout_do_mask2->set_scope(node->scope());

return dropout_do_mask2;
}
} // namespace mindspore::opt

+ 47
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/dropout_unify_mindir.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_

#include <memory>
#include "backend/optimizer/common/optimizer.h"

namespace mindspore {
namespace opt {
class DropoutUnifyMindIR : public PatternProcessPass {
public:
explicit DropoutUnifyMindIR(bool multigraph = true) : PatternProcessPass("dropout_unify_mindir", multigraph) {}
~DropoutUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};

class DropoutGradUnifyMindIR : public PatternProcessPass {
public:
explicit DropoutGradUnifyMindIR(bool multigraph = true)
: PatternProcessPass("dropout_grad_unify_mindir", multigraph) {
grad_input_ = std::make_shared<Var>();
}
~DropoutGradUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
VarPtr grad_input_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_

+ 148
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.cc View File

@@ -0,0 +1,148 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h"

#include <vector>
#include <memory>

#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_info.h"
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kMaxPoolInputNum = 2;
constexpr size_t kMaxPoolAttrAxisNum = 4;
constexpr size_t kMaxPoolGradInputNum = 4;
constexpr size_t kMaxPoolWithArgmaxOutputNum = 2;

CNodePtr GetMaxPool(const CNodePtr &maxpool_grad) {
MS_EXCEPTION_IF_NULL(maxpool_grad);
if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) {
MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got "
<< maxpool_grad->inputs().size() - 1;
}
auto maxpool_anf = maxpool_grad->input(2);
MS_EXCEPTION_IF_NULL(maxpool_anf);
return maxpool_anf->cast<CNodePtr>();
}

CNodePtr CreateMaxPoolWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(maxpool);
if (maxpool->inputs().size() != kMaxPoolInputNum) {
MS_LOG(EXCEPTION) << "MaxPool's input number should be " << kMaxPoolInputNum - 1 << ", but got "
<< maxpool->inputs().size() - 1;
}
std::vector<AnfNodePtr> maxpool_argmax_inputs = {NewValueNode(std::make_shared<Primitive>(kMaxPoolWithArgmaxOpName)),
maxpool->input(1)};
auto maxpool_argmax = graph->NewCNode(maxpool_argmax_inputs);
MS_EXCEPTION_IF_NULL(maxpool_argmax);
maxpool_argmax->set_scope(maxpool->scope());

// MaxPoolWithArgmax's second output is argmax, whose datatype is uint16 and with same shape as first output
TypeId argmax_dtype = kNumberTypeUInt16;
auto types = {AnfAlgo::GetOutputInferDataType(maxpool, 0), argmax_dtype};
auto out_shape = AnfAlgo::GetOutputInferShape(maxpool, 0);
auto shapes = {out_shape, out_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_argmax.get());
return maxpool_argmax;
}

CNodePtr CreateMaxPoolGradWithArgmax(const FuncGraphPtr &graph, const CNodePtr &maxpool_grad,
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(maxpool_grad);
if (maxpool_grad->inputs().size() != kMaxPoolGradInputNum) {
MS_LOG(EXCEPTION) << "MaxPoolGrad's input number should be " << kMaxPoolGradInputNum - 1 << ", but got "
<< maxpool_grad->inputs().size() - 1;
}
// MaxPoolGrad's inputs are {input, output, grad_input}, MaxPoolGradWithArgmax's inputs are
// {input, grad_input, argmax_output}
std::vector<AnfNodePtr> maxpool_grad_argmax_inputs = {
NewValueNode(std::make_shared<Primitive>(kMaxPoolGradWithArgmaxOpName)), maxpool_grad->input(1),
maxpool_grad->input(3), maxpool_argmax_outputs[1]};
auto maxpool_grad_argmax = graph->NewCNode(maxpool_grad_argmax_inputs);
MS_EXCEPTION_IF_NULL(maxpool_grad_argmax);
maxpool_grad_argmax->set_scope(maxpool_grad->scope());
maxpool_grad_argmax->set_abstract(maxpool_grad->abstract());
return maxpool_grad_argmax;
}

void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax,
const CNodePtr &maxpool_grad_argmax) {
auto strides = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrStrides);
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool, kAttrKsize);
if (strides.size() != kMaxPoolAttrAxisNum) {
MS_LOG(EXCEPTION) << "MaxPool's attr strides has wrong axis number, should be " << kMaxPoolAttrAxisNum
<< ", but got " << strides.size();
}
if (ksize.size() != kMaxPoolAttrAxisNum) {
MS_LOG(EXCEPTION) << "MaxPool's attr ksize has wrong axis number, should be " << kMaxPoolAttrAxisNum << ", but got "
<< ksize.size();
}
// note that strides and ksize change from (1, 1, x, y) to (1, x, y, 1)
for (size_t i = 1; i <= 2; ++i) {
strides[i] = strides[i + 1];
ksize[i] = ksize[i + 1];
}
strides[3] = 1;
ksize[3] = 1;

AnfAlgo::CopyNodeAttrs(maxpool, maxpool_argmax);
AnfAlgo::CopyNodeAttrs(maxpool_grad, maxpool_grad_argmax);
AnfAlgo::SetNodeAttr(kAttrStrides, MakeValue(strides), maxpool_argmax);
AnfAlgo::SetNodeAttr(kAttrStrides, MakeValue(strides), maxpool_grad_argmax);
AnfAlgo::SetNodeAttr(kAttrKsize, MakeValue(ksize), maxpool_argmax);
AnfAlgo::SetNodeAttr(kAttrKsize, MakeValue(ksize), maxpool_grad_argmax);
}
} // namespace

const BaseRef MaxPool2MaxPoolWithArgmax::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
VectorRef maxpool({prim::kPrimMaxPool, X});
VectorRef pattern({prim::kPrimMaxPoolGrad, X, maxpool, Y});
return pattern;
}

const AnfNodePtr MaxPool2MaxPoolWithArgmax::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

auto maxpool_grad = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maxpool_grad);
auto maxpool = GetMaxPool(maxpool_grad);
MS_EXCEPTION_IF_NULL(maxpool);

auto maxpool_argmax = CreateMaxPoolWithArgmax(graph, maxpool);
std::vector<AnfNodePtr> maxpool_argmax_outputs;
CreateMultipleOutputsOfAnfNode(graph, maxpool_argmax, kMaxPoolWithArgmaxOutputNum, &maxpool_argmax_outputs);
auto maxpool_grad_argmax = CreateMaxPoolGradWithArgmax(graph, maxpool_grad, maxpool_argmax_outputs);
SetNodeAttrs(maxpool, maxpool_grad, maxpool_argmax, maxpool_grad_argmax);

auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
(void)manager->Replace(maxpool, maxpool_argmax_outputs[0]);
return maxpool_grad_argmax;
}
} // namespace opt
} // namespace mindspore

+ 34
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_

#include <memory>
#include "backend/optimizer/common/optimizer.h"

namespace mindspore {
namespace opt {
class MaxPool2MaxPoolWithArgmax : public PatternProcessPass {
public:
explicit MaxPool2MaxPoolWithArgmax(bool multigraph = true)
: PatternProcessPass("maxpool_to_maxpool_with_argmax", multigraph) {}
~MaxPool2MaxPoolWithArgmax() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_TO_MAXPOOL_WITH_ARGMAX_H_

+ 115
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.cc View File

@@ -0,0 +1,115 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h"
#include <memory>
#include <vector>
#include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_info.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h"
#include "utils/utils.h"

namespace mindspore {
namespace opt {
namespace {
constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4;
bool IsC(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
AnfNodePtr in = utils::cast<AnfNodePtr>(n);
MS_EXCEPTION_IF_NULL(in);
return in->isa<ValueNode>();
}
return false;
}

CNodePtr GetMaxPoolWithArgmax(const CNodePtr &maxpool_grad_with_argmax) {
MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax);
if (maxpool_grad_with_argmax->inputs().size() != kMaxPoolGradWithArgmaxInputNum) {
MS_LOG(EXCEPTION) << "MaxPoolGradWithArgmax has wrong input size.";
}
auto tuple_getitem0_anf = maxpool_grad_with_argmax->input(3);
MS_EXCEPTION_IF_NULL(tuple_getitem0_anf);
return tuple_getitem0_anf->cast<CNodePtr>();
}
} // namespace

const BaseRef MaxPoolWithArgmaxUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VectorRef pattern({prim::kPrimMaxPoolWithArgmax, X});
return pattern;
}

const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto maxpool_with_argmax = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maxpool_with_argmax);

TypeId argmax_dtype = kNumberTypeUInt16;
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_with_argmax, kAttrKsize);
auto output_shape = AnfAlgo::GetOutputInferShape(maxpool_with_argmax, 0);
auto argmax_shape = output_shape;
if (argmax_shape.size() != 4) {
MS_LOG(DEBUG) << "argmax's infer shape size not equal 4";
}
argmax_shape[2] = ksize[1] * ksize[2];
argmax_shape[3] = (output_shape[2] * output_shape[3] + 15) / 16 + 1;
auto types = {AnfAlgo::GetOutputInferDataType(maxpool_with_argmax, 0), argmax_dtype};
auto shapes = {output_shape, argmax_shape};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, maxpool_with_argmax.get());

auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
return maxpool_with_argmax;
}

const BaseRef MaxPoolGradWithArgmaxUnifyMindIR::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
VarPtr index0 = std::make_shared<CondVar>(IsC);
VectorRef maxpool_with_argmax({prim::kPrimMaxPoolWithArgmax, X});
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, maxpool_with_argmax, index0});
VectorRef maxpool_grad_with_argmax({prim::kPrimMaxPoolGradWithArgmax, X, Y, tuple_getitem0});
return maxpool_grad_with_argmax;
}

const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto maxpool_grad_with_argmax = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax);
auto tuple_getitem0_anf = GetMaxPoolWithArgmax(maxpool_grad_with_argmax);
MS_EXCEPTION_IF_NULL(tuple_getitem0_anf);

TypeId argmax_dtype = kNumberTypeUInt16;
auto ksize = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(maxpool_grad_with_argmax, kAttrKsize);
auto argmax_shape = AnfAlgo::GetOutputInferShape(tuple_getitem0_anf, 0);
if (argmax_shape.size() != 4) {
MS_LOG(DEBUG) << "argmax's infer shape size not equal 4";
}
argmax_shape[3] = (argmax_shape[2] * argmax_shape[3] + 15) / 16 + 1;
argmax_shape[2] = ksize[1] * ksize[2];
AnfAlgo::SetOutputInferTypeAndShape({argmax_dtype}, {argmax_shape}, tuple_getitem0_anf.get());

auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
return maxpool_grad_with_argmax;
}
} // namespace opt
} // namespace mindspore

+ 43
- 0
mindspore/ccsrc/backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_

#include <memory>
#include "backend/optimizer/common/optimizer.h"

namespace mindspore {
namespace opt {
class MaxPoolWithArgmaxUnifyMindIR : public PatternProcessPass {
public:
explicit MaxPoolWithArgmaxUnifyMindIR(bool multigraph = true)
: PatternProcessPass("maxpool_with_argmax_unify_mindir", multigraph) {}
~MaxPoolWithArgmaxUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};

class MaxPoolGradWithArgmaxUnifyMindIR : public PatternProcessPass {
public:
explicit MaxPoolGradWithArgmaxUnifyMindIR(bool multigraph = true)
: PatternProcessPass("maxpool_grad_with_argmax_unify_mindir", multigraph) {}
~MaxPoolGradWithArgmaxUnifyMindIR() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_

+ 4
- 0
mindspore/ccsrc/backend/optimizer/common/helper.h View File

@@ -101,6 +101,10 @@ constexpr size_t kFusedMulApplyMomentumOutputNum = 2;
constexpr size_t kSplitInputNum = 2; constexpr size_t kSplitInputNum = 2;
constexpr size_t kGatherV2DynInputNum = 3; constexpr size_t kGatherV2DynInputNum = 3;
constexpr size_t kUnsortedSegmentSumInputNum = 2; constexpr size_t kUnsortedSegmentSumInputNum = 2;
constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2;
constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputNum = 3;
constexpr size_t kOneHotOutputNum = 1;
constexpr size_t kOneHotInputNum = 5;


enum FusedBatchNormInput { enum FusedBatchNormInput {
kX = 1, kX = 1,


+ 36
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -32,6 +32,10 @@
#include "runtime/device/ascend/ascend_kernel_runtime.h" #include "runtime/device/ascend/ascend_kernel_runtime.h"
#include "backend/optimizer/ascend/ascend_backend_optimization.h" #include "backend/optimizer/ascend/ascend_backend_optimization.h"
#include "backend/optimizer/common/common_backend_optimization.h" #include "backend/optimizer/common/common_backend_optimization.h"
#include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/maxpool_to_maxpool_with_argmax.h"
#include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h"
#include "runtime/device/kernel_adjust.h" #include "runtime/device/kernel_adjust.h"
#include "runtime/device/ascend/ascend_stream_assign.h" #include "runtime/device/ascend/ascend_stream_assign.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
@@ -423,6 +427,35 @@ void AscendSession::Init(uint32_t device_id) {
runtime_instance->CreateContext(); runtime_instance->CreateContext();
} }


void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs) {
std::string file_name = "hwopt_d_before_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph);
DumpIRProto(graph, "before_unify_mindir_hwopt_" + std::to_string(graph->graph_id()));
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutGradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::DropoutUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropInputUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::Conv2DBackpropFilterUnifyMindIR>());

optimizer->AddPassManager(unify_mindir_pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_name = "hwopt_d_after_unify_mindir_graph_" + std::to_string(graph->graph_id()) + ".ir";
DumpIR(file_name, graph);
}
}

GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
MS_LOG(INFO) << "Start"; MS_LOG(INFO) << "Start";
// construct graph, if successfully, graph_sum_ + 1 // construct graph, if successfully, graph_sum_ + 1
@@ -438,6 +471,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
// Update Graph Dynamic Shape Attr // Update Graph Dynamic Shape Attr
UpdateAllGraphDynamicShapeAttr(all_graphs); UpdateAllGraphDynamicShapeAttr(all_graphs);
for (const auto &graph : all_graphs) {
UnifyMindIR(graph);
}
BackendOptimization(all_graphs); BackendOptimization(all_graphs);
// empty graph dont entry to backend // empty graph dont entry to backend
if (root_graph->execution_order().empty()) { if (root_graph->execution_order().empty()) {
@@ -1219,7 +1255,6 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
return; return;
} }
memo->insert(graph.get()); memo->insert(graph.get());

opt::AscendBackendIRFusionOptimization(graph); opt::AscendBackendIRFusionOptimization(graph);
graph->SetExecOrderByDefault(); graph->SetExecOrderByDefault();




+ 1
- 0
mindspore/ccsrc/backend/session/ascend_session.h View File

@@ -51,6 +51,7 @@ class AscendSession : public SessionBasic {
void SyncStream() override; void SyncStream() override;


protected: protected:
void UnifyMindIR(const KernelGraphPtr &graph) override;
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override; GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override;


+ 1
- 0
mindspore/ccsrc/backend/session/cpu_session.h View File

@@ -32,6 +32,7 @@ class CPUSession : public SessionBasic {
void Init(uint32_t device_id) override { InitExecutor(kCPUDevice, device_id); } void Init(uint32_t device_id) override { InitExecutor(kCPUDevice, device_id); }


protected: protected:
void UnifyMindIR(const KernelGraphPtr &graph) override { return; }
void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *, void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override; std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) override;
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;


+ 1
- 0
mindspore/ccsrc/backend/session/gpu_session.h View File

@@ -35,6 +35,7 @@ class GPUSession : public SessionBasic {
void SyncStream() override; void SyncStream() override;


protected: protected:
void UnifyMindIR(const KernelGraphPtr &graph) override { return; }
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,


+ 2
- 0
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -943,6 +943,7 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con


// Update Graph Dynamic Shape Attr // Update Graph Dynamic Shape Attr
UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
UnifyMindIR(graph);
opt::BackendCommonOptimization(graph); opt::BackendCommonOptimization(graph);
graph->SetInputNodes(); graph->SetInputNodes();
auto input_nodes = graph->input_nodes(); auto input_nodes = graph->input_nodes();
@@ -1610,6 +1611,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
// set output // set output
CreateOutputNode(cnode, graph); CreateOutputNode(cnode, graph);
graph->SetInputNodes(); graph->SetInputNodes();
UnifyMindIR(graph);
return graph; return graph;
} }




+ 1
- 0
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -147,6 +147,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs, VectorRef *outputs,
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0;
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) { virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {


+ 22
- 1
mindspore/ccsrc/utils/utils.h View File

@@ -163,7 +163,9 @@ constexpr auto kBatchToSpaceOpName = "BatchToSpace";
constexpr auto kPadOpName = "Pad"; constexpr auto kPadOpName = "Pad";
constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput"; constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput";
constexpr auto kConv2DBackpropFilterOpName = "Conv2DBackpropFilter"; constexpr auto kConv2DBackpropFilterOpName = "Conv2DBackpropFilter";
constexpr auto kDepthwiseConv2dNativeName = "DepthwiseConv2dNative";
constexpr auto kDepthwiseConv2dNativeOpName = "DepthwiseConv2dNative";
constexpr auto kDepthwiseConv2dNativeBackpropInputOpName = "DepthwiseConv2dNativeBackpropInput";
constexpr auto kDepthwiseConv2dNativeBackpropFilterOpName = "DepthwiseConv2dNativeBackpropFilter";
constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2"; constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2";
constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2"; constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2";
constexpr auto kLabelSetOpName = "LabelSet"; constexpr auto kLabelSetOpName = "LabelSet";
@@ -204,6 +206,8 @@ constexpr auto kPaddingOpName = "Padding";
constexpr auto kAvgPoolOpName = "AvgPool"; constexpr auto kAvgPoolOpName = "AvgPool";
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
constexpr auto kmaxPoolGradOpName = "MaxPoolGrad"; constexpr auto kmaxPoolGradOpName = "MaxPoolGrad";
constexpr auto kMaxPoolWithArgmaxOpName = "MaxPoolWithArgmax";
constexpr auto kMaxPoolGradWithArgmaxOpName = "MaxPoolGradWithArgmax";
constexpr auto kTensorAddOpName = "TensorAdd"; constexpr auto kTensorAddOpName = "TensorAdd";
constexpr auto kCastOpName = "Cast"; constexpr auto kCastOpName = "Cast";
constexpr auto kGreaterEqualOpName = "GreaterEqual"; constexpr auto kGreaterEqualOpName = "GreaterEqual";
@@ -250,6 +254,13 @@ constexpr auto kMatMulV2OpName = "MatMulV2";
constexpr auto kBroadcastToOpName = "BroadcastTo"; constexpr auto kBroadcastToOpName = "BroadcastTo";
constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; constexpr auto kFusedAddReluV2Name = "FusedAddReluV2";
constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2"; constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2";
constexpr auto kDropoutOpName = "Dropout";
constexpr auto kDropoutGradOpName = "DropoutGrad";
constexpr auto kDropoutGenMaskOpName = "DropoutGenMask";
constexpr auto kDropoutDoMaskOpName = "DropoutDoMask";
constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits";
constexpr auto kOneHotOpName = "OneHot";
constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits";


// Hcom Op Type // Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
@@ -272,6 +283,7 @@ constexpr auto kAttrEpsilon = "epsilon";
constexpr auto kAttrFactor = "factor"; constexpr auto kAttrFactor = "factor";
constexpr auto kAttrIsRef = "isRef"; constexpr auto kAttrIsRef = "isRef";
constexpr auto kAttrDataShape = "data_shape"; constexpr auto kAttrDataShape = "data_shape";
constexpr auto kAttrDataFormat = "data_format";
constexpr auto kAttrAxis = "axis"; constexpr auto kAttrAxis = "axis";
constexpr auto kAttrKeepDims = "keep_dims"; constexpr auto kAttrKeepDims = "keep_dims";
constexpr auto kAttrShapeGamma = "shape_gamma"; constexpr auto kAttrShapeGamma = "shape_gamma";
@@ -348,6 +360,15 @@ constexpr auto kAttrPynativeNextOpName = "next_op";
constexpr auto kAttrPynativeNextIndex = "next_index"; constexpr auto kAttrPynativeNextIndex = "next_index";
constexpr auto kAttrCompileInfo = "compile_info"; constexpr auto kAttrCompileInfo = "compile_info";
constexpr auto kAttrFusionType = "fusion_type"; constexpr auto kAttrFusionType = "fusion_type";
constexpr auto kAttrStride = "stride";
constexpr auto kAttrStrides = "strides";
constexpr auto kAttrKsize = "ksize";
constexpr auto kAttrKernelSize = "kernel_size";
constexpr auto kAttrDilation = "dilation";
constexpr auto kAttrPadMode = "pad_mode";
constexpr auto kAttrPad = "pad";
constexpr auto kAttrPadding = "padding";
constexpr auto kAttrIsGrad = "is_grad";


// attr value // attr value
constexpr auto kValueTargetSwitch = "target_switch"; constexpr auto kValueTargetSwitch = "target_switch";


+ 2
- 0
mindspore/core/base/core_ops.h View File

@@ -134,6 +134,8 @@ inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");
inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool"); inline const PrimitivePtr kPrimMaxPool = std::make_shared<Primitive>("MaxPool");
inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad"); inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared<Primitive>("MaxPoolGrad");
inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("MaxPoolWithArgmax");
inline const PrimitivePtr kPrimMaxPoolGradWithArgmax = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp"); inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool"); inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool");
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");


+ 4
- 20
mindspore/nn/layer/basic.py View File

@@ -141,37 +141,21 @@ class Dropout(Cell):
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob
seed0, seed1 = _get_graph_seed(0, "dropout") seed0, seed1 = _get_graph_seed(0, "dropout")
self.seed0 = seed0 self.seed0 = seed0
self.seed1 = seed1 self.seed1 = seed1
self.dtype = dtype
self.get_shape = P.Shape()
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
self.dropout_do_mask = P.DropoutDoMask()
self.cast = P.Cast()
self.is_ascend = context.get_context('device_target') in ["Ascend"]
self.dropout = P.Dropout(keep_prob)
self.keep_prob = keep_prob
self.dropout = P.Dropout(keep_prob, seed0, seed1)


def construct(self, x): def construct(self, x):
if not self.training: if not self.training:
return x return x


if not self.is_ascend:
out, _ = self.dropout(x)
return out

if self.keep_prob == 1: if self.keep_prob == 1:
return x return x


shape = self.get_shape(x)
dtype = P.DType()(x)
if _is_float_dtype(dtype):
keep_prob = self.cast(self.keep_prob, dtype)
else:
keep_prob = self.cast(self.keep_prob, mstype.float16)
output = self.dropout_gen_mask(shape, keep_prob)
return self.dropout_do_mask(x, output, keep_prob)
out, _ = self.dropout(x)
return out


def extend_repr(self): def extend_repr(self):
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)


+ 1
- 21
mindspore/nn/layer/conv.py View File

@@ -19,7 +19,7 @@ from mindspore import context
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice from mindspore._checkparam import Validator, Rel, twice
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
@@ -247,28 +247,8 @@ class Conv2d(_Conv):
dilation=self.dilation, dilation=self.dilation,
group=self.group, group=self.group,
data_format=self.format) data_format=self.format)
self._init_depthwise_conv2d()
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()


def _init_depthwise_conv2d(self):
"""Initialize depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation
Validator.check_equal_int(self.group, self.in_channels, 'group')
Validator.check_equal_int(self.group, self.out_channels, 'group')
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation)
weight_shape = [1, self.in_channels, *self.kernel_size]
if isinstance(self.weight_init, Tensor):
self.weight_init = Tensor(self.weight_init.asnumpy().swapaxes(0, 1), self.weight_init.dtype)
if isinstance(self.weight_init, Initializer):
self.weight_init.shape = weight_shape
self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')

def construct(self, x): def construct(self, x):
output = self.conv2d(x, self.weight) output = self.conv2d(x, self.weight)
if self.has_bias: if self.has_bias:


+ 2
- 16
mindspore/nn/layer/pooling.py View File

@@ -124,16 +124,9 @@ class MaxPool2d(_PoolNd):
strides=self.stride, strides=self.stride,
padding=self.pad_mode, padding=self.pad_mode,
data_format=self.format) data_format=self.format)
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
self.is_tbe = context.get_context("device_target") == "Ascend"


def construct(self, x): def construct(self, x):
if self.is_tbe and self.training:
out = self.max_pool_with_arg_max(x)[0]
else:
out = self.max_pool(x)
out = self.max_pool(x)
return out return out




@@ -198,22 +191,15 @@ class MaxPool1d(_PoolNd):
self.max_pool = P.MaxPool(ksize=self.kernel_size, self.max_pool = P.MaxPool(ksize=self.kernel_size,
strides=self.stride, strides=self.stride,
padding=self.pad_mode) padding=self.pad_mode)
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
self.shape = F.shape self.shape = F.shape
self.reduce_mean = P.ReduceMean(keep_dims=True) self.reduce_mean = P.ReduceMean(keep_dims=True)
self.expand = P.ExpandDims() self.expand = P.ExpandDims()
self.squeeze = P.Squeeze(2) self.squeeze = P.Squeeze(2)
self.is_tbe = context.get_context("device_target") == "Ascend"


def construct(self, x): def construct(self, x):
_shape_check(self.shape(x)) _shape_check(self.shape(x))
x = self.expand(x, 2) x = self.expand(x, 2)
if self.is_tbe and self.training:
output = self.max_pool_with_arg_max(x)[0]
else:
output = self.max_pool(x)
output = self.max_pool(x)
output = self.squeeze(output) output = self.squeeze(output)
return output return output




+ 30
- 65
mindspore/nn/layer/quant.py View File

@@ -433,27 +433,15 @@ class Conv2dBnFoldQuantOneConv(Cell):
(self.is_ge_backend or self.is_ascend) (self.is_ge_backend or self.is_ascend)


# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation)
weight_shape = [1, in_channels, *self.kernel_size]
channel_axis = 1
else:
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation,
group=group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation,
group=group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.channel_axis = channel_axis self.channel_axis = channel_axis
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
@@ -651,27 +639,15 @@ class Conv2dBnFoldQuant(Cell):
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"


# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation)
weight_shape = [1, in_channels, *self.kernel_size]
channel_axis = 1
else:
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation,
group=group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation,
group=group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if Validator.check_bool(has_bias): if Validator.check_bool(has_bias):
@@ -830,28 +806,16 @@ class Conv2dBnWithoutFoldQuant(Cell):
else: else:
self.bias = None self.bias = None
# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation)
weight_shape = [1, in_channels, *self.kernel_size]
channel_axis = 1
else:
self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.fake_quant_weight = quant_config.weight(min_init=-6, self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6, max_init=6,
@@ -963,10 +927,11 @@ class Conv2dQuant(Cell):
stride=self.stride, stride=self.stride,
dilation=self.dilation, dilation=self.dilation,
group=self.group) group=self.group)
channel_axis = 0
self.fake_quant_weight = quant_config.weight(min_init=-6, self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
channel_axis=0,
channel_axis=channel_axis,
num_channels=out_channels, num_channels=out_channels,
quant_dtype=quant_dtype) quant_dtype=quant_dtype)




+ 6
- 24
mindspore/ops/operations/nn_ops.py View File

@@ -1574,32 +1574,12 @@ class MaxPoolWithArgmax(_Pool):


def infer_shape(self, x_shape): def infer_shape(self, x_shape):
out_shape = _Pool.infer_shape(self, x_shape) out_shape = _Pool.infer_shape(self, x_shape)
_, _, out_h, out_w = out_shape
_, kernel_h, kernel_w, _ = self.ksize

argmax_shape = []
if self.is_tbe:
for i in range(4):
if i == 2:
dim = kernel_h * kernel_w
argmax_shape.append(dim)
elif i == 3:
dim = math.ceil(out_h * out_w / 16) + 1
argmax_shape.append(dim)
else:
argmax_shape.append(x_shape[i])
else:
argmax_shape = out_shape

return out_shape, argmax_shape
return out_shape, out_shape


def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
out_dtype = x_dtype
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
argmax_dtype = mstype.uint16
if self.is_gpu:
argmax_dtype = mstype.int32
return out_dtype, argmax_dtype
argmax_dtype = mstype.int32
return x_dtype, argmax_dtype




class AvgPool(_Pool): class AvgPool(_Pool):
@@ -6070,7 +6050,9 @@ class Dropout(PrimitiveWithInfer):
""" """


@prim_attr_register @prim_attr_register
def __init__(self, keep_prob=0.5):
def __init__(self, keep_prob=0.5, Seed0=0, Seed1=0):
self.seed0 = validator.check_value_type("Seed0", Seed0, [int], self.name)
self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)


def infer_shape(self, x_shape): def infer_shape(self, x_shape):


+ 1
- 1
tests/ut/python/ops/test_ops.py View File

@@ -1615,7 +1615,7 @@ test_case_nn_ops = [
('MaxPoolWithArgmax', { ('MaxPoolWithArgmax', {
'block': P.MaxPoolWithArgmax(ksize=2, strides=2), 'block': P.MaxPoolWithArgmax(ksize=2, strides=2),
'desc_inputs': [[128, 32, 32, 64]], 'desc_inputs': [[128, 32, 32, 64]],
'desc_bprop': [[128, 32, 16, 32], ([128, 32, 4, 33], {'dtype': np.uint16})]}),
'desc_bprop': [[128, 32, 16, 32], ([128, 32, 16, 32], {'dtype': np.int32})]}),
('SoftmaxCrossEntropyWithLogits', { ('SoftmaxCrossEntropyWithLogits', {
'block': P.SoftmaxCrossEntropyWithLogits(), 'block': P.SoftmaxCrossEntropyWithLogits(),
'desc_inputs': [[1, 10], [1, 10]], 'desc_inputs': [[1, 10], [1, 10]],


+ 53
- 1
tests/ut/python/parallel/test_matmul_dropout.py View File

@@ -18,7 +18,11 @@ import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.common.seed import _get_graph_seed
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore._checkparam import Validator
from mindspore.ops.primitive import constexpr
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
@@ -47,13 +51,61 @@ class GradWrap(nn.Cell):
return grad_all(self.network)(x, y, b) return grad_all(self.network)(x, y, b)




@constexpr
def _is_float_dtype(dtype):
if dtype in [mstype.float32, mstype.float16]:
return True
return False

class Dropout(nn.Cell):
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
self.keep_prob = keep_prob
seed0, seed1 = _get_graph_seed(0, "dropout")
self.seed0 = seed0
self.seed1 = seed1
self.dtype = dtype
self.get_shape = P.Shape()
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
self.dropout_do_mask = P.DropoutDoMask()
self.cast = P.Cast()
self.is_gpu = context.get_context('device_target') in ["GPU"]
self.dropout = P.Dropout(keep_prob)

def construct(self, x):
if not self.training:
return x

if self.is_gpu:
out, _ = self.dropout(x)
return out

if self.keep_prob == 1:
return x

shape = self.get_shape(x)
dtype = P.DType()(x)
if _is_float_dtype(dtype):
keep_prob = self.cast(self.keep_prob, dtype)
else:
keep_prob = self.cast(self.keep_prob, mstype.float16)
output = self.dropout_gen_mask(shape, keep_prob)
return self.dropout_do_mask(x, output, keep_prob)

def extend_repr(self):
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)

# model_parallel test # model_parallel test
def test_two_matmul_dropout(): def test_two_matmul_dropout():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy2, strategy3): def __init__(self, strategy1, strategy2, strategy3):
super().__init__() super().__init__()
self.matmul1 = P.MatMul().shard(strategy1) self.matmul1 = P.MatMul().shard(strategy1)
self.dropout = nn.Dropout()
self.dropout = Dropout()
self.dropout.dropout_do_mask.shard(strategy2) self.dropout.dropout_do_mask.shard(strategy2)
self.dropout.dropout_gen_mask.shard(strategy2) self.dropout.dropout_gen_mask.shard(strategy2)
self.matmul2 = P.MatMul().shard(strategy3) self.matmul2 = P.MatMul().shard(strategy3)


+ 31
- 1
tests/ut/python/parallel/test_one_dev.py View File

@@ -19,11 +19,14 @@ import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import _selected_ops
from mindspore.parallel._utils import _reset_op_id from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model from mindspore.train import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
@@ -66,6 +69,33 @@ class AllToAllNet(nn.Cell):
return x return x




class SoftmaxCrossEntropyWithLogits(_Loss):
def __init__(self,
sparse=False,
reduction='none'):
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.sparse = sparse
self.reduction = reduction
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0., mstype.float32)
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"]

if self.is_cpugpu:
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits()

def construct(self, logits, labels):
if self.is_cpugpu and self.sparse and self.reduction == 'mean':
x = self.sparse_softmax_cross_entropy(logits, labels)
return x

if self.sparse:
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
x = self.softmax_cross_entropy(logits, labels)[0]
return self.get_loss(x)


def all_to_all_net(): def all_to_all_net():
return AllToAllNet() return AllToAllNet()




Loading…
Cancel
Save