/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ #include #include #include #include "frontend/optimizer/irpass.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "abstract/dshape.h" namespace mindspore { namespace opt { namespace irpass { using abstract::Shape; using abstract::ShapePtr; // {ReduceLike, X, axis} class ReduceOneEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); PrimitivePtr prim; if (IsPrimitiveCNode(node, prim::kPrimReduceMean) || IsPrimitiveCNode(node, prim::kPrimReduceAll) || IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimReduceMax) || IsPrimitiveCNode(node, prim::kPrimReduceMin)) { prim = GetValueNode(node->cast()->input(0)); AnfVisitor::Match(prim, {IsNode, IsVNode})(node); if (!is_axis_one_) { return nullptr; } // consider keep_dims auto keep_dims = prim->GetAttr("keep_dims"); auto is_keep_dims = GetValue(keep_dims); // {_Reduce, X, axis} -> X if (is_keep_dims) { return x_; } // {_Reduce, Tensor} if (is_tensor_) { return nullptr; } // {_Reduce, X, axis} -> {Reshape, X, new_shape} std::vector elements; for (size_t i = 0; i < x_shape_.size(); i++) { auto iter = find(axis_.begin(), axis_.end(), i); if (iter == axis_.end()) { ValuePtr s = MakeValue(x_shape_[i]); elements.push_back(s); } } auto new_shape = std::make_shared(elements); auto reshape_op = prim::GetPythonOps("reshape", "mindspore.ops.functional")->cast(); auto node_abstract = node->abstract(); // handle auto_parallel get nullptr abstract if (node_abstract != nullptr) { auto new_base_shape = std::make_shared(GetValue>(new_shape)); node_abstract->set_shape(new_base_shape); auto new_node = node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); new_node->set_abstract(node_abstract); return new_node; } return node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); } return nullptr; } void Visit(const AnfNodePtr &node) override { if (!IsVNode(node) && x_ == nullptr) { if (IsValueNode(node)) { is_tensor_ = true; } // get X's shape auto x_shape_abs = node->abstract(); if (x_shape_abs != nullptr) { auto x_track = x_shape_abs->GetShapeTrack()->cast(); if (x_track == nullptr) { return; } auto x_shape = x_track->shape(); (void)std::copy(x_shape.begin(), x_shape.end(), std::back_inserter(x_shape_)); x_ = node; } return; } // check axis AnfVisitor::Visit(node); } void Visit(const ValueNodePtr &vnode) override { if (x_shape_.empty()) { return; } // axis : int if (IsValueNode(vnode)) { auto idx = GetValue(vnode->value()); // axis could be negative if (idx < 0) { idx += SizeToInt(x_shape_.size()); } if (SizeToInt(x_shape_.size()) > idx && x_shape_[IntToSize(idx)] == 1) { is_axis_one_ = true; axis_.push_back(idx); } return; } // axis : tuple(int), default () if (IsValueNode(vnode)) { auto axis = GetValue>(vnode->value()); if (axis.empty()) { return; } auto cmp = std::all_of(axis.cbegin(), axis.cend(), [this](int idx) { // axis could be negative if (idx < 0) { idx += SizeToInt(x_shape_.size()); } return SizeToInt(this->x_shape_.size()) > idx && this->x_shape_[IntToSize(idx)] == 1; }); if (cmp) { is_axis_one_ = true; (void)std::copy(axis.begin(), axis.end(), std::back_inserter(axis_)); } } } void Reset() { axis_.clear(); x_shape_.clear(); x_ = nullptr; is_axis_one_ = false; is_tensor_ = false; } private: bool is_axis_one_{false}, is_tensor_{false}; std::vector axis_{}, x_shape_{}; AnfNodePtr x_{nullptr}; }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_