| @@ -64,6 +64,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | |||
| #include "backend/optimizer/ascend/format_type/add_attr_for_3d_graph.h" | |||
| #include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" | |||
| @@ -224,6 +225,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) | |||
| auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); | |||
| data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | |||
| data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>()); | |||
| data_layout_pm->AddPass(std::make_shared<AddIoFormatAttrFor3DGraph>()); | |||
| data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | |||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/add_attr_for_3d_graph.h" | |||
| #include <memory> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| void AddAttrForAllCNode(const std::vector<AnfNodePtr> &node_list) { | |||
| for (auto node : node_list) { | |||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | |||
| continue; | |||
| } | |||
| AnfAlgo::SetNodeAttr("io_format", MakeValue(kOpFormat_NCDHW), node); | |||
| } | |||
| } | |||
| bool NodeHasAttrIoFormat(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::HasNodeAttr("io_format", cnode)) { | |||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format"); | |||
| return attr == kOpFormat_NCDHW; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| bool AddIoFormatAttrFor3DGraph::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| bool changed = false; | |||
| if (std::any_of(node_list.begin(), node_list.end(), | |||
| [](const AnfNodePtr &node) { return NodeHasAttrIoFormat(node); })) { | |||
| AddAttrForAllCNode(node_list); | |||
| changed = true; | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class AddIoFormatAttrFor3DGraph : public Pass { | |||
| public: | |||
| explicit AddIoFormatAttrFor3DGraph(size_t groups = 1) : Pass("add_attr_for_3d_graph"), groups_(groups) {} | |||
| ~AddIoFormatAttrFor3DGraph() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| size_t groups_ = 1; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H | |||