Browse Source

stop converting input of dynamic-shape Reshape to attr

feature/build-system-rewrite
lingyunli63 4 years ago
parent
commit
bae077f524
4 changed files with 24 additions and 5 deletions
  1. +18
    -0
      mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc
  2. +2
    -1
      mindspore/ccsrc/backend/optimizer/common/helper.cc
  3. +1
    -1
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  4. +3
    -3
      mindspore/ccsrc/utils/utils.h

+ 18
- 0
mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc View File

@@ -66,6 +66,21 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
<< "input size : [" << input_tensor_num << "] can not match any kernelInfo !";
}
}

bool SelectAicpuReshapeInTaskSink(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
if (AnfAlgo::GetCNodeName(kernel_node) != "Reshape") {
return false;
}
const size_t AicpuReshapeSize = 2;
if (kernel_node->size() != AicpuReshapeSize) {
return false;
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
return is_task_sink;
}
} // namespace

void CheckKernelInfoListEmpty(const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
@@ -89,6 +104,9 @@ void KernelQueryAll(const CNodePtr &kernel_node,
HcclMetadataInfo(kernel_node, kernel_info_list);
CheckKernelInfoListEmpty(kernel_info_list, "HCCL_Kernel");
}
if (SelectAicpuReshapeInTaskSink(kernel_node)) {
return;
}
if (kernel_info_list->empty()) {
HostMetadataInfo(kernel_node, kernel_info_list);
CheckKernelInfoListEmpty(kernel_info_list, "HOST_Kernel");


+ 2
- 1
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -330,7 +330,8 @@ bool IsNopNode(const AnfNodePtr &node) {
if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
return false;
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimReshape->name() && AnfAlgo::IsNodeDynamicShape(cnode)) {
const size_t kNopNodeInputSize = 2;
if (cnode->size() != kNopNodeInputSize) {
return false;
}
return true;


+ 1
- 1
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1022,7 +1022,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
if (opt::IsNopNode(node) && (skip_nop_node || IsNeedSkipNopOpAddr(node))) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() == kNopNodeInputSize || AnfAlgo::GetCNodeName(cnode) == "Reshape") {
if (cnode->size() == kNopNodeInputSize) {
return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node." << trace::DumpSourceLines(node);


+ 3
- 3
mindspore/ccsrc/utils/utils.h View File

@@ -739,9 +739,9 @@ const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0,
const std::set<std::string> kNoPaddingFormatSet = {kOpFormat_ChannelLast, kOpFormat_FRAC_NZ, kOpFormat_FRACTAL_ZN_RNN,
kOpFormat_ND_RNN_BIAS};

const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kReduceMinOpName,
kReduceMeanOpName, kReduceMaxOpName, kReduceAllOpName, kReduceAnyOpName, kConcatOpName};
const std::set<std::string> DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName, kEmbeddingLookupOpName,
kReduceMinOpName, kReduceMeanOpName, kReduceMaxOpName,
kReduceAllOpName, kReduceAnyOpName, kConcatOpName};

const std::set<std::string> DynamicShapeConstInputToAttrCPU = {
kCastOpName, kExpandDimsOpName, kEmbeddingLookupOpName, kReduceMinOpName, kReduceMeanOpName,


Loading…
Cancel
Save