| @@ -51,26 +51,55 @@ static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras; | |||||
| void ClearPythonParasMap() { python_paras = nullptr; } | void ClearPythonParasMap() { python_paras = nullptr; } | ||||
| namespace { | namespace { | ||||
| const int kSummaryGetItem = 2; | const int kSummaryGetItem = 2; | ||||
| bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | |||||
| const size_t max_depth = 128; | |||||
| bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto node_users = manager->node_users()[node]; | |||||
| for (auto item : node_users) { | |||||
| if (AnfAlgo::IsRealKernel(item.first)) { | |||||
| if (*check_dynamic) { | |||||
| if (node->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(node->cast<CNodePtr>())) { | |||||
| return true; | |||||
| } | |||||
| } else if (AnfAlgo::IsRealKernel(node)) { | |||||
| return true; | |||||
| } | |||||
| (*idx) += 1; | |||||
| // max recursion depth | |||||
| if (*idx <= max_depth) { | |||||
| auto users = manager->node_users()[node]; | |||||
| if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { | |||||
| return RecursiveCheck(manager, kernel.first, idx, check_dynamic); | |||||
| })) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto node_users = manager->node_users()[node]; | |||||
| size_t idx = 0; | |||||
| bool check_dynamic = false; | |||||
| if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { | |||||
| return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); | |||||
| })) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto node_users = manager->node_users()[node]; | auto node_users = manager->node_users()[node]; | ||||
| for (auto item : node_users) { | |||||
| if (item.first->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(item.first->cast<CNodePtr>())) { | |||||
| return true; | |||||
| } | |||||
| size_t idx = 0; | |||||
| bool check_dynamic = true; | |||||
| if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { | |||||
| return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); | |||||
| })) { | |||||
| return true; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h" | #include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include <set> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| #include "runtime/kernel.h" | #include "runtime/kernel.h" | ||||
| @@ -27,6 +28,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| std::set<std::string> kComputeDepend = {"Unique"}; | |||||
| AiCpuDynamicKernel::~AiCpuDynamicKernel() { | AiCpuDynamicKernel::~AiCpuDynamicKernel() { | ||||
| // free dev ptr | // free dev ptr | ||||
| if (ext_info_addr_dev_ == nullptr) { | if (ext_info_addr_dev_ == nullptr) { | ||||
| @@ -67,9 +69,11 @@ void AiCpuDynamicKernel::Initialize() { | |||||
| output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | ||||
| UnknowShapeOpType shape_type = UnknowShapeOpType::DEPEND_IN_SHAPE; | UnknowShapeOpType shape_type = UnknowShapeOpType::DEPEND_IN_SHAPE; | ||||
| if (AnfAlgo::GetCNodeName(cnode_ptr_) == "Unique") { | |||||
| auto op_name = AnfAlgo::GetCNodeName(cnode_ptr_); | |||||
| if (kComputeDepend.find(op_name) != kComputeDepend.end()) { | |||||
| shape_type = UnknowShapeOpType::DEPEND_COMPUTE; | shape_type = UnknowShapeOpType::DEPEND_COMPUTE; | ||||
| } | } | ||||
| unknow_type_ = shape_type; | |||||
| // Parse aicpu ext info | // Parse aicpu ext info | ||||
| if (is_dynamic_shape_) { | if (is_dynamic_shape_) { | ||||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | MS_EXCEPTION_IF_NULL(cnode_ptr_); | ||||
| @@ -141,7 +145,7 @@ bool AiCpuDynamicKernel::UpdateExtInfo() { | |||||
| ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode_ptr_)); | ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode_ptr_)); | ||||
| } | } | ||||
| if (unknow_type_ != DEPEND_COMPUTE) { | |||||
| if (AnfAlgo::IsDynamicShape(cnode_ptr_) && unknow_type_ != DEPEND_COMPUTE) { | |||||
| for (size_t i = 0; i < output_num_; ++i) { | for (size_t i = 0; i < output_num_; ++i) { | ||||
| ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode_ptr_)); | ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode_ptr_)); | ||||
| } | } | ||||
| @@ -198,6 +202,9 @@ bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() { | |||||
| void AiCpuDynamicKernel::PostExecute() { | void AiCpuDynamicKernel::PostExecute() { | ||||
| MS_LOG(INFO) << "Aicpu " << cnode_ptr_->fullname_with_scope() << " PostExecute"; | MS_LOG(INFO) << "Aicpu " << cnode_ptr_->fullname_with_scope() << " PostExecute"; | ||||
| if (unknow_type_ != DEPEND_COMPUTE) { | |||||
| return; | |||||
| } | |||||
| if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { | if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { | ||||
| MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | ||||
| return; | return; | ||||
| @@ -40,7 +40,7 @@ class AiCpuDynamicKernel : public DynamicKernel { | |||||
| ext_info_size_(0), | ext_info_size_(0), | ||||
| input_num_(0), | input_num_(0), | ||||
| output_num_(0), | output_num_(0), | ||||
| unknow_type_(DEPEND_COMPUTE) {} | |||||
| unknow_type_(DEPEND_IN_SHAPE) {} | |||||
| ~AiCpuDynamicKernel() override; | ~AiCpuDynamicKernel() override; | ||||