|
|
|
@@ -238,17 +238,29 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
// parameter->transdata->assign<-5d node, ref parameter would get from transdata input |
|
|
|
auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr { |
|
|
|
if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::KPrimTransData)) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto first_input = cnode->input(kFirstDataInputIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(first_input); |
|
|
|
return first_input; |
|
|
|
} |
|
|
|
return node; |
|
|
|
}; |
|
|
|
// prepare referance count |
|
|
|
for (auto &node : search_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// if assign node |
|
|
|
std::set<AnfNodePtr> refed_parameters; |
|
|
|
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) { |
|
|
|
refed_parameters.insert(std::get<1>(iter->second)); |
|
|
|
refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second))); |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &in : node->inputs()) { |
|
|
|
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first; |
|
|
|
visit_node = validate_ref_parameter(visit_node); |
|
|
|
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|