Browse Source

fix run error when there is a Depend or ControlDepend on BatchNorm

tags/v1.0.0
huanghui 5 years ago
parent
commit
b8e737f66a
8 changed files with 41 additions and 15 deletions
  1. +0
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  2. +3
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc
  3. +12
    -10
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc
  4. +21
    -0
      mindspore/ccsrc/backend/optimizer/common/helper.cc
  5. +3
    -0
      mindspore/ccsrc/backend/optimizer/common/helper.h
  6. +0
    -1
      mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc
  7. +1
    -1
      mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc
  8. +1
    -1
      mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h

+ 0
- 1
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -79,7 +79,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
: trans::IsNeedPadding(input_format, input_node_out_shape.size());

if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());


+ 3
- 1
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc View File

@@ -121,7 +121,9 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
if (!NeedFusion(graph, node, &batchnorm)) {
return nullptr;
}
return CreateBNInfer(graph, batchnorm, node);
auto bn_infer = CreateBNInfer(graph, batchnorm, node);
TransferDepend(batchnorm, graph, bn_infer);
return bn_infer;
}
} // namespace opt
} // namespace mindspore

+ 12
- 10
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc View File

@@ -81,7 +81,7 @@ bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad
return true;
}

bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) {
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm_grad) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto tuple_getitem = node->cast<CNodePtr>();
@@ -93,12 +93,12 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
return false;
}

AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(batchnormgrad_anf);
MS_EXCEPTION_IF_NULL(batchnormgrad);
*batchnormgrad = batchnormgrad_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(*batchnormgrad);
return CheckBatchNormGrad(graph, *batchnormgrad);
AnfNodePtr batchnorm_grad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(batchnorm_grad_anf);
MS_EXCEPTION_IF_NULL(batchnorm_grad);
*batchnorm_grad = batchnorm_grad_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(*batchnorm_grad);
return CheckBatchNormGrad(graph, *batchnorm_grad);
}
} // namespace

@@ -117,11 +117,13 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);

CNodePtr batchnormgrad = nullptr;
if (!NeedFusion(graph, node, &batchnormgrad)) {
CNodePtr batchnorm_grad = nullptr;
if (!NeedFusion(graph, node, &batchnorm_grad)) {
return nullptr;
}
return CreateBNInferGrad(graph, batchnormgrad, node);
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
TransferDepend(batchnorm_grad, graph, bn_infer_grad);
return bn_infer_grad;
}
} // namespace opt
} // namespace mindspore

+ 21
- 0
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -872,5 +872,26 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
return new_value_node;
}

void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// find BatchNorm's output which is a Depend or ControlDepend
for (const auto &node_index : manager->node_users()[old_node]) {
AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
auto control_depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend);
control_depend->set_input(index, new_node);
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
auto depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend);
depend->set_input(index, new_node);
}
}
}
} // namespace opt
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/backend/optimizer/common/helper.h View File

@@ -203,6 +203,9 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor

// Create a new value node of func graph,not kernel graph
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);

// Transfer depend or control_depend to the new node
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_

+ 0
- 1
mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc View File

@@ -27,7 +27,6 @@
namespace mindspore {
namespace opt {
namespace {

static std::vector<size_t> g_output_idx;

bool HasAtomic(const AnfNodePtr &input) {


+ 1
- 1
mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc View File

@@ -98,7 +98,7 @@ void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
}
}
bool CheckSegments(size_t segments, size_t communication_op_node_size, std::vector<size_t> *segment_index) {
bool CheckSegments(size_t segments, size_t communication_op_node_size, const std::vector<size_t> *segment_index) {
MS_EXCEPTION_IF_NULL(segment_index);
if (segments >= communication_op_node_size) {
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments


+ 1
- 1
mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h View File

@@ -24,7 +24,7 @@ namespace opt {
class ConstToAttrStridedSliceGradPass : public PatternProcessPass {
public:
explicit ConstToAttrStridedSliceGradPass(bool multigraph = true)
: PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {}
: PatternProcessPass("const_to_attr_strided_slice_grad", multigraph) {}
~ConstToAttrStridedSliceGradPass() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;


Loading…
Cancel
Save