| @@ -54,14 +54,16 @@ void ClearPythonParasMap() { python_paras = nullptr; } | |||
| namespace { | |||
| const int kSummaryGetItem = 2; | |||
| const size_t max_depth = 128; | |||
| bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) { | |||
| bool IsShapeDynamic(const abstract::ShapePtr &shape) { | |||
| if (shape == nullptr) { | |||
| return false; | |||
| } | |||
| return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; }); | |||
| } | |||
| bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (*check_dynamic) { | |||
| if (node->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(node->cast<CNodePtr>())) { | |||
| return true; | |||
| } | |||
| } else if (AnfAlgo::IsRealKernel(node)) { | |||
| if (AnfAlgo::IsRealKernel(node)) { | |||
| return true; | |||
| } | |||
| (*idx) += 1; | |||
| @@ -69,7 +71,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, | |||
| if (*idx <= max_depth) { | |||
| auto users = manager->node_users()[node]; | |||
| if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { | |||
| return RecursiveCheck(manager, kernel.first, idx, check_dynamic); | |||
| return RecursiveCheck(manager, kernel.first, idx); | |||
| })) { | |||
| return true; | |||
| } | |||
| @@ -82,24 +84,8 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto node_users = manager->node_users()[node]; | |||
| size_t idx = 0; | |||
| bool check_dynamic = false; | |||
| if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { | |||
| return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); | |||
| })) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto node_users = manager->node_users()[node]; | |||
| size_t idx = 0; | |||
| bool check_dynamic = true; | |||
| if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { | |||
| return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); | |||
| return RecursiveCheck(manager, kernel.first, &idx); | |||
| })) { | |||
| return true; | |||
| } | |||
| @@ -481,7 +467,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const | |||
| builder.SetOutputsFormat({format}); | |||
| d_kernel_info->set_select_kernel_build_info(builder.Build()); | |||
| AnfAlgo::SetOutputAddr(address, 0, parameter.get()); | |||
| AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get()); | |||
| auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), | |||
| parameter->Shape()->cast<abstract::BaseShapePtr>()); | |||
| parameter->set_abstract(abstract); | |||
| } | |||
| } | |||
| @@ -954,7 +942,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con | |||
| if (!IsUsedByRealKernel(manager, input_node)) { | |||
| node_ptr->set_used_by_real_kernel(); | |||
| } | |||
| if (IsUsedByDynamicKernel(manager, input_node)) { | |||
| auto shape = node_ptr->Shape(); | |||
| if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) { | |||
| node_ptr->set_used_by_dynamic_kernel(); | |||
| } | |||
| } | |||
| @@ -1043,7 +1032,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| if (!IsUsedByRealKernel(manager, input_node)) { | |||
| node_ptr->set_used_by_real_kernel(); | |||
| } | |||
| if (IsUsedByDynamicKernel(manager, input_node)) { | |||
| auto shape = node_ptr->Shape(); | |||
| if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) { | |||
| node_ptr->set_used_by_dynamic_kernel(); | |||
| } | |||
| } | |||