Browse Source

!1911 add a function to charge the node input and output is a scalar

Merge pull request !1911 from lianliguang/add-a-function-to-charge-the-node-input-or-output-if-is-a-scalar
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
beb714d2d0
9 changed files with 30 additions and 11 deletions
  1. +2
    -2
      mindspore/ccsrc/pre_activate/ascend/ascend_helper.h
  2. +4
    -4
      mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc
  3. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc
  4. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc
  5. +16
    -0
      mindspore/ccsrc/session/anf_runtime_algorithm.cc
  6. +2
    -0
      mindspore/ccsrc/session/anf_runtime_algorithm.h
  7. +2
    -0
      mindspore/ccsrc/utils/utils.h
  8. +1
    -2
      tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc
  9. +1
    -1
      tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc

+ 2
- 2
mindspore/ccsrc/pre_activate/ascend/ascend_helper.h View File

@@ -37,11 +37,11 @@ class SupportedChecker {
public:
SupportedChecker() = default;
virtual ~SupportedChecker() = default;
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
}
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
}


+ 4
- 4
mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc View File

@@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
return nullptr;
}
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) {
return node;
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) {
if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) {
return nullptr;
} else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) {
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
builder->SetKernelType(AICPU_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
@@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
<< node->DebugString() << "]";
}
return node;
return nullptr;
}
} // namespace opt
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc View File

@@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
auto indices_const = CreateValueNode(new_cnode);
new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
MS_LOG(INFO) << "split topk failed, check to aicpu.";
return nullptr;
}


+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc View File

@@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());

auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) {
if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) {
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);


+ 16
- 0
mindspore/ccsrc/session/anf_runtime_algorithm.cc View File

@@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
}
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
}

bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (shape.empty()) {
return true;
}
return shape.size() == kShape1dDims && shape[0] == 1;
}

bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
if (shape.empty()) {
return true;
}
return shape.size() == kShape1dDims && shape[0] == 1;
}
} // namespace session
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/session/anf_runtime_algorithm.h View File

@@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm {
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static bool IsSwitchCall(const CNodePtr &call_node);
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;


+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other";

// some size
const size_t kShape4dDims = 4;
const size_t kShape2dDims = 2;
const size_t kShape5dDims = 5;
const size_t kShape1dDims = 1;
const size_t kCubeSize = 16;
const size_t kMemAlignSize = 512;
const int kParameterDataTensorMask = 0;


+ 1
- 2
tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc View File

@@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
}; // namespace opt


+ 1
- 1
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc View File

@@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
};


Loading…
Cancel
Save