| @@ -41,7 +41,6 @@ | |||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "runtime/device/ascend/profiling/profiling_utils.h" | #include "runtime/device/ascend/profiling/profiling_utils.h" | ||||
| #include "backend/kernel_compiler/tbe/tbe_utils.h" | #include "backend/kernel_compiler/tbe/tbe_utils.h" | ||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "runtime/device/ascend/ascend_memory_manager.h" | #include "runtime/device/ascend/ascend_memory_manager.h" | ||||
| #include "debug/tensor_load.h" | #include "debug/tensor_load.h" | ||||
| #include "debug/data_dump/dump_json_parser.h" | #include "debug/data_dump/dump_json_parser.h" | ||||
| @@ -114,34 +113,6 @@ std::string GetRankId() { | |||||
| } | } | ||||
| return rank_id_str; | return rank_id_str; | ||||
| } | } | ||||
| void InferShapeForNopNode(AnfNodePtr *input_node) { | |||||
| MS_EXCEPTION_IF_NULL(*input_node); | |||||
| if (!opt::IsNopNode(*input_node)) { | |||||
| MS_LOG(INFO) << "Input node is not a nop node, no need infer."; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Infer shape for nop node."; | |||||
| std::stack<AnfNodePtr> nop_road; | |||||
| nop_road.push(*input_node); | |||||
| while (true) { | |||||
| auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0); | |||||
| auto in_node = input_node_with_idx.first; | |||||
| MS_EXCEPTION_IF_NULL(in_node); | |||||
| if (opt::IsNopNode(in_node)) { | |||||
| nop_road.push(in_node); | |||||
| *input_node = in_node; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| while (!nop_road.empty()) { | |||||
| auto nop_node = nop_road.top(); | |||||
| AnfAlgo::InferShape(nop_node->cast<CNodePtr>()); | |||||
| nop_road.pop(); | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_; | std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_; | ||||
| @@ -665,15 +636,6 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap | |||||
| } | } | ||||
| if (dynamic_kernel->is_dynamic_shape()) { | if (dynamic_kernel->is_dynamic_shape()) { | ||||
| auto kernel_node = dynamic_kernel->kernel_node(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto input_size = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| for (size_t i = 0; i < input_size; i++) { | |||||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i); | |||||
| auto input_node = input_node_with_index.first; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| InferShapeForNopNode(&input_node); | |||||
| } | |||||
| dynamic_kernel->InferShape(); | dynamic_kernel->InferShape(); | ||||
| dynamic_kernel->UpdateArgs(); | dynamic_kernel->UpdateArgs(); | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "backend/optimizer/common/helper.h" | |||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| #include "pipeline/jit/static_analysis/static_analysis.h" | #include "pipeline/jit/static_analysis/static_analysis.h" | ||||
| #include "abstract/dshape.h" | #include "abstract/dshape.h" | ||||
| @@ -73,6 +74,7 @@ void DynamicKernel::InferShape() { | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | MS_EXCEPTION_IF_NULL(cnode_ptr_); | ||||
| MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope(); | MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope(); | ||||
| InferShapeRecursive(); | |||||
| auto inputs = cnode_ptr_->inputs(); | auto inputs = cnode_ptr_->inputs(); | ||||
| if (inputs.empty()) { | if (inputs.empty()) { | ||||
| @@ -124,5 +126,43 @@ void DynamicKernel::InferShape() { | |||||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | ||||
| cnode_ptr_->set_abstract(eval_result); | cnode_ptr_->set_abstract(eval_result); | ||||
| } | } | ||||
| void DynamicKernel::InferShapeRecursive() { | |||||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||||
| for (size_t i = 0; i < input_size; i++) { | |||||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i); | |||||
| auto input_node = input_node_with_index.first; | |||||
| MS_EXCEPTION_IF_NULL(input_node); | |||||
| InferShapeForNopNode(&input_node); | |||||
| } | |||||
| } | |||||
| void DynamicKernel::InferShapeForNopNode(AnfNodePtr *input_node) { | |||||
| MS_EXCEPTION_IF_NULL(*input_node); | |||||
| if (!opt::IsNopNode(*input_node) || !AnfAlgo::IsDynamicShape(*input_node)) { | |||||
| MS_LOG(INFO) << "Input node is not a nop node, no need infer."; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Infer shape for nop node."; | |||||
| std::stack<AnfNodePtr> nop_road; | |||||
| nop_road.push(*input_node); | |||||
| while (true) { | |||||
| auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0); | |||||
| auto in_node = input_node_with_idx.first; | |||||
| MS_EXCEPTION_IF_NULL(in_node); | |||||
| if (opt::IsNopNode(in_node)) { | |||||
| nop_road.push(in_node); | |||||
| *input_node = in_node; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| while (!nop_road.empty()) { | |||||
| auto nop_node = nop_road.top(); | |||||
| AnfAlgo::InferShape(nop_node->cast<CNodePtr>()); | |||||
| nop_road.pop(); | |||||
| } | |||||
| } | |||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -52,6 +52,8 @@ class DynamicKernel { | |||||
| protected: | protected: | ||||
| void RebuildDependTensor(); | void RebuildDependTensor(); | ||||
| void InferShapeRecursive(); | |||||
| void InferShapeForNopNode(AnfNodePtr *input_node); | |||||
| void *stream_; | void *stream_; | ||||
| const CNodePtr cnode_ptr_; | const CNodePtr cnode_ptr_; | ||||