|
|
|
@@ -68,7 +68,6 @@ void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, s |
|
|
|
MS_EXCEPTION_IF_NULL(n); |
|
|
|
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || |
|
|
|
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { |
|
|
|
int index = 0; |
|
|
|
auto cnode = n->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->inputs().size() <= kSummaryGetItem) { |
|
|
|
@@ -76,24 +75,11 @@ void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, s |
|
|
|
} |
|
|
|
auto node = cnode->input(kSummaryGetItem); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { |
|
|
|
auto c = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(c); |
|
|
|
if (c->inputs().size() != kTupleGetItemInputSize) { |
|
|
|
MS_LOG(EXCEPTION) << "the node tuple_get_item must have 2 inputs!"; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(c->input(kInputNodeOutputIndexInTupleGetItem)); |
|
|
|
auto value_node = c->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>(); |
|
|
|
auto value = value_node->value(); |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
Int32ImmPtr int_imm_ptr = value->cast<Int32ImmPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(int_imm_ptr); |
|
|
|
index = int_imm_ptr->value(); |
|
|
|
node = c->input(kRealInputNodeIndexInTupleGetItem); |
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); |
|
|
|
if (!AnfAlgo::IsRealKernel(item_with_index.first)) { |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); |
|
|
|
} |
|
|
|
std::pair<AnfNodePtr, int> output_pair(node, index); |
|
|
|
// get full name with scope will add scalar or tensor or image summary tag. |
|
|
|
(*summary)[n->fullname_with_scope()] = output_pair; |
|
|
|
(*summary)[n->fullname_with_scope()] = item_with_index; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size(); |
|
|
|
|