|
|
|
@@ -54,14 +54,16 @@ void ClearPythonParasMap() { python_paras = nullptr; } |
|
|
|
namespace { |
|
|
|
const int kSummaryGetItem = 2; |
|
|
|
const size_t max_depth = 128; |
|
|
|
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx, bool *check_dynamic) { |
|
|
|
bool IsShapeDynamic(const abstract::ShapePtr &shape) { |
|
|
|
if (shape == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; }); |
|
|
|
} |
|
|
|
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, size_t *idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (*check_dynamic) { |
|
|
|
if (node->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(node->cast<CNodePtr>())) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} else if (AnfAlgo::IsRealKernel(node)) { |
|
|
|
if (AnfAlgo::IsRealKernel(node)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
(*idx) += 1; |
|
|
|
@@ -69,7 +71,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, |
|
|
|
if (*idx <= max_depth) { |
|
|
|
auto users = manager->node_users()[node]; |
|
|
|
if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { |
|
|
|
return RecursiveCheck(manager, kernel.first, idx, check_dynamic); |
|
|
|
return RecursiveCheck(manager, kernel.first, idx); |
|
|
|
})) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -82,24 +84,8 @@ bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &no |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto node_users = manager->node_users()[node]; |
|
|
|
size_t idx = 0; |
|
|
|
bool check_dynamic = false; |
|
|
|
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { |
|
|
|
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); |
|
|
|
})) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsUsedByDynamicKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto node_users = manager->node_users()[node]; |
|
|
|
size_t idx = 0; |
|
|
|
bool check_dynamic = true; |
|
|
|
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) { |
|
|
|
return RecursiveCheck(manager, kernel.first, &idx, &check_dynamic); |
|
|
|
return RecursiveCheck(manager, kernel.first, &idx); |
|
|
|
})) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -481,7 +467,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const |
|
|
|
builder.SetOutputsFormat({format}); |
|
|
|
d_kernel_info->set_select_kernel_build_info(builder.Build()); |
|
|
|
AnfAlgo::SetOutputAddr(address, 0, parameter.get()); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get()); |
|
|
|
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), |
|
|
|
parameter->Shape()->cast<abstract::BaseShapePtr>()); |
|
|
|
parameter->set_abstract(abstract); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -954,7 +942,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con |
|
|
|
if (!IsUsedByRealKernel(manager, input_node)) { |
|
|
|
node_ptr->set_used_by_real_kernel(); |
|
|
|
} |
|
|
|
if (IsUsedByDynamicKernel(manager, input_node)) { |
|
|
|
auto shape = node_ptr->Shape(); |
|
|
|
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) { |
|
|
|
node_ptr->set_used_by_dynamic_kernel(); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1043,7 +1032,8 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP |
|
|
|
if (!IsUsedByRealKernel(manager, input_node)) { |
|
|
|
node_ptr->set_used_by_real_kernel(); |
|
|
|
} |
|
|
|
if (IsUsedByDynamicKernel(manager, input_node)) { |
|
|
|
auto shape = node_ptr->Shape(); |
|
|
|
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) { |
|
|
|
node_ptr->set_used_by_dynamic_kernel(); |
|
|
|
} |
|
|
|
} |
|
|
|
|