Browse Source

!9715 fix if grad with 5d node precision

From: @youui
Reviewed-by: @zhoufeng54,@stsuteng
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a1e204969c
1 changed files with 13 additions and 1 deletions
  1. +13
    -1
      mindspore/ccsrc/backend/session/ascend_control_parser.cc

+ 13
- 1
mindspore/ccsrc/backend/session/ascend_control_parser.cc View File

@@ -238,17 +238,29 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
}
}
}
// parameter->transdata->assign<-5d node, ref parameter would get from transdata input
auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr {
if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::KPrimTransData)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto first_input = cnode->input(kFirstDataInputIndex);
MS_EXCEPTION_IF_NULL(first_input);
return first_input;
}
return node;
};
// prepare referance count
for (auto &node : search_list) {
MS_EXCEPTION_IF_NULL(node);
// if assign node
std::set<AnfNodePtr> refed_parameters;
for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
refed_parameters.insert(std::get<1>(iter->second));
refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
}

for (auto &in : node->inputs()) {
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
visit_node = validate_ref_parameter(visit_node);
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
continue;
}


Loading…
Cancel
Save