Browse Source

Fix the bug when bninfer is used as updatestate

tags/v1.5.0-rc1
Margaret_wangrui 4 years ago
parent
commit
fdd494fbfd
4 changed files with 8 additions and 7 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc
  2. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc
  3. +4
    -3
      mindspore/ccsrc/backend/optimizer/common/helper.cc
  4. +2
    -2
      mindspore/ccsrc/backend/optimizer/common/helper.h

+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc View File

@@ -122,7 +122,7 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
return nullptr;
}
auto bn_infer = CreateBNInfer(graph, batchnorm, node);
TransferDepend(batchnorm, graph, bn_infer);
TransferDependOrUpdateState(batchnorm, graph, bn_infer);
return bn_infer;
}
} // namespace opt


+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc View File

@@ -125,7 +125,7 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
return nullptr;
}
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
TransferDepend(batchnorm_grad, graph, bn_infer_grad);
TransferDependOrUpdateState(batchnorm_grad, graph, bn_infer_grad);
return bn_infer_grad;
}
} // namespace opt


+ 4
- 3
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -916,13 +916,14 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
return new_value_node;
}

void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// Find BatchNorm's output which is a Depend or UpdateState.
for (const auto &node_index : manager->node_users()[old_node]) {
auto node_users = manager->node_users()[old_node];
for (const auto &node_index : node_users) {
AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output);
@@ -930,7 +931,7 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
auto depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend);
depend->set_input(index, new_node);
manager->SetEdge(depend, index, new_node);
}
}
}


+ 2
- 2
mindspore/ccsrc/backend/optimizer/common/helper.h View File

@@ -216,8 +216,8 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
// Create a new value node of func graph,not kernel graph
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);

// Transfer depend to the new node
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
// Transfer depend or updatestate to the new node
void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);

AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);



Loading…
Cancel
Save