| @@ -79,7 +79,6 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| : AnfAlgo::GetOutputInferShape(input_node, insert_index); | : AnfAlgo::GetOutputInferShape(input_node, insert_index); | ||||
| bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) | bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) | ||||
| : trans::IsNeedPadding(input_format, input_node_out_shape.size()); | : trans::IsNeedPadding(input_format, input_node_out_shape.size()); | ||||
| if (!need_padding) { | if (!need_padding) { | ||||
| // don't need padding insert transdata only | // don't need padding insert transdata only | ||||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); | trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); | ||||
| @@ -121,7 +121,9 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf | |||||
| if (!NeedFusion(graph, node, &batchnorm)) { | if (!NeedFusion(graph, node, &batchnorm)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return CreateBNInfer(graph, batchnorm, node); | |||||
| auto bn_infer = CreateBNInfer(graph, batchnorm, node); | |||||
| TransferDepend(batchnorm, graph, bn_infer); | |||||
| return bn_infer; | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -81,7 +81,7 @@ bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { | |||||
| bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm_grad) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto tuple_getitem = node->cast<CNodePtr>(); | auto tuple_getitem = node->cast<CNodePtr>(); | ||||
| @@ -93,12 +93,12 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat | |||||
| return false; | return false; | ||||
| } | } | ||||
| AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); | |||||
| MS_EXCEPTION_IF_NULL(batchnormgrad_anf); | |||||
| MS_EXCEPTION_IF_NULL(batchnormgrad); | |||||
| *batchnormgrad = batchnormgrad_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(*batchnormgrad); | |||||
| return CheckBatchNormGrad(graph, *batchnormgrad); | |||||
| AnfNodePtr batchnorm_grad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); | |||||
| MS_EXCEPTION_IF_NULL(batchnorm_grad_anf); | |||||
| MS_EXCEPTION_IF_NULL(batchnorm_grad); | |||||
| *batchnorm_grad = batchnorm_grad_anf->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(*batchnorm_grad); | |||||
| return CheckBatchNormGrad(graph, *batchnorm_grad); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -117,11 +117,13 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| CNodePtr batchnormgrad = nullptr; | |||||
| if (!NeedFusion(graph, node, &batchnormgrad)) { | |||||
| CNodePtr batchnorm_grad = nullptr; | |||||
| if (!NeedFusion(graph, node, &batchnorm_grad)) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return CreateBNInferGrad(graph, batchnormgrad, node); | |||||
| auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node); | |||||
| TransferDepend(batchnorm_grad, graph, bn_infer_grad); | |||||
| return bn_infer_grad; | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -872,5 +872,26 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { | |||||
| return new_value_node; | return new_value_node; | ||||
| } | } | ||||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { | |||||
| MS_EXCEPTION_IF_NULL(old_node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| // find BatchNorm's output which is a Depend or ControlDepend | |||||
| for (const auto &node_index : manager->node_users()[old_node]) { | |||||
| AnfNodePtr output = node_index.first; | |||||
| size_t index = IntToSize(node_index.second); | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||||
| auto control_depend = output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(control_depend); | |||||
| control_depend->set_input(index, new_node); | |||||
| } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) { | |||||
| auto depend = output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(depend); | |||||
| depend->set_input(index, new_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -203,6 +203,9 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor | |||||
| // Create a new value node of func graph,not kernel graph | // Create a new value node of func graph,not kernel graph | ||||
| ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); | ||||
| // Transfer depend or control_depend to the new node | |||||
| void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ | #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_ | ||||
| @@ -27,7 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| static std::vector<size_t> g_output_idx; | static std::vector<size_t> g_output_idx; | ||||
| bool HasAtomic(const AnfNodePtr &input) { | bool HasAtomic(const AnfNodePtr &input) { | ||||
| @@ -98,7 +98,7 @@ void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) { | |||||
| } | } | ||||
| } | } | ||||
| bool CheckSegments(size_t segments, size_t communication_op_node_size, std::vector<size_t> *segment_index) { | |||||
| bool CheckSegments(size_t segments, size_t communication_op_node_size, const std::vector<size_t> *segment_index) { | |||||
| MS_EXCEPTION_IF_NULL(segment_index); | MS_EXCEPTION_IF_NULL(segment_index); | ||||
| if (segments >= communication_op_node_size) { | if (segments >= communication_op_node_size) { | ||||
| MS_LOG(INFO) << "fusion not changed: segment_num=" << segments | MS_LOG(INFO) << "fusion not changed: segment_num=" << segments | ||||
| @@ -24,7 +24,7 @@ namespace opt { | |||||
| class ConstToAttrStridedSliceGradPass : public PatternProcessPass { | class ConstToAttrStridedSliceGradPass : public PatternProcessPass { | ||||
| public: | public: | ||||
| explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) | explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) | ||||
| : PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {} | |||||
| : PatternProcessPass("const_to_attr_strided_slice_grad", multigraph) {} | |||||
| ~ConstToAttrStridedSliceGradPass() override = default; | ~ConstToAttrStridedSliceGradPass() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||