| @@ -21,25 +21,23 @@ namespace opt { | |||
| namespace python_pass { | |||
| int Pattern::g_id_ = 0; | |||
| MatchResultPtr IsPrimTypeOf::match(const AnfNodePtr &node) { | |||
| MatchResultPtr Prim::match(const AnfNodePtr &node) { | |||
| if (!IsValueNode<Primitive>(node)) { | |||
| return nullptr; | |||
| } | |||
| MatchResultPtr res = std::make_shared<MatchResult>(); | |||
| if (IsValueNode<Primitive>(node)) { | |||
| // iterate over all primitives | |||
| for (auto &iter : primitives_) { | |||
| if (IsPrimitive(node, iter) || iter->name() == "*") { | |||
| matched_prim_ = iter; | |||
| res->add_entry(shared_from_base<IsPrimTypeOf>(), node); | |||
| return res; | |||
| } | |||
| // iterate over all primitives | |||
| for (auto &iter : primitives_) { | |||
| if (IsPrimitive(node, iter) || iter->name() == "*") { | |||
| matched_prim_ = iter; | |||
| res->add_entry(shared_from_base<Prim>(), node); | |||
| return res; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| MatchResultPtr CallWith::match(const AnfNodePtr &node) { | |||
| MatchResultPtr Call::match(const AnfNodePtr &node) { | |||
| if (!IsPrimitiveCNode(node)) { | |||
| return nullptr; | |||
| } | |||
| @@ -71,7 +69,7 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) { | |||
| } | |||
| // If inputs is not specified, add node without looking into its inputs | |||
| if (p_inputs_size == 0) { | |||
| res->add_entry(shared_from_base<CallWith>(), cnode->input(0)); | |||
| res->add_entry(shared_from_base<Call>(), cnode->input(0)); | |||
| return res; | |||
| } | |||
| bool failed = false; | |||
| @@ -86,24 +84,24 @@ MatchResultPtr CallWith::match(const AnfNodePtr &node) { | |||
| res->merge(input_match_result); | |||
| } | |||
| if (!failed) { | |||
| res->add_entry(shared_from_base<CallWith>(), cnode->input(0)); | |||
| res->add_entry(shared_from_base<Call>(), cnode->input(0)); | |||
| return res; | |||
| } | |||
| return nullptr; | |||
| } | |||
| MatchResultPtr IsIn::match(const AnfNodePtr &node) { | |||
| MatchResultPtr OneOf::match(const AnfNodePtr &node) { | |||
| for (auto &iter : patterns_) { | |||
| auto res = iter->match(node); | |||
| if (res != nullptr) { | |||
| res->add_entry(shared_from_base<IsIn>(), node); | |||
| res->add_entry(shared_from_base<OneOf>(), node); | |||
| return res; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| MatchResultPtr IsNot::match(const AnfNodePtr &node) { | |||
| MatchResultPtr NoneOf::match(const AnfNodePtr &node) { | |||
| for (auto &iter : patterns_) { | |||
| auto res = iter->match(node); | |||
| if (res != nullptr) { | |||
| @@ -111,16 +109,33 @@ MatchResultPtr IsNot::match(const AnfNodePtr &node) { | |||
| } | |||
| } | |||
| auto res = std::make_shared<MatchResult>(); | |||
| res->add_entry(shared_from_base<IsNot>(), node); | |||
| res->add_entry(shared_from_base<NoneOf>(), node); | |||
| return res; | |||
| } | |||
| MatchResultPtr AnyPattern::match(const AnfNodePtr &node) { | |||
| MatchResultPtr Any::match(const AnfNodePtr &node) { | |||
| MatchResultPtr res = std::make_shared<MatchResult>(); | |||
| res->add_entry(shared_from_base<AnyPattern>(), node); | |||
| res->add_entry(shared_from_base<Any>(), node); | |||
| return res; | |||
| } | |||
| MatchResultPtr Imm::match(const AnfNodePtr &node) { | |||
| if (!IsValueNode<Int32Imm>(node)) { | |||
| return nullptr; | |||
| } | |||
| // Check value | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto value_ptr = value_node->value()->cast<Int32ImmPtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_ptr); | |||
| if ((int32_t)value_ptr->value() == value_) { | |||
| MatchResultPtr res = std::make_shared<MatchResult>(); | |||
| res->add_entry(shared_from_base<Imm>(), node); | |||
| return res; | |||
| } | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr MatchResult::get_node(const PatternPtr &pattern) { | |||
| auto entry = match_result_.find(pattern); | |||
| if (entry == match_result_.end()) { | |||
| @@ -140,20 +155,20 @@ void MatchResult::merge(const MatchResultPtr &other_result) { | |||
| REGISTER_PYBIND_DEFINE( | |||
| Pattern, ([](const py::module *m) { | |||
| (void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>()); | |||
| (void)py::class_<IsIn, std::shared_ptr<IsIn>, Pattern>(*m, "IsIn_").def(py::init<vector<PatternPtr>>()); | |||
| (void)py::class_<IsPrimTypeOf, std::shared_ptr<IsPrimTypeOf>, Pattern>(*m, "IsPrimTypeOf_", py::dynamic_attr()) | |||
| .def(py::init<vector<PrimitivePyPtr>, string, bool>()) | |||
| .def(py::init<vector<string>, string, bool>()); | |||
| (void)py::class_<CallWith, std::shared_ptr<CallWith>, Pattern>(*m, "CallWith_") | |||
| .def(py::init<PatternPtr, vector<PatternPtr>, bool>()) | |||
| .def(py::init<PrimitivePyPtr, vector<PatternPtr>, bool>()) | |||
| .def(py::init<string, vector<PatternPtr>, bool>()); | |||
| (void)py::class_<IsNot, std::shared_ptr<IsNot>, Pattern>(*m, "IsNot_").def(py::init<vector<PatternPtr>>()); | |||
| (void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>()); | |||
| (void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>()); | |||
| (void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr()) | |||
| .def(py::init<vector<PrimitivePyPtr>, string>()) | |||
| .def(py::init<vector<string>, string>()); | |||
| (void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_") | |||
| .def(py::init<PatternPtr, vector<PatternPtr>>()) | |||
| .def(py::init<PrimitivePyPtr, vector<PatternPtr>>()) | |||
| .def(py::init<string, vector<PatternPtr>>()); | |||
| (void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>()); | |||
| (void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>()); | |||
| (void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_") | |||
| .def(py::init<tensor::TensorPtr>()); | |||
| (void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_") | |||
| .def(py::init<string, tensor::TensorPtr, bool, bool, bool>()); | |||
| .def(py::init<string, tensor::TensorPtr, bool, bool>()); | |||
| (void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>()); | |||
| })); | |||
| } // namespace python_pass | |||
| @@ -36,10 +36,10 @@ class MatchResult; | |||
| using MatchResultPtr = std::shared_ptr<MatchResult>; | |||
| class Pattern; | |||
| using PatternPtr = std::shared_ptr<Pattern>; | |||
| class IsPrimTypeOf; | |||
| using IsPrimTypeOfPtr = std::shared_ptr<IsPrimTypeOf>; | |||
| class CallWith; | |||
| using CallWithPtr = std::shared_ptr<CallWith>; | |||
| class Prim; | |||
| using PrimPtr = std::shared_ptr<Prim>; | |||
| class Call; | |||
| using CallPtr = std::shared_ptr<Call>; | |||
| class NewTensor; | |||
| using NewTensorPtr = std::shared_ptr<NewTensor>; | |||
| class NewParameter; | |||
| @@ -58,8 +58,6 @@ class Pattern : public Base { | |||
| virtual bool operator==(const Pattern &other) const { return unique_name_ == other.unique_name_; } | |||
| string unique_name() const { return unique_name_; } | |||
| vector<PatternPtr> inputs() { return inputs_; } | |||
| bool should_replace() { return should_replace_; } | |||
| void set_should_replace(bool should_replace) { should_replace_ = should_replace; } | |||
| virtual void reset() {} | |||
| protected: | |||
| @@ -67,7 +65,6 @@ class Pattern : public Base { | |||
| // NOTE: To ensure uniqueness of the name, raise g_id_ by 1 every time a pattern got constructed | |||
| string unique_name_; | |||
| vector<PatternPtr> inputs_; | |||
| bool should_replace_ = true; | |||
| }; | |||
| struct PatternEqual { | |||
| @@ -85,70 +82,61 @@ struct PatternHasher { | |||
| } | |||
| }; | |||
| class IsPrimTypeOf : public Pattern { | |||
| class Prim : public Pattern { | |||
| public: | |||
| IsPrimTypeOf() { unique_name_ = std::to_string(g_id_++); } | |||
| ~IsPrimTypeOf() = default; | |||
| IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace) | |||
| : primitives_(prims), name_(name), matched_prim_(nullptr) { | |||
| unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name; | |||
| should_replace_ = should_replace; | |||
| if (!should_replace) { | |||
| matched_prim_ = prims[0]; | |||
| } | |||
| Prim() { unique_name_ = std::to_string(g_id_++); } | |||
| ~Prim() = default; | |||
| Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) { | |||
| unique_name_ = std::to_string(g_id_++) + "Prim_" + name; | |||
| // Default using the first prim to build target | |||
| matched_prim_ = primitives_[0]; | |||
| } | |||
| IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) { | |||
| unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name; | |||
| Prim(vector<string> types, string name) : types_(types), name_(name) { | |||
| unique_name_ = std::to_string(g_id_++) + "Prim_" + name; | |||
| // Make primitives_ | |||
| for (auto &iter : types) { | |||
| primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr))); | |||
| } | |||
| should_replace_ = should_replace; | |||
| if (!should_replace) { | |||
| matched_prim_ = primitives_[0]; | |||
| } | |||
| // Default using the first prim to build target | |||
| matched_prim_ = primitives_[0]; | |||
| } | |||
| MS_DECLARE_PARENT(IsPrimTypeOf, Pattern); | |||
| MS_DECLARE_PARENT(Prim, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||
| PrimitivePyPtr matched_primitive() { return matched_prim_; } | |||
| void reset() override { | |||
| if (should_replace_) { | |||
| matched_prim_ = nullptr; | |||
| } | |||
| // Init before reset | |||
| MS_EXCEPTION_IF_NULL(matched_prim_); | |||
| matched_prim_ = primitives_[0]; | |||
| } | |||
| private: | |||
| vector<string> types_; | |||
| vector<PrimitivePyPtr> primitives_; | |||
| string name_; | |||
| PrimitivePyPtr matched_prim_; | |||
| PrimitivePyPtr matched_prim_{nullptr}; | |||
| }; | |||
| class CallWith : public Pattern { | |||
| class Call : public Pattern { | |||
| public: | |||
| CallWith() { unique_name_ = std::to_string(g_id_++); } | |||
| ~CallWith() = default; | |||
| CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) { | |||
| Call() { unique_name_ = std::to_string(g_id_++); } | |||
| ~Call() = default; | |||
| Call(PatternPtr prim_pattern, vector<PatternPtr> inputs) { | |||
| // NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting | |||
| prim_pattern_ = prim_pattern; | |||
| unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name(); | |||
| unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name(); | |||
| inputs_ = inputs; | |||
| // NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently. | |||
| should_replace_ = prim_pattern->should_replace(); | |||
| } | |||
| CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) { | |||
| Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) { | |||
| prim_ = prim; | |||
| unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString(); | |||
| unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString(); | |||
| inputs_ = inputs; | |||
| should_replace_ = should_replace; | |||
| } | |||
| CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) { | |||
| Call(string prim_str, vector<PatternPtr> inputs) { | |||
| prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr)); | |||
| unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString(); | |||
| unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString(); | |||
| inputs_ = inputs; | |||
| should_replace_ = should_replace; | |||
| } | |||
| MS_DECLARE_PARENT(CallWith, Pattern); | |||
| MS_DECLARE_PARENT(Call, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||
| PrimitivePtr prim_value() { return prim_; } | |||
| PatternPtr prim_pattern() { return prim_pattern_; } | |||
| @@ -160,45 +148,45 @@ class CallWith : public Pattern { | |||
| string name_; | |||
| }; | |||
| class IsIn : public Pattern { | |||
| class OneOf : public Pattern { | |||
| public: | |||
| IsIn() { unique_name_ = std::to_string(g_id_++); } | |||
| ~IsIn() = default; | |||
| explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) { | |||
| unique_name_ = std::to_string(g_id_++) + "IsIn"; | |||
| OneOf() { unique_name_ = std::to_string(g_id_++); } | |||
| ~OneOf() = default; | |||
| explicit OneOf(vector<PatternPtr> patterns) : patterns_(patterns) { | |||
| unique_name_ = std::to_string(g_id_++) + "OneOf"; | |||
| for (auto &iter : patterns) { | |||
| unique_name_ = unique_name_ + "_" + iter->unique_name(); | |||
| } | |||
| } | |||
| MS_DECLARE_PARENT(IsIn, Pattern); | |||
| MS_DECLARE_PARENT(OneOf, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||
| private: | |||
| vector<PatternPtr> patterns_; | |||
| }; | |||
| class IsNot : public Pattern { | |||
| class NoneOf : public Pattern { | |||
| public: | |||
| IsNot() { unique_name_ = std::to_string(g_id_++); } | |||
| ~IsNot() = default; | |||
| explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) { | |||
| unique_name_ = std::to_string(g_id_++) + "IsNot"; | |||
| NoneOf() { unique_name_ = std::to_string(g_id_++); } | |||
| ~NoneOf() = default; | |||
| explicit NoneOf(vector<PatternPtr> patterns) : patterns_(patterns) { | |||
| unique_name_ = std::to_string(g_id_++) + "NoneOf"; | |||
| for (auto &iter : patterns) { | |||
| unique_name_ = unique_name_ + "_" + iter->unique_name(); | |||
| } | |||
| } | |||
| MS_DECLARE_PARENT(IsNot, Pattern); | |||
| MS_DECLARE_PARENT(NoneOf, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||
| private: | |||
| vector<PatternPtr> patterns_; | |||
| }; | |||
| class AnyPattern : public Pattern { | |||
| class Any : public Pattern { | |||
| public: | |||
| AnyPattern() { unique_name_ = std::to_string(g_id_++) + "_AnyPattern"; } | |||
| ~AnyPattern() = default; | |||
| MS_DECLARE_PARENT(AnyPattern, Pattern); | |||
| Any() { unique_name_ = std::to_string(g_id_++) + "_Any"; } | |||
| ~Any() = default; | |||
| MS_DECLARE_PARENT(Any, Pattern); | |||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||
| }; | |||
| @@ -207,7 +195,6 @@ class NewTensor : public Pattern { | |||
| NewTensor() { unique_name_ = std::to_string(g_id_++); } | |||
| ~NewTensor() = default; | |||
| explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { | |||
| should_replace_ = false; | |||
| unique_name_ = std::to_string(g_id_++) + "NewTensor"; | |||
| } | |||
| MS_DECLARE_PARENT(NewTensor, Pattern); | |||
| @@ -223,10 +210,8 @@ class NewTensor : public Pattern { | |||
| class NewParameter : public Pattern { | |||
| public: | |||
| NewParameter() { unique_name_ = std::to_string(g_id_++); } | |||
| explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel, | |||
| bool should_replace) | |||
| explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel) | |||
| : para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) { | |||
| should_replace_ = should_replace; | |||
| unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name; | |||
| // clone input tensor | |||
| default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get()); | |||
| @@ -243,11 +228,14 @@ class NewParameter : public Pattern { | |||
| bool built() { return built_; } | |||
| void set_built(bool built) { built_ = built; } | |||
| void reset() override { built_ = false; } | |||
| bool should_last() { return last_across_passes_; } | |||
| void set_last(bool last) { last_across_passes_ = last; } | |||
| private: | |||
| string para_name_; | |||
| bool requires_grad_; | |||
| bool layerwise_parallel_; | |||
| bool last_across_passes_{false}; | |||
| bool built_; | |||
| tensor::TensorPtr default_tensor_; | |||
| }; | |||
| @@ -255,13 +243,9 @@ class NewParameter : public Pattern { | |||
| class Imm : public Pattern { | |||
| public: | |||
| Imm() { unique_name_ = std::to_string(g_id_++); } | |||
| explicit Imm(int value) : value_(value) { | |||
| should_replace_ = false; | |||
| unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); | |||
| } | |||
| explicit Imm(int value) : value_(value) { unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value); } | |||
| MS_DECLARE_PARENT(Imm, Pattern); | |||
| // NOTE: Doesn't support Imm in src pattern currently. | |||
| MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; } | |||
| MatchResultPtr match(const AnfNodePtr &node) override; | |||
| int value() { return value_; } | |||
| private: | |||
| @@ -80,7 +80,7 @@ bool IsTraversable(const AnfNodePtr &node) { | |||
| AnfNodePtr BuildPrimitive(const PatternPtr &pattern, const MatchResultPtr &res) { | |||
| // Build up AnfNode from primitive | |||
| auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>(); | |||
| auto prim_pattern = pattern->cast<PrimPtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_pattern); | |||
| PrimitivePyPtr prim = prim_pattern->matched_primitive(); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| @@ -98,13 +98,13 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) | |||
| } | |||
| AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) { | |||
| auto call_with_pattern = pattern->cast<CallWithPtr>(); | |||
| MS_EXCEPTION_IF_NULL(call_with_pattern); | |||
| auto prim = call_with_pattern->prim_value(); | |||
| auto call_pattern = pattern->cast<CallPtr>(); | |||
| MS_EXCEPTION_IF_NULL(call_pattern); | |||
| auto prim = call_pattern->prim_value(); | |||
| if (prim != nullptr) { | |||
| return std::make_shared<ValueNode>(prim); | |||
| } | |||
| auto prim_pattern = call_with_pattern->prim_pattern(); | |||
| auto prim_pattern = call_pattern->prim_pattern(); | |||
| MS_EXCEPTION_IF_NULL(prim_pattern); | |||
| return ProcessSinglePattern(prim_pattern, res, fg); | |||
| } | |||
| @@ -152,45 +152,35 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) { | |||
| } | |||
| AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { | |||
| if (pattern->should_replace()) { | |||
| // Find replacement in the MatchResult | |||
| auto target_node = res->get_node(pattern); | |||
| if (target_node == nullptr) { | |||
| // If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception. | |||
| if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n"; | |||
| return nullptr; | |||
| } | |||
| // Try to build this pattern and add to MatchResult, since this pattern is defined inside target | |||
| auto new_node = BuildTarget(pattern, func_graph, res); | |||
| if (new_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n"; | |||
| } | |||
| return new_node; | |||
| } | |||
| if (pattern->isa<NewParameter>()) { | |||
| auto target_node = res->get_node(pattern); | |||
| if (target_node != nullptr) { | |||
| // If pattern is NewParameter, check whether it shouldn't last and is not built | |||
| auto new_para = pattern->cast<NewParameterPtr>(); | |||
| if (new_para == nullptr || new_para->should_last() || new_para->built()) { | |||
| return target_node; | |||
| } | |||
| return target_node; | |||
| } | |||
| // Build up new node from pattern | |||
| if (pattern->isa<IsPrimTypeOf>()) { | |||
| if (pattern->isa<Prim>()) { | |||
| return BuildPrimitive(pattern, res); | |||
| } else if (pattern->isa<NewTensor>()) { | |||
| return BuildNewTensor(pattern, res); | |||
| } else if (pattern->isa<CallWith>()) { | |||
| } else if (pattern->isa<Call>()) { | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph); | |||
| } else if (pattern->isa<NewParameter>()) { | |||
| return BuildNewParameter(pattern, res, func_graph); | |||
| } else if (pattern->isa<Imm>()) { | |||
| return BuildImmNode(pattern, res); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Cannot find or build target node, pattern: " + pattern->unique_name() + "\n"; | |||
| return nullptr; | |||
| } | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res, | |||
| const FuncGraphPtr &func_graph) { | |||
| if (pattern->isa<CallWith>()) { | |||
| if (pattern->isa<Call>()) { | |||
| return BuildPrimitiveValueNode(pattern, res, func_graph); | |||
| } | |||
| return nullptr; | |||
| @@ -269,16 +259,16 @@ void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor | |||
| } | |||
| void Reset(PatternPtr pattern) { | |||
| if (pattern->isa<IsPrimTypeOf>()) { | |||
| auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>(); | |||
| if (pattern->isa<Prim>()) { | |||
| auto prim_pattern = pattern->cast<PrimPtr>(); | |||
| prim_pattern->reset(); | |||
| return; | |||
| } else if (pattern->isa<NewParameter>()) { | |||
| auto new_param_pattern = pattern->cast<NewParameterPtr>(); | |||
| new_param_pattern->reset(); | |||
| return; | |||
| } else if (pattern->isa<CallWith>()) { | |||
| auto call_with_pattern = pattern->cast<CallWithPtr>(); | |||
| } else if (pattern->isa<Call>()) { | |||
| auto call_with_pattern = pattern->cast<CallPtr>(); | |||
| for (auto sub_pattern : call_with_pattern->inputs()) { | |||
| Reset(sub_pattern); | |||
| } | |||
| @@ -49,8 +49,9 @@ PyPassManager::PyPassManager() { | |||
| } | |||
| void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||
| Phase phase, bool run_only_once) { | |||
| auto cur_pg = GetPassGroup(phase); | |||
| bool run_only_once) { | |||
| // NOTE: remove phase option to avoid unnecessary confusion. | |||
| auto cur_pg = GetPassGroup(Phase::OPT); | |||
| MS_EXCEPTION_IF_NULL(cur_pg); | |||
| cur_pg->SetRunOnlyOnce(run_only_once); | |||
| MS_EXCEPTION_IF_NULL(pattern); | |||
| @@ -60,8 +61,9 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt | |||
| cur_pg->AddPass(new_pass); | |||
| } | |||
| void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { | |||
| auto cur_pm = GetPassGroup(phase); | |||
| void PyPassManager::Unregiste(const std::string &pass_name) { | |||
| // NOTE: remove phase option to avoid unnecessary confusion. | |||
| auto cur_pm = GetPassGroup(Phase::OPT); | |||
| MS_EXCEPTION_IF_NULL(cur_pm); | |||
| if (!cur_pm->DeletePass(pass_name)) { | |||
| MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; | |||
| @@ -70,7 +72,6 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { | |||
| void PyPassManager::GenNewParameter(const PatternPtr ¶meter) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| // Add new parameter after resolve | |||
| // NOTE: Add NewParameter at early stage will cause CSE problems | |||
| auto cur_pg = GetPassGroup(Phase::OPT); | |||
| MS_EXCEPTION_IF_NULL(cur_pg); | |||
| @@ -78,7 +79,7 @@ void PyPassManager::GenNewParameter(const PatternPtr ¶meter) { | |||
| auto new_para_pattern = parameter->cast<NewParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(new_para_pattern); | |||
| auto pass_name = new_para_pattern->para_name(); | |||
| parameter->set_should_replace(false); | |||
| new_para_pattern->set_last(true); | |||
| auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true); | |||
| cur_pg->AddPass(new_pass); | |||
| } | |||
| @@ -53,16 +53,17 @@ class PyPassManager { | |||
| static PyPassManagerPtr GetInstance(); | |||
| virtual ~PyPassManager() = default; | |||
| void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, | |||
| Phase phase = Phase::RESOLVE, bool run_only_once = false); | |||
| void Unregiste(const std::string &pass_name, Phase phase); | |||
| bool run_only_once = false); | |||
| void Unregiste(const std::string &pass_name); | |||
| void GenNewParameter(const PatternPtr ¶meter); | |||
| PassGroupPtr GetPassGroup(Phase phase); | |||
| void ClearRes(); | |||
| MatchResultPtr GetMatchResult() { return res_; } | |||
| void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; } | |||
| bool ShouldRenorm() { return should_renorm_; } | |||
| void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; } | |||
| pipeline::ResourcePtr GetResource() { return resource_; } | |||
| void ClearRes(); | |||
| void ClearPipelineRes() { resource_ = nullptr; } | |||
| private: | |||
| bool should_renorm_ = true; | |||
| @@ -477,6 +477,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons | |||
| // save the run graph func to MsPipeLine | |||
| SaveCompiledGraph(phase_s); | |||
| opt::python_pass::PyPassManager::GetInstance()->ClearPipelineRes(); | |||
| resource->Clean(); | |||
| // Reclaim all resource used by optimizer; | |||
| ReclaimOptimizer(); | |||
| @@ -15,50 +15,43 @@ | |||
| """Patterns for describing graphs""" | |||
| from mindspore.ops import Primitive | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\ | |||
| NewParameter_, Imm | |||
| from mindspore._c_expression import Pattern, OneOf_, Prim_, Call_, NoneOf_, Any, NewTensor_, NewParameter_, Imm | |||
| __all__ = [ | |||
| "IsIn", | |||
| "IsPrimTypeOf", | |||
| "CallWith", | |||
| "IsNot", | |||
| "AnyPattern", | |||
| "OneOf", | |||
| "Prim", | |||
| "Call", | |||
| "NoneOf", | |||
| "Any", | |||
| "NewTensor", | |||
| "NewParameter", | |||
| "Imm" | |||
| ] | |||
| class IsIn(IsIn_): | |||
| class OneOf(OneOf_): | |||
| r""" | |||
| Express a pattern which allows a list of patterns. | |||
| """ | |||
| def __init__(self, patterns=None, should_replace=True): | |||
| def __init__(self, patterns=None): | |||
| r""" | |||
| Args: | |||
| patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`], | |||
| patterns(Union[:class:`mindspore.graph_utils.graph_pattern`, | |||
| tuple[:class:`mindspore.graph_utils.graph_pattern`], | |||
| list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns, | |||
| each element should be one of the exposed Pattern instance. | |||
| should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. | |||
| Raises: | |||
| ValueError: raise if should_replace is False | |||
| TypeError: raise type error for invalid inputs. | |||
| """ | |||
| if not should_replace: | |||
| raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \ | |||
| its sub-pattern instead.") | |||
| self.patterns = patterns | |||
| if patterns is None: | |||
| IsIn_.__init__(self, ()) | |||
| elif isinstance(patterns, Pattern): | |||
| IsIn_.__init__(self, [patterns]) | |||
| if isinstance(patterns, Pattern): | |||
| OneOf_.__init__(self, [patterns]) | |||
| elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): | |||
| IsIn_.__init__(self, patterns) | |||
| OneOf_.__init__(self, patterns) | |||
| else: | |||
| raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") | |||
| class IsPrimTypeOf(IsPrimTypeOf_): | |||
| class Prim(Prim_): | |||
| r""" | |||
| Express a pattern of certain primitive type(s). | |||
| @@ -66,7 +59,7 @@ class IsPrimTypeOf(IsPrimTypeOf_): | |||
| This pattern will match and only match the primitive value node. If matching primitive CNode is needed, | |||
| please refer to CallWith pattern. | |||
| """ | |||
| def __init__(self, types, name=None, should_replace=True): | |||
| def __init__(self, types, name=None): | |||
| r""" | |||
| Args: | |||
| types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`], | |||
| @@ -77,9 +70,6 @@ class IsPrimTypeOf(IsPrimTypeOf_): | |||
| 2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D' | |||
| It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)] | |||
| name (str): name of the pattern, optional. Default: None. | |||
| should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is | |||
| used when building the replacement target node. Use captured node if True, build from scratch otherwise. | |||
| Default: True. | |||
| Raises: | |||
| TypeError: raise type error for invalid argument. | |||
| @@ -103,13 +93,13 @@ class IsPrimTypeOf(IsPrimTypeOf_): | |||
| self.types = types | |||
| else: | |||
| raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") | |||
| IsPrimTypeOf_.__init__(self, self.types, self.name, should_replace) | |||
| Prim_.__init__(self, self.types, self.name) | |||
| class CallWith(CallWith_): | |||
| class Call(Call_): | |||
| r""" | |||
| Express a primitive CNode. | |||
| """ | |||
| def __init__(self, prim_pattern, inputs=None, should_replace=True): | |||
| def __init__(self, prim_pattern, inputs=None): | |||
| r""" | |||
| Args: | |||
| prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`, | |||
| @@ -118,9 +108,6 @@ class CallWith(CallWith_): | |||
| tuple[:class:`mindspore.graph_utils.graph_pattern`]]): | |||
| Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input | |||
| patterns should be of right order and each element should be one of the exposed Pattern instance. | |||
| should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is | |||
| used when building the replacement target node. Use captured node if True, build from scratch otherwise. | |||
| Default: True. | |||
| Raises: | |||
| TypeError: raise type error for invalid argument. | |||
| @@ -135,36 +122,31 @@ class CallWith(CallWith_): | |||
| self.inputs = inputs | |||
| else: | |||
| raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") | |||
| CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace) | |||
| Call_.__init__(self, self.prim_pattern, self.inputs) | |||
| class IsNot(IsNot_): | |||
| class NoneOf(NoneOf_): | |||
| r""" | |||
| Express a pattern which forbids a list of patterns. | |||
| NOTE: | |||
| IsNot pattern should not be the root pattern. | |||
| NoneOf pattern should not be the root pattern. | |||
| """ | |||
| def __init__(self, patterns=None, should_replace=True): | |||
| def __init__(self, patterns=None): | |||
| r""" | |||
| Args: | |||
| patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element | |||
| should be one of the exposed Pattern instance. | |||
| should_replace(bool): added this for interface consistency. Should only set this in sub-patterns. | |||
| Raises: | |||
| ValueError: raise if should_replace is False. | |||
| TypeError: raise type error for invalid argument. | |||
| """ | |||
| if not should_replace: | |||
| raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \ | |||
| its sub-pattern instead.") | |||
| self.patterns = patterns | |||
| if patterns is None: | |||
| IsNot_.__init__(self, ()) | |||
| NoneOf_.__init__(self, ()) | |||
| elif isinstance(patterns, Pattern): | |||
| IsNot_.__init__(self, [patterns]) | |||
| NoneOf_.__init__(self, [patterns]) | |||
| elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): | |||
| IsNot_.__init__(self, patterns) | |||
| NoneOf_.__init__(self, patterns) | |||
| else: | |||
| raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") | |||
| @@ -172,18 +154,14 @@ class NewTensor(NewTensor_): | |||
| r""" | |||
| New Tensor to be used in the target. | |||
| """ | |||
| def __init__(self, input_tensor, should_replace=False): | |||
| def __init__(self, input_tensor): | |||
| r""" | |||
| Args: | |||
| input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target | |||
| should_replace(bool): added this for interface consistency. NewTensor should only appear in the target. | |||
| Raises: | |||
| ValueError: raise if should_replace is True | |||
| TypeError: raise type error for invalid argument. | |||
| """ | |||
| if should_replace: | |||
| raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.") | |||
| self.input_tensor = input_tensor | |||
| if isinstance(input_tensor, Tensor): | |||
| NewTensor_.__init__(self, input_tensor) | |||
| @@ -194,15 +172,13 @@ class NewParameter(NewParameter_): | |||
| r""" | |||
| New Parameter to be used in the target. | |||
| """ | |||
| def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False): | |||
| def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False): | |||
| r""" | |||
| Args: | |||
| para_name(str): name for the new Parameter | |||
| default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter | |||
| requires_grad(bool): True if the parameter requires gradient. Default: True | |||
| layerwise_parallel(bool): switch for layerwise parallel mode. Default: False | |||
| should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new | |||
| parameter everytime a pass target got built. Default: False | |||
| Raises: | |||
| TypeError: raise type error for invalid argument. | |||
| @@ -211,12 +187,11 @@ class NewParameter(NewParameter_): | |||
| self.default_tensor = default_tensor | |||
| self.requires_grad = requires_grad | |||
| self.layerwise_parallel = layerwise_parallel | |||
| self.should_replace = should_replace | |||
| if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ | |||
| isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool): | |||
| isinstance(layerwise_parallel, bool): | |||
| NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, | |||
| self.layerwise_parallel, self.should_replace) | |||
| self.layerwise_parallel) | |||
| else: | |||
| raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ | |||
| layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \ | |||
| {requires_grad}, {layerwise_parallel}, {should_replace}") | |||
| layerwise_parallel(bool), got : {para_name}, {default_tensor}, \ | |||
| {requires_grad}, {layerwise_parallel}") | |||
| @@ -15,7 +15,7 @@ | |||
| """Python pass register""" | |||
| from inspect import isfunction | |||
| from mindspore.graph_utils.graph_pattern import Pattern, NewParameter | |||
| from mindspore._c_expression import PyPassManager_, phase | |||
| from mindspore._c_expression import PyPassManager_ | |||
| __all__ = [ | |||
| @@ -30,21 +30,16 @@ class PyPassManager(PyPassManager_): | |||
| Used to registe and unregiste python passes which can be used to alter graphs. | |||
| Args: | |||
| pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. | |||
| run_only_once (bool): Specify whether or not to run pass only once. Default: False. | |||
| multigraph (bool): Whether or not the pattern exists across graphs. Default: True. | |||
| Raises: | |||
| TypeError: If argument has invalid type. | |||
| """ | |||
| def __init__(self, pipeline_phase=phase.opt, run_only_once=False): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| def __init__(self, run_only_once=False): | |||
| if not isinstance(run_only_once, bool): | |||
| raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}") | |||
| PyPassManager_.__init__(self) | |||
| self.phase_ = pipeline_phase | |||
| self.run_only_once_ = run_only_once | |||
| PyPassManager_.__init__(self) | |||
| def registe(self, py_pass): | |||
| if not isfunction(py_pass): | |||
| @@ -55,16 +50,14 @@ class PyPassManager(PyPassManager_): | |||
| raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") | |||
| if not isinstance(target, Pattern): | |||
| raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") | |||
| super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_) | |||
| super().registe(pass_name, pattern, target, self.run_only_once_) | |||
| def unregiste(self, py_pass, pipeline_phase=phase.opt): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| def unregiste(self, py_pass): | |||
| if isinstance(py_pass, str): | |||
| super().unregiste(py_pass, pipeline_phase) | |||
| super().unregiste(py_pass) | |||
| return | |||
| if isfunction(py_pass): | |||
| super().unregiste(py_pass.__name__, pipeline_phase) | |||
| super().unregiste(py_pass.__name__) | |||
| return | |||
| raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") | |||
| @@ -82,13 +75,11 @@ class PyPassManager(PyPassManager_): | |||
| raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") | |||
| super().set_renorm(should_renorm) | |||
| def registe_pass(pipeline_phase=phase.opt, run_only_once=False): | |||
| def registe_pass(run_only_once=False): | |||
| """ | |||
| Registe python pass to specified pipeline phase which would be used in compilation. | |||
| Args: | |||
| pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is | |||
| registed. Support phase.resolve and phase.opt. Default: phase.opt. | |||
| run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False. | |||
| Returns: | |||
| @@ -102,19 +93,17 @@ def registe_pass(pipeline_phase=phase.opt, run_only_once=False): | |||
| >>> target = IsPrimTypeOf("ReLU6") | |||
| >>> return pattern, target | |||
| """ | |||
| return PyPassManager(pipeline_phase, run_only_once) | |||
| return PyPassManager(run_only_once) | |||
| def unregiste_pass(py_pass, pipeline_phase=phase.opt): | |||
| def unregiste_pass(py_pass): | |||
| """ | |||
| Unregiste python pass. | |||
| Args: | |||
| py_pass(Union(str, function)): target python pass to unregiste. | |||
| pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is | |||
| unregisted. Support phase.resolve and phase.opt. Default: phase.opt. | |||
| """ | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(py_pass, pipeline_phase) | |||
| ppm.unregiste(py_pass) | |||
| def gen_new_parameter(pattern): | |||
| """ | |||
| @@ -164,7 +153,14 @@ def cancel_new_parameter(pattern): | |||
| def set_renorm(should_renorm): | |||
| """ | |||
| Set whether or not to do renorm after modified graph in python pass(es). | |||
| Set whether or not to do renormalization after modified graph in python pass(es). | |||
| Args: | |||
| should_renorm(bool): whether or not to do renormalization after modified graph in python pass(es). | |||
| NOTE: | |||
| This interface is mainly intended for testing modifying graph without worrying about its validity. Turn off | |||
| renormalization may BREAK the network. | |||
| """ | |||
| ppm = PyPassManager() | |||
| ppm.set_renorm(should_renorm) | |||
| @@ -23,8 +23,7 @@ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_ | |||
| cancel_new_parameter | |||
| from mindspore.common.api import _generate_pip_args | |||
| from mindspore._c_expression import generate_key, Executor_ | |||
| from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\ | |||
| NewParameter, Imm | |||
| from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -50,11 +49,9 @@ def test_softmax_relu(): | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_relu_pass(): | |||
| x = AnyPattern() | |||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||
| pattern = CallWith(softmax_pattern, inputs=[x]) | |||
| relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) | |||
| target = CallWith(relu_pattern, inputs=[x]) | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| target = Call(P.ReLU(), inputs=[x]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| @@ -74,13 +71,13 @@ def test_softmax_relu_sigmoid(): | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_relu_pass(): | |||
| x = AnyPattern() | |||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||
| pattern = CallWith(softmax_pattern, inputs=[x]) | |||
| sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False) | |||
| call_sigmoid = CallWith(sigmoid_pattern, [x]) | |||
| relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False) | |||
| target = CallWith(relu_pattern, inputs=[call_sigmoid]) | |||
| x = Any() | |||
| softmax_pattern = Prim(P.Softmax()) | |||
| pattern = Call(softmax_pattern, inputs=[x]) | |||
| sigmoid_pattern = Prim(P.Sigmoid()) | |||
| call_sigmoid = Call(sigmoid_pattern, [x]) | |||
| relu_pattern = Prim(P.ReLU()) | |||
| target = Call(relu_pattern, inputs=[call_sigmoid]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3) | |||
| @@ -99,15 +96,15 @@ def test_isin_pattern_0(): | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_relu_pass(): | |||
| x = AnyPattern() | |||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||
| call_softmax = CallWith(softmax_pattern, inputs=[x]) | |||
| relu_pattern = IsPrimTypeOf(P.ReLU()) | |||
| call_relu = CallWith(relu_pattern, inputs=[x]) | |||
| pattern = IsIn([call_softmax, call_relu]) | |||
| relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False) | |||
| target = CallWith(relu6_pattern, inputs=[x]) | |||
| x = Any() | |||
| softmax_pattern = Prim(P.Softmax()) | |||
| call_softmax = Call(softmax_pattern, inputs=[x]) | |||
| relu_pattern = Prim(P.ReLU()) | |||
| call_relu = Call(relu_pattern, inputs=[x]) | |||
| pattern = OneOf([call_softmax, call_relu]) | |||
| relu6_pattern = Prim(P.ReLU6()) | |||
| target = Call(relu6_pattern, inputs=[x]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| unregiste_pass(softmax_relu_pass) | |||
| @@ -123,18 +120,17 @@ def test_isin_pattern_1(): | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_neg_pass(): | |||
| x = AnyPattern() | |||
| softmax_pattern = IsPrimTypeOf(P.Softmax()) | |||
| call_softmax = CallWith(softmax_pattern, inputs=[x]) | |||
| relu_pattern = IsPrimTypeOf(P.ReLU()) | |||
| call_relu = CallWith(relu_pattern, inputs=[x]) | |||
| pattern = IsIn([call_softmax, call_relu]) | |||
| neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False) | |||
| target = CallWith(neg_ops, inputs=[pattern]) | |||
| x = Any() | |||
| softmax_pattern = Prim(P.Softmax()) | |||
| call_softmax = Call(softmax_pattern, inputs=[x]) | |||
| relu_pattern = Prim(P.ReLU()) | |||
| call_relu = Call(relu_pattern, inputs=[x]) | |||
| pattern = OneOf([call_softmax, call_relu]) | |||
| neg_ops = Prim(P.Neg()) | |||
| target = Call(neg_ops, inputs=[pattern]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4) | |||
| print(transformed_repr) | |||
| unregiste_pass(softmax_neg_pass) | |||
| assert "Neg" in transformed_repr | |||
| assert "Softmax" in transformed_repr | |||
| @@ -167,11 +163,11 @@ def test_isnot_pattern_0(): | |||
| """ | |||
| Sub a BN which does NOT take Conv as inputs to ReLU6. | |||
| """ | |||
| conv2d_prim = IsPrimTypeOf("Conv2D") | |||
| conv2d = CallWith(conv2d_prim) | |||
| pattern_0 = IsNot(conv2d) | |||
| pattern = CallWith(P.BatchNorm(), inputs=[pattern_0]) | |||
| target = CallWith(P.ReLU6(), inputs=[pattern_0]) | |||
| conv2d_prim = Prim("Conv2D") | |||
| conv2d = Call(conv2d_prim) | |||
| pattern_0 = NoneOf(conv2d) | |||
| pattern = Call(P.BatchNorm(), inputs=[pattern_0]) | |||
| target = Call(P.ReLU6(), inputs=[pattern_0]) | |||
| return pattern, target | |||
| @registe_pass(run_only_once=True) | |||
| @@ -179,10 +175,8 @@ def test_isnot_pattern_0(): | |||
| """ | |||
| Sub a BN to Softmax. | |||
| """ | |||
| bn = P.BatchNorm() | |||
| pattern = CallWith(bn) | |||
| softmax = P.Softmax() | |||
| target = CallWith(softmax, should_replace=False) | |||
| pattern = Call(P.BatchNorm()) | |||
| target = Call(P.Softmax()) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5) | |||
| @@ -205,12 +199,12 @@ def test_isnot_pattern_1(): | |||
| """ | |||
| Sub a BN which does NOT take MatMul as inputs to ReLU6. | |||
| """ | |||
| matmul = IsPrimTypeOf("MatMul") | |||
| pattern_0 = IsNot(matmul) | |||
| matmul = Prim("MatMul") | |||
| pattern_0 = NoneOf(matmul) | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[pattern_0]) | |||
| pattern = Call(softmax, inputs=[pattern_0]) | |||
| relu6 = P.ReLU6() | |||
| target = CallWith(relu6, inputs=[pattern_0], should_replace=False) | |||
| target = Call(relu6, inputs=[pattern_0]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| @@ -228,14 +222,12 @@ def test_newtensor_pattern(): | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = AnyPattern() | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[x]) | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| weight_tensor = Tensor(np.zeros([42]), mindspore.float16) | |||
| new_weight = NewTensor(weight_tensor) | |||
| addn_ops = P.AddN() | |||
| target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False) | |||
| target = Call(P.AddN(), inputs=[x, new_weight]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| unregiste_pass(softmax_addn_pass) | |||
| @@ -252,25 +244,23 @@ def test_newparameter_pattern(): | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = AnyPattern() | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[x]) | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| new_para_0 = NewParameter("Merlin", default_tensor0) | |||
| new_para_1 = NewParameter("Arthur", default_tensor1) | |||
| target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False) | |||
| target = CallWith("make_tuple", inputs=[target_0], should_replace=False) | |||
| target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1]) | |||
| target = Call("make_tuple", inputs=[target_0]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| unregiste_pass(softmax_addn_pass) | |||
| assert "MatMul" in transformed_repr | |||
| assert "make_tuple" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||
| def test_imm_pattern(): | |||
| def test_imm_target(): | |||
| """ | |||
| Test NewParameter pattern in the target | |||
| """ | |||
| @@ -278,17 +268,15 @@ def test_imm_pattern(): | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_addn_pass(): | |||
| x = AnyPattern() | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[x]) | |||
| def softmax_pass(): | |||
| x = Any() | |||
| pattern = Call(P.Softmax(), inputs=[x]) | |||
| imm = Imm(0) | |||
| target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False) | |||
| target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False) | |||
| target_0 = Call("make_tuple", inputs=[pattern]) | |||
| target = Call("tuple_getitem", inputs=[target_0, imm]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| unregiste_pass(softmax_addn_pass) | |||
| unregiste_pass(softmax_pass) | |||
| assert "make_tuple" in transformed_repr | |||
| assert "tuple_getitem" in transformed_repr | |||
| assert "Softmax" in transformed_repr | |||
| @@ -301,21 +289,19 @@ def test_gen_new_parameter(): | |||
| softmax_model = nn.Softmax() | |||
| default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) | |||
| new_para = NewParameter("Merlin", default_tensor, should_replace=True) | |||
| new_para = NewParameter("Merlin", default_tensor) | |||
| gen_new_parameter(new_para) | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_make_tuple_pass(): | |||
| x = AnyPattern() | |||
| x = Any() | |||
| softmax = P.Softmax() | |||
| pattern = CallWith(softmax, inputs=[x]) | |||
| pattern = Call(softmax, inputs=[x]) | |||
| target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False) | |||
| target = Call("make_tuple", inputs=[pattern, new_para]) | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| assert "Merlin" in transformed_repr | |||
| unregiste_pass(softmax_make_tuple_pass) | |||
| cancel_new_parameter(new_para) | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5) | |||
| print(transformed_repr) | |||
| assert "Merlin" not in transformed_repr | |||