Merge pull request !3715 from BowenK/new_patterntags/v0.7.0-beta
| @@ -0,0 +1,158 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/optimizer/pattern.h" | |||||
| #include "pybind_api/api_register.h" | |||||
| #include "pybind_api/export_flags.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| int Pattern::g_id_ = 0; | |||||
| MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) { | |||||
| if (!IsValueNode<Primitive>(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| MatchResultPtr res = std::make_shared<MatchResult>(); | |||||
| if (IsValueNode<Primitive>(node)) { | |||||
| // iterate over all primitives | |||||
| for (auto &iter : primitives_) { | |||||
| if (IsPrimitive(node, iter) || iter->name() == "*") { | |||||
| matched_prim_ = iter; | |||||
| res->add_entry(shared_from_base<IsPrimTypeOf>(), node); | |||||
| return res; | |||||
| } | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| MatchResultPtr CallWith::match(const AnfNodePtr &node) { | |||||
| if (!IsPrimitiveCNode(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| MatchResultPtr res = std::make_shared<MatchResult>(); | |||||
| // IsPrimitiveCNode | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| // Check Primitive ValueNode | |||||
| if (prim_pattern_ != nullptr) { | |||||
| // Passed in prim_pattern | |||||
| auto prim_value_res = prim_pattern_->match(cnode->input(0)); | |||||
| if (prim_value_res == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| res->merge(prim_value_res); | |||||
| } else if (prim_ != nullptr) { | |||||
| // Passed in primitive/primitive str | |||||
| if (!IsPrimitive(cnode->input(0), prim_)) { | |||||
| return nullptr; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Uninitialized CallWith pattern."; | |||||
| } | |||||
| // Check inputs | |||||
| auto p_inputs_size = inputs_.size(); | |||||
| auto node_inputs_size = cnode->size() - 1; | |||||
| if (p_inputs_size != 0 && p_inputs_size != node_inputs_size) { | |||||
| return nullptr; | |||||
| } | |||||
| // If inputs is not specified, add node without looking into its inputs | |||||
| if (p_inputs_size == 0) { | |||||
| res->add_entry(shared_from_base<CallWith>(), cnode->input(0)); | |||||
| return res; | |||||
| } | |||||
| bool failed = false; | |||||
| for (std::size_t i = 0; i < node_inputs_size; i++) { | |||||
| auto pattern = inputs_[i]; | |||||
| auto input = cnode->input(i + 1); | |||||
| auto input_match_result = pattern->match(input); | |||||
| if (input_match_result == nullptr) { | |||||
| failed = true; | |||||
| break; | |||||
| } | |||||
| res->merge(input_match_result); | |||||
| } | |||||
| if (!failed) { | |||||
| res->add_entry(shared_from_base<CallWith>(), cnode->input(0)); | |||||
| return res; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| MatchResultPtr IsIn::match(const AnfNodePtr &node) { | |||||
| for (auto &iter : patterns_) { | |||||
| auto res = iter->match(node); | |||||
| if (res != nullptr) { | |||||
| return res; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| MatchResultPtr IsNot::match(const AnfNodePtr &node) { | |||||
| for (auto &iter : patterns_) { | |||||
| auto res = iter->match(node); | |||||
| if (res != nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| auto res = std::make_shared<MatchResult>(); | |||||
| res->add_entry(shared_from_base<IsNot>(), node); | |||||
| return res; | |||||
| } | |||||
| MatchResultPtr AnyPattern::match(const AnfNodePtr &node) { | |||||
| MatchResultPtr res = std::make_shared<MatchResult>(); | |||||
| res->add_entry(shared_from_base<AnyPattern>(), node); | |||||
| return res; | |||||
| } | |||||
| AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) { | |||||
| auto entry = match_result_.find(pattern); | |||||
| if (entry == match_result_.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return entry->second; | |||||
| } | |||||
| void MatchResult::merge(const MatchResultPtr &other_result) { | |||||
| auto other_result_map = other_result->_result(); | |||||
| // add/update entries in other_result | |||||
| for (auto &iter : other_result_map) { | |||||
| match_result_[iter.first] = iter.second; | |||||
| } | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE( | |||||
| Pattern, ([](const py::module *m) { | |||||
| (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>()); | |||||
| (void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>()); | |||||
| (void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr()) | |||||
| .def(py::init<vector<PrimitivePyPtr>, string, bool>()) | |||||
| .def(py::init<vector<string>, string, bool>()); | |||||
| (void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_") | |||||
| .def(py::init<PatternPtr, vector<PatternPtr>, bool>()) | |||||
| .def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>()) | |||||
| .def(py::init<string, vector<PatternPtr>, bool>()); | |||||
| (void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>()); | |||||
| (void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>()); | |||||
| (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_") | |||||
| .def(py::init<tensor::TensorPtr>()); | |||||
| })); | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,228 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include "base/base.h" | |||||
| #include "ir/anf.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "utils/primitive_py.h" | |||||
| #include "utils/tensor_py.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| using std::string; | |||||
| using std::vector; | |||||
| class MatchResult; | |||||
| using MatchResultPtr = std::shared_ptr<MatchResult>; | |||||
| class Pattern; | |||||
| using PatternPtr = std::shared_ptr<Pattern>; | |||||
| class IsPrimTypeOf; | |||||
| using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>; | |||||
| class CallWith; | |||||
| using CallWithPtr = std::shared_ptr<CallWith>; | |||||
| class NewTensor; | |||||
| using NewTensorPtr = std::shared_ptr<NewTensor>; | |||||
| struct PatternHasher; | |||||
| struct PatternEqual; | |||||
| using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>; | |||||
| class Pattern : public Base { | |||||
| public: | |||||
| Pattern() : unique_name_(std::to_string(g_id_++)) {} | |||||
| virtual MatchResultPtr match(const AnfNodePtr &node) { return nullptr; } | |||||
| virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; } | |||||
| string unique_name() const { return unique_name_; } | |||||
| vector<PatternPtr> inputs() { return inputs_; } | |||||
| bool should_replace() { return should_replace_; } | |||||
| virtual void reset() {} | |||||
| protected: | |||||
| static int g_id_; | |||||
| // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed | |||||
| string unique_name_; | |||||
| vector<PatternPtr> inputs_; | |||||
| bool should_replace_ = true; | |||||
| }; | |||||
| struct PatternEqual { | |||||
| bool operator()(PatternPtr const &p1, PatternPtr const &p2) const { | |||||
| MS_EXCEPTION_IF_NULL(p1); | |||||
| MS_EXCEPTION_IF_NULL(p2); | |||||
| return p1->unique_name() == p2->unique_name(); | |||||
| } | |||||
| }; | |||||
| struct PatternHasher { | |||||
| std::size_t operator()(PatternPtr const &p) const { | |||||
| MS_EXCEPTION_IF_NULL(p); | |||||
| return std::hash<string>()(p->unique_name()); | |||||
| } | |||||
| }; | |||||
| class IsPrimTypeOf : public Pattern { | |||||
| public: | |||||
| IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); } | |||||
| IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace) | |||||
| : primitives_(prims), name_(name), matched_prim_(nullptr) { | |||||
| unique_name_ = std::to_string(g_id_++) + "_" + name; | |||||
| should_replace_ = should_replace; | |||||
| if (!should_replace) { | |||||
| matched_prim_ = prims[0]; | |||||
| } | |||||
| } | |||||
| IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) { | |||||
| unique_name_ = std::to_string(g_id_++) + "_" + name; | |||||
| // Make primitives_ | |||||
| for (auto &iter : types) { | |||||
| primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr))); | |||||
| } | |||||
| should_replace_ = should_replace; | |||||
| if (!should_replace) { | |||||
| matched_prim_ = primitives_[0]; | |||||
| } | |||||
| } | |||||
| MS_DECLARE_PARENT(IsPrimTypeOf, Pattern); | |||||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||||
| PrimitivePyPtr matched_primitive() { return matched_prim_; } | |||||
| void reset() override { | |||||
| if (should_replace_) { | |||||
| matched_prim_ = nullptr; | |||||
| } | |||||
| } | |||||
| private: | |||||
| vector<string> types_; | |||||
| vector<PrimitivePyPtr> primitives_; | |||||
| string name_; | |||||
| PrimitivePyPtr matched_prim_; | |||||
| }; | |||||
| class CallWith : public Pattern { | |||||
| public: | |||||
| CallWith() { unique_name_ = std::to_string(g_id_++); } | |||||
| CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) { | |||||
| // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting | |||||
| prim_pattern_ = prim_pattern; | |||||
| unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name(); | |||||
| inputs_ = inputs; | |||||
| should_replace_ = should_replace; | |||||
| } | |||||
| CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) { | |||||
| prim_ = prim; | |||||
| unique_name_ = std::to_string(g_id_++) + prim_->ToString(); | |||||
| inputs_ = inputs; | |||||
| should_replace_ = should_replace; | |||||
| } | |||||
| CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) { | |||||
| prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr)); | |||||
| unique_name_ = std::to_string(g_id_++) + prim_->ToString(); | |||||
| inputs_ = inputs; | |||||
| should_replace_ = should_replace; | |||||
| } | |||||
| MS_DECLARE_PARENT(CallWith, Pattern); | |||||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||||
| PrimitivePtr prim_value() { return prim_; } | |||||
| PatternPtr prim_pattern() { return prim_pattern_; } | |||||
| private: | |||||
| PatternPtr prim_pattern_ = nullptr; | |||||
| PrimitivePtr prim_ = nullptr; | |||||
| vector<string> types_; | |||||
| string name_; | |||||
| }; | |||||
| class IsIn : public Pattern { | |||||
| public: | |||||
| IsIn() { unique_name_ = std::to_string(g_id_++); } | |||||
| explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) { | |||||
| unique_name_ = std::to_string(g_id_++); | |||||
| for (auto &iter : patterns) { | |||||
| unique_name_ = unique_name_ + "_" + iter->unique_name(); | |||||
| } | |||||
| } | |||||
| MS_DECLARE_PARENT(IsIn, Pattern); | |||||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||||
| private: | |||||
| vector<PatternPtr> patterns_; | |||||
| }; | |||||
| class IsNot : public Pattern { | |||||
| public: | |||||
| IsNot() { unique_name_ = std::to_string(g_id_++); } | |||||
| explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) { | |||||
| unique_name_ = std::to_string(g_id_++); | |||||
| for (auto &iter : patterns) { | |||||
| unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name(); | |||||
| } | |||||
| } | |||||
| MS_DECLARE_PARENT(IsNot, Pattern); | |||||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||||
| private: | |||||
| vector<PatternPtr> patterns_; | |||||
| }; | |||||
| class AnyPattern : public Pattern { | |||||
| public: | |||||
| AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; } | |||||
| MS_DECLARE_PARENT(AnyPattern, Pattern); | |||||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||||
| }; | |||||
| class NewTensor : public Pattern { | |||||
| public: | |||||
| NewTensor() { unique_name_ = std::to_string(g_id_++); } | |||||
| explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; } | |||||
| MS_DECLARE_PARENT(NewTensor, Pattern); | |||||
| MatchResultPtr match(const AnfNodePtr &node) override { | |||||
| MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n"; | |||||
| } | |||||
| tensor::TensorPtr input_tensor() { return input_tensor_; } | |||||
| private: | |||||
| tensor::TensorPtr input_tensor_; | |||||
| }; | |||||
| class MatchResult { | |||||
| public: | |||||
| MatchResult() {} | |||||
| void add_entry(PatternPtr pattern, AnfNodePtr node) { match_result_[pattern] = node; } | |||||
| PatternNodeMap _result() { return match_result_; } | |||||
| AnfNodePtr get_node(const PatternPtr &pattern); | |||||
| void merge(const MatchResultPtr &other_result); | |||||
| void clear() { match_result_.clear(); } | |||||
| void dump() { | |||||
| MS_LOG(DEBUG) << "match_result_.size: " + std::to_string(match_result_.size()) + "\n"; | |||||
| for (auto &iter : match_result_) { | |||||
| MS_LOG(DEBUG) << "Pattern : " + iter.first->unique_name() + " , node : " + iter.second->ToString() + "\n"; | |||||
| } | |||||
| } | |||||
| private: | |||||
| PatternNodeMap match_result_; | |||||
| }; | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PATTERN_H_ | |||||
| @@ -22,6 +22,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "utils/primitive_py.h" | |||||
| #include "pipeline/jit/parse/parse_base.h" | #include "pipeline/jit/parse/parse_base.h" | ||||
| #include "pipeline/jit/resource.h" | #include "pipeline/jit/resource.h" | ||||
| @@ -29,6 +30,8 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| namespace python_pass { | namespace python_pass { | ||||
| namespace internal { | namespace internal { | ||||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res); | |||||
| std::string GetNodeRepr(AnfNodePtr node) { | std::string GetNodeRepr(AnfNodePtr node) { | ||||
| if (node != nullptr) { | if (node != nullptr) { | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| @@ -50,126 +53,104 @@ std::string GetNodeRepr(AnfNodePtr node) { | |||||
| return ""; | return ""; | ||||
| } | } | ||||
| void ResolveFuncGraph_(const FuncGraphPtr &fg) { | |||||
| auto manager = Manage(fg, false); | |||||
| auto use_sig = parse::python_adapter::UseSignatureInResolve(); | |||||
| parse::python_adapter::set_use_signature_in_resolve(false); | |||||
| parse::ResolveAll(manager); | |||||
| parse::python_adapter::set_use_signature_in_resolve(use_sig); | |||||
| } | |||||
| bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | |||||
| bool IsTraversable(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(pattern); | |||||
| if (pattern->isa<ValueNode>()) { | |||||
| if (!node->isa<ValueNode>()) { | |||||
| return false; | |||||
| } | |||||
| if (GetNodeRepr(pattern) == GetNodeRepr(node)) { | |||||
| // add to equiv_ptr | |||||
| equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } else if (pattern->isa<Parameter>()) { | |||||
| MS_LOG(DEBUG) << pattern->ToString() + "\n"; | |||||
| // add to equiv_ptr | |||||
| equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); | |||||
| if (node->isa<CNode>() || node->isa<Parameter>()) { | |||||
| return true; | return true; | ||||
| } else if (pattern->isa<CNode>()) { | |||||
| // match every single sub ANode | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto pattern_inputs = pattern->cast<CNodePtr>()->inputs(); | |||||
| auto node_inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (pattern_inputs.size() != node_inputs.size()) { | |||||
| return false; | |||||
| } | |||||
| for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); | |||||
| p_item++, node_item++) { | |||||
| auto res = Match(*p_item, *node_item, equiv_ptr); | |||||
| if (!res) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; | |||||
| return false; | |||||
| } | } | ||||
| AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, | |||||
| const NodeEquivPtr &equiv_ptr) { | |||||
| if (cur_raw_dst_node_->isa<Parameter>()) { | |||||
| auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); | |||||
| if (sub_pair != equiv_ptr->end()) { | |||||
| return sub_pair->second; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; | |||||
| } else if (cur_raw_dst_node_->isa<ValueNode>()) { | |||||
| // check primitive ValueNode | |||||
| auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast<ValueNodePtr>()->value()->ToString()); | |||||
| if (sub_pair != equiv_ptr->end()) { | |||||
| return sub_pair->second; | |||||
| } | |||||
| return cur_raw_dst_node_; | |||||
| } else if (cur_raw_dst_node_->isa<CNode>()) { | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| auto inputs = cur_raw_dst_node_->cast<CNodePtr>()->inputs(); | |||||
| for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { | |||||
| auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); | |||||
| new_inputs.push_back(subed); | |||||
| } | |||||
| return func_graph->NewCNode(new_inputs); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); | |||||
| AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) { | |||||
| // Build up AnfNode from primitive | |||||
| auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim_pattern); | |||||
| PrimitivePyPtr prim = prim_pattern->matched_primitive(); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| // Make value node out of primitives | |||||
| return std::make_shared<ValueNode>(prim); | |||||
| } | } | ||||
| bool isTraversable(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (node->isa<CNode>() || node->isa<Parameter>()) { | |||||
| return true; | |||||
| } | |||||
| if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) { | |||||
| return true; | |||||
| AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) { | |||||
| // Build a ValueNode from TensorPtr | |||||
| auto new_tensor_pattern = pattern->cast<NewTensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(new_tensor_pattern); | |||||
| auto input_tensor = new_tensor_pattern->input_tensor(); | |||||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||||
| return std::make_shared<ValueNode>(input_tensor); | |||||
| } | |||||
| AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) { | |||||
| auto call_with_pattern = pattern->cast<CallWithPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(call_with_pattern); | |||||
| auto prim = call_with_pattern->prim_value(); | |||||
| if (prim != nullptr) { | |||||
| return std::make_shared<ValueNode>(prim); | |||||
| } | } | ||||
| return false; | |||||
| auto prim_pattern = call_with_pattern->prim_pattern(); | |||||
| MS_EXCEPTION_IF_NULL(prim_pattern); | |||||
| return ProcessSinglePattern(prim_pattern, res); | |||||
| } | } | ||||
| } // namespace internal | |||||
| void PythonPass::Build(const py::function &src, const py::function &dst) { | |||||
| // 1. get FuncGraph from py::function | |||||
| auto src_fg_ = parse::ParsePythonCode(src); | |||||
| auto dst_fg_ = parse::ParsePythonCode(dst); | |||||
| if (src_fg_ == nullptr || dst_fg_ == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; | |||||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) { | |||||
| if (pattern->should_replace()) { | |||||
| // Find replacement in the MatchResult | |||||
| auto target_node = res->get_node(pattern); | |||||
| if (target_node == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n"; | |||||
| } | |||||
| return target_node; | |||||
| } | } | ||||
| // 2. Resolve | |||||
| internal::ResolveFuncGraph_(src_fg_); | |||||
| internal::ResolveFuncGraph_(dst_fg_); | |||||
| // 3. from FuncGraphPtr to ValueNode | |||||
| src_node_ = src_fg_->output(); | |||||
| dst_node_ = dst_fg_->output(); | |||||
| // Build up new node from pattern | |||||
| if (pattern->isa<IsPrimTypeOf>()) { | |||||
| return BuildPrimitive(pattern, res); | |||||
| } else if (pattern->isa<NewTensor>()) { | |||||
| return BuildNewTensor(pattern, res); | |||||
| } else if (pattern->isa<CallWith>()) { | |||||
| return BuildPrimitiveValueNode(pattern, res); | |||||
| } | |||||
| return nullptr; | |||||
| } | } | ||||
| PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, | |||||
| bool multigraph) | |||||
| : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { | |||||
| Build(src, dst); | |||||
| AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) { | |||||
| auto target_inputs = pattern->inputs(); | |||||
| if (target_inputs.size() == 0) { | |||||
| return ProcessSinglePattern(pattern, res); | |||||
| } | |||||
| // Build up the AnfNode in a recursive manner | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| auto prim_value_node = ProcessSinglePattern(pattern, res); | |||||
| MS_EXCEPTION_IF_NULL(prim_value_node); | |||||
| new_inputs.push_back(prim_value_node); | |||||
| for (auto &iter : target_inputs) { | |||||
| if (iter == pattern) { | |||||
| MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() + | |||||
| "\n"; | |||||
| } | |||||
| new_inputs.push_back(BuildTarget(iter, func_graph, res)); | |||||
| } | |||||
| return func_graph->NewCNode(new_inputs); | |||||
| } | } | ||||
| } // namespace internal | |||||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | ||||
| auto equiv_ptr = std::make_shared<NodeEquiv>(); | |||||
| bool is_a_match = internal::Match(src_node_, node, equiv_ptr); | |||||
| if (is_a_match) { | |||||
| auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); | |||||
| MS_EXCEPTION_IF_NULL(src_pattern_); | |||||
| MS_EXCEPTION_IF_NULL(dst_pattern_); | |||||
| auto res = src_pattern_->match(node); | |||||
| if (res != nullptr) { | |||||
| res->dump(); | |||||
| MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name(); | |||||
| auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); | |||||
| dst_pattern_->reset(); | |||||
| MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | ||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| src_pattern_->reset(); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -188,14 +169,12 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) { | |||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| AnfNodePtr node = todo.front(); | AnfNodePtr node = todo.front(); | ||||
| todo.pop_front(); | todo.pop_front(); | ||||
| // check whether this node has been matched. | |||||
| if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { | |||||
| // Check whether this node has been matched. | |||||
| if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| node->seen_ = seen; | node->seen_ = seen; | ||||
| // select nodes that this transform can be applied. | |||||
| // Select nodes that this transform can be applied. | |||||
| AnfNodePtr new_node = Run(func_graph, node); | AnfNodePtr new_node = Run(func_graph, node); | ||||
| bool change = (new_node != nullptr); | bool change = (new_node != nullptr); | ||||
| if (new_node != nullptr && new_node != node) { | if (new_node != nullptr && new_node != node) { | ||||
| @@ -206,17 +185,14 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph) { | |||||
| if (run_only_once_) { | if (run_only_once_) { | ||||
| return change; | return change; | ||||
| } | } | ||||
| // find success, and add them to todo list | |||||
| // Find success, and add them to todo list | |||||
| if (IsValueNode<FuncGraph>(node)) { | if (IsValueNode<FuncGraph>(node)) { | ||||
| todo.push_back(GetValueNode<FuncGraphPtr>(node)->output()); | todo.push_back(GetValueNode<FuncGraphPtr>(node)->output()); | ||||
| } | } | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | auto &inputs = node->cast<CNodePtr>()->inputs(); | ||||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | ||||
| } | } | ||||
| auto &node_users = manager->node_users(); | auto &node_users = manager->node_users(); | ||||
| if (change && node_users.find(node) != node_users.end()) { | if (change && node_users.find(node) != node_users.end()) { | ||||
| for (auto &use : node_users[node]) { | for (auto &use : node_users[node]) { | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "frontend/optimizer/pattern.h" | |||||
| #include "pybind_api/api_register.h" | #include "pybind_api/api_register.h" | ||||
| #include "pybind_api/export_flags.h" | #include "pybind_api/export_flags.h" | ||||
| @@ -33,17 +34,17 @@ using NodeEquivPtr = std::shared_ptr<NodeEquiv>; | |||||
| class PythonPass { | class PythonPass { | ||||
| public: | public: | ||||
| explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst, | |||||
| bool run_only_once = false, bool multigraph = true); | |||||
| explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false, | |||||
| bool multigraph = true) | |||||
| : src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {} | |||||
| ~PythonPass() = default; | ~PythonPass() = default; | ||||
| bool Run(const FuncGraphPtr &func_graph); | bool Run(const FuncGraphPtr &func_graph); | ||||
| std::string name() const { return name_; } | std::string name() const { return name_; } | ||||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); | AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); | ||||
| private: | private: | ||||
| void Build(const py::function &src, const py::function &dst); | |||||
| AnfNodePtr src_node_ = nullptr; | |||||
| AnfNodePtr dst_node_ = nullptr; | |||||
| PatternPtr src_pattern_; | |||||
| PatternPtr dst_pattern_; | |||||
| const std::string name_; | const std::string name_; | ||||
| bool run_only_once_; | bool run_only_once_; | ||||
| bool multigraph_ = true; | bool multigraph_ = true; | ||||
| @@ -49,7 +49,7 @@ PyPassManager::PyPassManager() { | |||||
| phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>(); | phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>(); | ||||
| } | } | ||||
| void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, | |||||
| void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||||
| Phase phase, bool run_only_once, bool multigraph) { | Phase phase, bool run_only_once, bool multigraph) { | ||||
| auto cur_pm = GetPassGroup(phase); | auto cur_pm = GetPassGroup(phase); | ||||
| MS_EXCEPTION_IF_NULL(cur_pm); | MS_EXCEPTION_IF_NULL(cur_pm); | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "pipeline/jit/parse/resolve.h" | #include "pipeline/jit/parse/resolve.h" | ||||
| #include "frontend/optimizer/pattern.h" | |||||
| #include "frontend/optimizer/py_pass.h" | #include "frontend/optimizer/py_pass.h" | ||||
| #include "frontend/optimizer/pass_group.h" | #include "frontend/optimizer/pass_group.h" | ||||
| @@ -51,7 +52,7 @@ class PyPassManager { | |||||
| // Access the only global instance | // Access the only global instance | ||||
| static PyPassManagerPtr GetInstance(); | static PyPassManagerPtr GetInstance(); | ||||
| virtual ~PyPassManager() = default; | virtual ~PyPassManager() = default; | ||||
| void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, | |||||
| void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||||
| Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); | Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); | ||||
| void Unregiste(const std::string &pass_name, Phase phase); | void Unregiste(const std::string &pass_name, Phase phase); | ||||
| PassGroupPtr GetPassGroup(Phase phase); | PassGroupPtr GetPassGroup(Phase phase); | ||||
| @@ -0,0 +1,154 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Patterns for describing graphs""" | |||||
| from mindspore.ops import Primitive | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_ | |||||
| __all__ = [ | |||||
| "IsIn", | |||||
| "IsPrimTypeOf", | |||||
| "CallWith", | |||||
| "IsNot", | |||||
| "AnyPattern", | |||||
| "NewTensor", | |||||
| ] | |||||
| class IsIn(IsIn_): | |||||
| """ | |||||
| Express a pattern which allows a list of patterns. | |||||
| """ | |||||
| def __init__(self, patterns=None, should_replace=True): | |||||
| r""" | |||||
| Args: | |||||
| patterns(list/tuple): list of allowed patterns | |||||
| should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. | |||||
| """ | |||||
| if not should_replace: | |||||
| raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \ | |||||
| its sub-pattern instead.") | |||||
| self.patterns = patterns | |||||
| if patterns is None: | |||||
| IsIn_.__init__(self, ()) | |||||
| elif isinstance(patterns, Pattern): | |||||
| IsIn_.__init__(self, [patterns]) | |||||
| elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): | |||||
| IsIn_.__init__(self, patterns) | |||||
| else: | |||||
| raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") | |||||
| class IsPrimTypeOf(IsPrimTypeOf_): | |||||
| r""" | |||||
| Express a pattern of certain primitive type(s). | |||||
| NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed, | |||||
| please refer to CallWith pattern. | |||||
| """ | |||||
| def __init__(self, types, name=None, should_replace=True): | |||||
| r""" | |||||
| Args: | |||||
| types (str/(list/tuple of Primitives)): Specify allowed types. | |||||
| If it is a string, the form could be | |||||
| 1) a single primitive type, e.g. 'Conv2D' | |||||
| 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' | |||||
| It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)] | |||||
| name (str): name of the pattern, optional | |||||
| should_replace | |||||
| """ | |||||
| if name is not None and not isinstance(name, str): | |||||
| raise TypeError(f"Expect string, got : {name}") | |||||
| self.name = name | |||||
| if isinstance(types, str): | |||||
| if self.name is None: | |||||
| self.name = types | |||||
| self.types = types.split('|') | |||||
| elif isinstance(types, Primitive): | |||||
| if self.name is None: | |||||
| self.name = types.name | |||||
| self.types = [types] | |||||
| elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): | |||||
| if self.name is None: | |||||
| self.name = "" | |||||
| for prim in types: | |||||
| self.name += prim.name | |||||
| self.types = types | |||||
| else: | |||||
| raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") | |||||
| IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace) | |||||
| class CallWith(CallWith_): | |||||
| r""" | |||||
| Express a primitive CNode. | |||||
| """ | |||||
| def __init__(self, prim_pattern, inputs=None, should_replace=False): | |||||
| r""" | |||||
| Args: | |||||
| prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode. | |||||
| inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; | |||||
| if specified, input patterns should be of right order. | |||||
| """ | |||||
| if not isinstance(prim_pattern, (Pattern, str, Primitive)): | |||||
| raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") | |||||
| self.prim_pattern = prim_pattern | |||||
| self.inputs = [] | |||||
| if inputs is None: | |||||
| pass | |||||
| elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): | |||||
| self.inputs = inputs | |||||
| else: | |||||
| raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") | |||||
| CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace) | |||||
| class IsNot(IsNot_): | |||||
| r""" | |||||
| Express a pattern which forbids a list of patterns. | |||||
| NOTE: IsNot pattern should not be the root pattern. | |||||
| """ | |||||
| def __init__(self, patterns=None, should_replace=True): | |||||
| r""" | |||||
| Args: | |||||
| patterns(list/tuple): list of forbiden patterns | |||||
| should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. | |||||
| """ | |||||
| if not should_replace: | |||||
| raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \ | |||||
| its sub-pattern instead.") | |||||
| self.patterns = patterns | |||||
| if patterns is None: | |||||
| IsNot_.__init__(self, ()) | |||||
| elif isinstance(patterns, Pattern): | |||||
| IsNot_.__init__(self, [patterns]) | |||||
| elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): | |||||
| IsNot_.__init__(self, patterns) | |||||
| else: | |||||
| raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") | |||||
| class NewTensor(NewTensor_): | |||||
| r""" | |||||
| New Tensor to be used in the target. | |||||
| """ | |||||
| def __init__(self, input_tensor, should_replace=False): | |||||
| r""" | |||||
| Args: | |||||
| input_tensor(Tensor): new tensor to be used in the target | |||||
| should_replace(bool): added this for interface consistency. NewTensor should only appear in the target. | |||||
| """ | |||||
| if should_replace: | |||||
| raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.") | |||||
| self.input_tensor = input_tensor | |||||
| if isinstance(input_tensor, Tensor): | |||||
| NewTensor_.__init__(self, input_tensor) | |||||
| else: | |||||
| raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Python pass register""" | """Python pass register""" | ||||
| from inspect import isfunction | from inspect import isfunction | ||||
| from mindspore.common.graph_pattern import Pattern | |||||
| from mindspore._c_expression import PyPassManager_ | from mindspore._c_expression import PyPassManager_ | ||||
| from mindspore._c_expression import phase | from mindspore._c_expression import phase | ||||
| @@ -46,10 +47,10 @@ class PyPassManager(PyPassManager_): | |||||
| raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") | raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") | ||||
| pattern, target = py_pass() | pattern, target = py_pass() | ||||
| pass_name = py_pass.__name__ | pass_name = py_pass.__name__ | ||||
| if not isfunction(pattern): | |||||
| raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}") | |||||
| if not isfunction(target): | |||||
| raise TypeError(f"Expecting function target, got : ({type(target)}){target}") | |||||
| if not isinstance(pattern, Pattern): | |||||
| raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}") | |||||
| if not isinstance(target, Pattern): | |||||
| raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}") | |||||
| super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) | super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) | ||||
| def unregiste(self, py_pass, pipeline_phase=phase.opt): | def unregiste(self, py_pass, pipeline_phase=phase.opt): | ||||
| @@ -22,10 +22,11 @@ from mindspore.ops import operations as P | |||||
| from mindspore.common.python_pass_register import registe_pass, PyPassManager | from mindspore.common.python_pass_register import registe_pass, PyPassManager | ||||
| from mindspore.common.api import _generate_pip_args | from mindspore.common.api import _generate_pip_args | ||||
| from mindspore._c_expression import generate_key, Executor_ | from mindspore._c_expression import generate_key, Executor_ | ||||
| from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor | |||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| def get_func_graph(obj, *args, phase="predict"): | |||||
| def get_func_graph(obj, *args, phase="validate"): | |||||
| args_names, args_list = _generate_pip_args(obj, *args) | args_names, args_list = _generate_pip_args(obj, *args) | ||||
| dic = dict(zip(args_names, args_list)) | dic = dict(zip(args_names, args_list)) | ||||
| key = generate_key(phase, dic) | key = generate_key(phase, dic) | ||||
| @@ -47,14 +48,11 @@ def test_softmax_relu(): | |||||
| @registe_pass(run_only_once=True) | @registe_pass(run_only_once=True) | ||||
| def softmax_relu_pass(): | def softmax_relu_pass(): | ||||
| softmax = P.Softmax() | |||||
| relu = P.ReLU() | |||||
| def pattern(x): | |||||
| x = softmax(x) | |||||
| return x | |||||
| def target(x): | |||||
| x = relu(x) | |||||
| return x | |||||
| x = AnyPattern() | |||||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||||
| pattern = CallWith(softmax_pattern, inputs=[x]) | |||||
| relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) | |||||
| target = CallWith(relu_pattern, inputs=[x]) | |||||
| return pattern, target | return pattern, target | ||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | ||||
| @@ -62,3 +60,128 @@ def test_softmax_relu(): | |||||
| ppm.unregiste(softmax_relu_pass) | ppm.unregiste(softmax_relu_pass) | ||||
| assert "ReLU" in transformed_repr | assert "ReLU" in transformed_repr | ||||
| assert "Softmax" not in transformed_repr | assert "Softmax" not in transformed_repr | ||||
| def test_isin_pattern(): | |||||
| """ | |||||
| Test IsIn pattern which expresses the IsIn/OneOf semantics. | |||||
| """ | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||||
| softmax_model = nn.Softmax() | |||||
| @registe_pass(run_only_once=True) | |||||
| def softmax_relu_pass(): | |||||
| x = AnyPattern() | |||||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||||
| call_softmax = CallWith(softmax_pattern, inputs=[x]) | |||||
| relu_pattern = IsPrimTypeOf(P.ReLU()) | |||||
| call_relu = CallWith(relu_pattern, inputs=[x]) | |||||
| pattern = IsIn([call_softmax, call_relu]) | |||||
| relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False) | |||||
| target = CallWith(relu6_pattern, inputs=[x]) | |||||
| return pattern, target | |||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||||
| ppm = PyPassManager() | |||||
| ppm.unregiste(softmax_relu_pass) | |||||
| assert "ReLU6" in transformed_repr | |||||
| assert "Softmax" not in transformed_repr | |||||
| def test_isnot_pattern_0(): | |||||
| """ | |||||
| Test IsNot pattern which expresses the IsNot semantics. | |||||
| Case: IsNot pass failed to match | |||||
| """ | |||||
| class ConvBN(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ConvBN, self).__init__() | |||||
| self.conv = P.Conv2D(32, 3) | |||||
| self.conv_weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32) | |||||
| self.scale = Tensor(np.ones([32]), mindspore.float32) | |||||
| self.bias = Tensor(np.ones([32]), mindspore.float32) | |||||
| self.mean = Tensor(np.ones([32]), mindspore.float32) | |||||
| self.variance = Tensor(np.ones([32]), mindspore.float32) | |||||
| self.bn = P.BatchNorm() | |||||
| def construct(self, x): | |||||
| x = self.conv(x, self.conv_weight) | |||||
| x = self.bn(x, self.scale, self.bias, self.mean, self.variance) | |||||
| return x | |||||
| inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32) | |||||
| conv_bn_model = ConvBN() | |||||
| @registe_pass(run_only_once=True) | |||||
| def single_bn_pass(): | |||||
| """ | |||||
| Sub a BN which does NOT take Conv as inputs to ReLU6. | |||||
| """ | |||||
| conv2d_prim = IsPrimTypeOf("Conv2D") | |||||
| conv2d = CallWith(conv2d_prim) | |||||
| pattern_0 = IsNot(conv2d) | |||||
| pattern = CallWith(P.BatchNorm(), inputs=[pattern_0]) | |||||
| target = CallWith(P.ReLU6(), inputs=[pattern_0]) | |||||
| return pattern, target | |||||
| @registe_pass(run_only_once=True) | |||||
| def bn_pass(): | |||||
| """ | |||||
| Sub a BN to Softmax. | |||||
| """ | |||||
| bn = P.BatchNorm() | |||||
| pattern = CallWith(bn) | |||||
| softmax = P.Softmax() | |||||
| target = CallWith(softmax, should_replace=False) | |||||
| return pattern, target | |||||
| transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) | |||||
| ppm = PyPassManager() | |||||
| ppm.unregiste(single_bn_pass) | |||||
| ppm.unregiste(bn_pass) | |||||
| assert "ReLU6" not in transformed_repr | |||||
| assert "Softmax" in transformed_repr | |||||
| def test_isnot_pattern_1(): | |||||
| """ | |||||
| Test IsNot pattern which expresses the IsNot semantics. | |||||
| Case: IsNot pattern matches with the graph | |||||
| """ | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||||
| softmax_model = nn.Softmax() | |||||
| @registe_pass(run_only_once=True) | |||||
| def single_bn_pass(): | |||||
| """ | |||||
| Sub a BN which does NOT take MatMul as inputs to ReLU6. | |||||
| """ | |||||
| matmul = IsPrimTypeOf("MatMul") | |||||
| pattern_0 = IsNot(matmul) | |||||
| softmax = P.Softmax() | |||||
| pattern = CallWith(softmax, inputs=[pattern_0]) | |||||
| relu6 = P.ReLU6() | |||||
| target = CallWith(relu6, inputs=[pattern_0], should_replace=False) | |||||
| return pattern, target | |||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||||
| ppm = PyPassManager() | |||||
| ppm.unregiste(single_bn_pass) | |||||
| assert "ReLU6" in transformed_repr | |||||
| assert "Softmax" not in transformed_repr | |||||
| def test_newtensor_pattern(): | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||||
| softmax_model = nn.Softmax() | |||||
| @registe_pass(run_only_once=True) | |||||
| def softmax_addn_pass(): | |||||
| x = AnyPattern() | |||||
| softmax = P.Softmax() | |||||
| pattern = CallWith(softmax, inputs=[x]) | |||||
| weight_tensor = Tensor(np.zeros([42]), mindspore.float16) | |||||
| new_weight = NewTensor(weight_tensor) | |||||
| addn_ops = P.AddN() | |||||
| target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False) | |||||
| return pattern, target | |||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||||
| ppm = PyPassManager() | |||||
| ppm.unregiste(softmax_addn_pass) | |||||
| assert "AddN" in transformed_repr | |||||
| assert "Softmax" not in transformed_repr | |||||