/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/optimizer/graph/transpose_strategy.h" #include #include #include #include #include #include "ops/crop.h" #include "ops/fusion/activation.h" #include "ops/fusion/slice_fusion.h" #include "ops/op_utils.h" #include "ops/strided_slice.h" namespace mindspore { namespace opt { namespace { constexpr size_t kFirstInput = 1; constexpr size_t kTransposePerm = 2; constexpr size_t kOnnxStridedSlice = 6; const std::vector NH2NC = {0, 3, 1, 2}; const std::vector NC2NH = {0, 2, 3, 1}; STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector *out_nodes) { auto manager = func_graph->manager(); if (manager == nullptr) { manager = Manage(func_graph, true); } if (manager == nullptr) { MS_LOG(ERROR) << "manager is nullptr."; return lite::RET_ERROR; } auto node_users = manager->node_users()[cnode]; if (node_users.empty()) { MS_LOG(ERROR) << "cnode is isolated."; return lite::RET_ERROR; } std::transform(node_users.begin(), node_users.end(), std::back_inserter(*out_nodes), [](const std::pair &node_user) { return node_user.first; }); return lite::RET_OK; } } // namespace AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm, bool before, size_t index) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode; // judge pair transpose after insert. if (CheckPrimitiveType(trans_input_node, prim::kPrimTranspose)) { std::vector trans_perm; auto input_cnode = trans_input_node->cast(); if (input_cnode == nullptr) { MS_LOG(ERROR) << "input node is invalid."; return nullptr; } if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) { MS_LOG(ERROR) << "transpose perm get failed."; return nullptr; } if ((perm == NH2NC && trans_perm == NC2NH) || (perm == NC2NH && trans_perm == NH2NC)) { return input_cnode->input(kFirstInput); } } // insert depend on shape return TransposeDependOnShape(func_graph, cnode, perm, before, index); } AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm, bool before, size_t index) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode; auto status = TransposeInsertDependOnShape(func_graph, cnode, before, index); if (status == lite::RET_ERROR) { return nullptr; } else if (status == lite::RET_NO_CHANGE) { return before ? cnode->input(index) : cnode; } // insert tranpsoe std::string trans_name = before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post"; auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name); return trans_insert_node; } bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info, TransTypePair *trans_insert_info) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); MS_ASSERT(pre_type != nullptr && post_type != nullptr); size_t trans_count = 0; std::vector in_nodes; for (size_t i = 1; i < cnode->size(); ++i) { if (utils::isa(cnode->input(i))) { in_nodes.push_back(cnode->input(i)); } } if (!IsInOutCanFuison(func_graph, in_nodes, &trans_count, &trans_info->pre_)) { return false; } std::vector out_nodes; if (GetPostNodes(func_graph, cnode, &out_nodes) != lite::RET_OK) { return false; } if (!IsInOutCanFuison(func_graph, out_nodes, &trans_count, &trans_info->post_)) { return false; } if (trans_info->pre_ == trans_info->post_) { return false; } auto total_node_count = in_nodes.size() + out_nodes.size(); bool can_insert = trans_count > total_node_count / 2; if (CheckPrimitiveType(cnode, prim::kPrimActivation)) { auto prim_act = GetValueNode>(cnode->input(0)); MS_ASSERT(prim_act != nullptr); if (prim_act->get_activation_type() == mindspore::ActivationType::LEAKY_RELU) { can_insert = trans_count >= total_node_count / 2; } } if (CheckPrimitiveType(cnode, prim::kPrimSplit) || CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { can_insert = trans_count >= total_node_count / 2; } if (!can_insert) { return can_insert; } DecidePreAndPostTransType(trans_info, trans_insert_info); return can_insert; } bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); auto shape = node_infer_shape_.GetInputShape(cnode, 1); if (shape.size() != 4) { if (cnode->size() > 2) { shape = node_infer_shape_.GetInputShape(cnode, 2); if (shape.size() != 4 && !shape.empty()) { return false; } } else { return false; } } if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { auto prim = GetValueNode(cnode->input(0)); if (prim->GetAttr(ops::kAxis) == nullptr) { return false; } } if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { for (size_t i = 2; i < cnode->size(); ++i) { if (utils::isa(cnode->input(i))) { return false; } } if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) { return false; } } return true; } STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); auto shape = node_infer_shape_.GetInputShape(cnode, 1); if (shape.size() != 4) { if (cnode->size() > 2) { shape = node_infer_shape_.GetInputShape(cnode, 2); if (shape.size() != 4 && !shape.empty()) { return lite::RET_NOT_SUPPORT; } } else { return lite::RET_NOT_SUPPORT; } } auto axis_map = GetNC2NHAxisMap(); if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { auto prim = GetValueNode(cnode->input(0)); if (prim->GetAttr(ops::kAxis) == nullptr) { return lite::RET_NOT_SUPPORT; } auto axis = GetValue(prim->GetAttr(ops::kAxis)); auto new_axis = axis_map[axis < 0 ? axis + 4 : axis]; prim->AddAttr(ops::kAxis, MakeValue(new_axis)); } if (CheckPrimitiveType(cnode, prim::kPrimCrop)) { auto crop_prim = GetValueNode>(cnode->input(0)); if (crop_prim == nullptr) { return lite::RET_NULL_PTR; } auto axis = crop_prim->get_axis(); auto offsets = crop_prim->get_offsets(); auto new_axis = axis_map[axis < 0 ? axis + 4 : axis]; if (new_axis == 0) { offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; } else if (new_axis == 3) { offsets = {offsets[1], offsets[2], offsets[0]}; } else { offsets.push_back(0); } crop_prim->set_axis(new_axis); crop_prim->set_offsets(offsets); } if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) { return ChangeOpSlice(func_graph, cnode); } if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { return ChangeOpStrideSlice(func_graph, cnode); } return lite::RET_OK; } STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); auto manager = func_graph->manager(); if (manager == nullptr) { manager = Manage(func_graph, true); } if (manager == nullptr) { MS_LOG(ERROR) << "manager is nullptr."; return lite::RET_ERROR; } auto node_users = manager->node_users()[cnode]; if (node_users.empty()) { MS_LOG(ERROR) << "cnode is isolated."; return lite::RET_ERROR; } if (!utils::isa(node_users.front().first)) { return lite::RET_ERROR; } CNodePtr base_node = before ? cnode : node_users.front().first->cast(); size_t input_index = before ? index : node_users.front().second; auto shape = node_infer_shape_.GetInputShape(base_node, input_index); if (!shape.empty() && shape.size() != NH2NC.size()) { return lite::RET_NO_CHANGE; } return lite::RET_OK; } bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const std::vector &nodes, size_t *trans_count, FormatTransNodeType *trans_type) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(trans_count != nullptr && trans_type != nullptr); for (auto &node : nodes) { if (CheckPrimitiveType(node, prim::kPrimTranspose)) { FormatTransNodeType cur_type; std::vector perm; auto cnode = node->cast(); if (cnode == nullptr) { return false; } if (GetTransposePerm(cnode, &perm) != lite::RET_OK) { return false; } if (perm == NH2NC) { cur_type = kNHWC2NCHW; } else if (perm == NC2NH) { cur_type = kNCHW2NHWC; } else { return false; } if (*trans_type == kNONE) { *trans_type = cur_type; } else if (*trans_type != cur_type) { return false; } *trans_count += 1; } } return true; } void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) { if (trans_info->pre_ == trans_info->post_) { return; } if (trans_info->pre_ != kNONE && trans_info->post_ != kNONE) { trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; } else if (trans_info->pre_ == kNONE) { trans_insert_info->pre_ = trans_info->post_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; } else { trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; } } STATUS TransposeStrategy::ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_ASSERT(cnode != nullptr); for (size_t i = 2; i < cnode->size(); ++i) { if (utils::isa(cnode->input(i))) { return lite::RET_NOT_SUPPORT; } } auto shape = node_infer_shape_.GetInputShape(cnode, 2); if (shape.empty()) { return lite::RET_NOT_SUPPORT; } int element_num = shape.front(); auto prim = GetValueNode>(cnode->input(0)); std::vector axes; if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) { for (int index = 0; index < element_num; ++index) { axes.push_back(index); } } else { auto origin_axes = prim->get_axes(); std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes), [](int64_t v) { return static_cast(v); }); } for (size_t i = 2; i < cnode->size(); ++i) { TransformAttrByAxes(func_graph, cnode, i, axes); } auto tmp_axes = TransformOpAxesAttr(axes); std::vector new_axes(tmp_axes.begin(), tmp_axes.end()); prim->set_axes(new_axes); return lite::RET_OK; } STATUS TransposeStrategy::ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { if (cnode->size() != kOnnxStridedSlice) { return lite::RET_NOT_SUPPORT; } for (size_t i = 2; i < cnode->size(); ++i) { if (utils::isa(cnode->input(i))) { return lite::RET_NOT_SUPPORT; } } std::vector axes = node_infer_shape_.GetIntVecInput(cnode, kOnnxStridedSlice - 2); if (axes.empty()) { MS_LOG(ERROR) << "strided slice input invalid."; return lite::RET_ERROR; } for (size_t index = 2; index < cnode->size(); ++index) { if (index == 4) { continue; } TransformAttrByAxes(func_graph, cnode, index, axes); } auto cur_axes = TransformOpAxesAttr(axes); auto param_node = BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(4)->fullname_with_scope()); func_graph->manager()->Replace(cnode->input(4), param_node); return lite::RET_OK; } void TransposeStrategy::TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, const std::vector &axes) { if (cnode == nullptr || input_index >= cnode->size() || axes.empty()) { return; } auto axis_map = GetNC2NHAxisMap(); auto origin_input = node_infer_shape_.GetIntVecInput(cnode, input_index); if (origin_input.size() != axes.size()) { return; } std::vector cur_input; for (int dim = 0; dim < 4; ++dim) { for (size_t index = 0; index < axes.size(); ++index) { int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; if (nhwc_dim == dim) { cur_input.push_back(origin_input[index]); } } } auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope()); func_graph->manager()->Replace(cnode->input(input_index), param_node); } std::vector TransposeStrategy::TransformOpAxesAttr(const std::vector &origin_axes) { auto axis_map = GetNC2NHAxisMap(); std::vector cur_axis; for (size_t i = 0; i < origin_axes.size(); ++i) { cur_axis.push_back(axis_map[origin_axes[i] < 0 ? origin_axes[i] + 4 : origin_axes[i]]); } std::sort(cur_axis.begin(), cur_axis.end()); return cur_axis; } } // namespace opt } // namespace mindspore