From: @zhangbuxue Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -202,10 +202,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| } | } | ||||
| ResolveIRPassLib::ResolveIRPassLib() { | ResolveIRPassLib::ResolveIRPassLib() { | ||||
| resolver_resolve_attr_ = | |||||
| MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr); | |||||
| resolver_resolve_and_getattr_ = | |||||
| MakeSubstitution(std::make_shared<ResolverResolveAndGetAttr>(), "resolver_resolve_and_getattr", | |||||
| {prim::kPrimGetAttr, prim::kPrimResolve}); | |||||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | ||||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr); | |||||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr); | |||||
| } | } | ||||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | InferenceOptPrepareLib::InferenceOptPrepareLib() { | ||||
| @@ -141,7 +141,7 @@ class ResolveIRPassLib { | |||||
| ResolveIRPassLib(); | ResolveIRPassLib(); | ||||
| ~ResolveIRPassLib() = default; | ~ResolveIRPassLib() = default; | ||||
| SubstitutionPtr resolver_resolve_attr_; | |||||
| SubstitutionPtr resolver_resolve_and_getattr_; | |||||
| SubstitutionPtr resolver_resolve_; | SubstitutionPtr resolver_resolve_; | ||||
| SubstitutionPtr resolver_getattr_; | SubstitutionPtr resolver_getattr_; | ||||
| }; | }; | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "frontend/optimizer/optimizer.h" | #include "frontend/optimizer/optimizer.h" | ||||
| #include "frontend/optimizer/optimizer_caller.h" | #include "frontend/optimizer/optimizer_caller.h" | ||||
| @@ -66,7 +67,7 @@ class ResolverResolve : public AnfVisitor { | |||||
| }; | }; | ||||
| // {prim::kPrimGetAttr, Ns, Str} | // {prim::kPrimGetAttr, Ns, Str} | ||||
| class ResolverGetattr : public AnfVisitor { | |||||
| class ResolverGetAttr : public AnfVisitor { | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | ||||
| Reset(); | Reset(); | ||||
| @@ -97,7 +98,7 @@ class ResolverGetattr : public AnfVisitor { | |||||
| }; | }; | ||||
| // {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} | // {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} | ||||
| class ResolveAttr : public OptimizerCaller { | |||||
| class ResolverGetAttrResolve : public OptimizerCaller { | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | ||||
| PatternNode<AnfNodePtr> ns_node, sym_node, attr_node; | PatternNode<AnfNodePtr> ns_node, sym_node, attr_node; | ||||
| @@ -122,6 +123,29 @@ class ResolveAttr : public OptimizerCaller { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| }; | }; | ||||
| class ResolverResolveAndGetAttr : public OptimizerCaller { | |||||
| public: | |||||
| ResolverResolveAndGetAttr() { | |||||
| resolver_optimizers_ = {std::make_shared<ResolverGetAttrResolve>(), std::make_shared<ResolverResolve>(), | |||||
| std::make_shared<ResolverGetAttr>()}; | |||||
| } | |||||
| ~ResolverResolveAndGetAttr() = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||||
| AnfNodePtr new_node; | |||||
| for (const auto &resolver_opt : resolver_optimizers_) { | |||||
| new_node = (*resolver_opt)(optimizer, node); | |||||
| if (new_node != nullptr) { | |||||
| return new_node; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| private: | |||||
| std::vector<OptimizerCallerPtr> resolver_optimizers_{}; | |||||
| }; | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -276,8 +276,15 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa | |||||
| if (!data_converter::IsCellInstance(obj)) { | if (!data_converter::IsCellInstance(obj)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| py::object obj_attr = obj.attr(attr.c_str()); | |||||
| AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node); | |||||
| const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL; | |||||
| const std::string module = "mindspore._extends.parse.parser"; | |||||
| py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj); | |||||
| auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj); | |||||
| auto new_symbol = std::make_shared<Symbol>(attr); | |||||
| AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)}; | |||||
| AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs); | |||||
| TraceManager::ClearParseOrResolveDebugInfo(); | TraceManager::ClearParseOrResolveDebugInfo(); | ||||
| return resolved_node; | return resolved_node; | ||||
| } | } | ||||
| @@ -285,16 +292,10 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa | |||||
| namespace { | namespace { | ||||
| opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { | opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { | ||||
| opt::OptPassGroupMap map({ | opt::OptPassGroupMap map({ | ||||
| {"resolve_attr", | |||||
| { | |||||
| // for resolve primitive; | |||||
| irpass.resolver_resolve_attr_, | |||||
| }}, | |||||
| {"resolve", | {"resolve", | ||||
| { | { | ||||
| // for resolve and getattr primitive; | // for resolve and getattr primitive; | ||||
| irpass.resolver_resolve_, | |||||
| irpass.resolver_getattr_, | |||||
| irpass.resolver_resolve_and_getattr_, | |||||
| }}, | }}, | ||||
| }); | }); | ||||
| return map; | return map; | ||||
| @@ -0,0 +1,105 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test call inner net attr""" | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore import context | |||||
| from mindspore.ops import composite as C | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| class InnerInNet(nn.Cell): | |||||
| def __init__(self, init_data, const): | |||||
| super(InnerInNet, self).__init__() | |||||
| self.weight = Parameter(init_data, name="weight_s") | |||||
| self.t = init_data | |||||
| self.const = const | |||||
| def construct(self, input_x): | |||||
| if self.const: | |||||
| return input_x * self.t | |||||
| return input_x * self.weight | |||||
| class InnerNet(nn.Cell): | |||||
| def __init__(self, init_data, const): | |||||
| super(InnerNet, self).__init__() | |||||
| self.inner_in_net = InnerInNet(init_data, const) | |||||
| self.t = init_data | |||||
| self.const = const | |||||
| def construct(self, input_x): | |||||
| if self.const: | |||||
| return self.inner_in_net.t / self.inner_in_net(input_x) | |||||
| return self.inner_in_net.weight / self.inner_in_net(input_x) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, init_data, const): | |||||
| super(Net, self).__init__() | |||||
| self.inner_net = InnerNet(init_data, const) | |||||
| self.x = Tensor(np.ones((2, 3)) * 5) | |||||
| self.y = Tensor(np.ones((2, 3)) * 6) | |||||
| self.const = const | |||||
| self.weight = Parameter(init_data, name="weight_s") | |||||
| def construct(self, input_x, input_y): | |||||
| if self.const: | |||||
| return self.inner_net.t + self.inner_net(self.x) - self.y | |||||
| return self.inner_net.t + self.inner_net(input_x) - input_y | |||||
| class OuterMostNet(nn.Cell): | |||||
| def __init__(self, init_data, const): | |||||
| super(OuterMostNet, self).__init__() | |||||
| self.net = Net(init_data, const) | |||||
| def construct(self, input_x, input_y): | |||||
| return self.net.inner_net.inner_in_net.t | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.forward_net = net | |||||
| self.sens = Tensor(np.ones((2, 2), np.float32) * 5) | |||||
| self.grad_all = C.GradOperation(get_all=True) | |||||
| def construct(self, input_x, input_y): | |||||
| return self.grad_all(self.forward_net)(input_x, input_y) | |||||
| def test_inner_net_attr(): | |||||
| input_x = Tensor(np.ones((2, 3)) * 2) | |||||
| input_y = Tensor(np.ones((2, 3)) * 3) | |||||
| init_data = Tensor(np.ones((2, 3)) * 4) | |||||
| test_var_net = Net(init_data, False) | |||||
| test_var_net(input_x, input_y) | |||||
| grad_net = GradNet(test_var_net) | |||||
| grad_net(input_x, input_y) | |||||
| test_const_net = Net(init_data, True) | |||||
| ret = test_const_net(input_x, input_y) | |||||
| expect = -1.8 * np.ones((2, 3)) | |||||
| assert np.allclose(ret.asnumpy(), expect) | |||||
| test_outer_net = OuterMostNet(init_data, True) | |||||
| ret = test_outer_net(input_x, input_y) | |||||
| assert np.allclose(ret.asnumpy(), init_data.asnumpy()) | |||||