From: @zhangbuxue Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -202,10 +202,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| } | |||
| ResolveIRPassLib::ResolveIRPassLib() { | |||
| resolver_resolve_attr_ = | |||
| MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr); | |||
| resolver_resolve_and_getattr_ = | |||
| MakeSubstitution(std::make_shared<ResolverResolveAndGetAttr>(), "resolver_resolve_and_getattr", | |||
| {prim::kPrimGetAttr, prim::kPrimResolve}); | |||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr); | |||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr); | |||
| } | |||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | |||
| @@ -141,7 +141,7 @@ class ResolveIRPassLib { | |||
| ResolveIRPassLib(); | |||
| ~ResolveIRPassLib() = default; | |||
| SubstitutionPtr resolver_resolve_attr_; | |||
| SubstitutionPtr resolver_resolve_and_getattr_; | |||
| SubstitutionPtr resolver_resolve_; | |||
| SubstitutionPtr resolver_getattr_; | |||
| }; | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/optimizer_caller.h" | |||
| @@ -66,7 +67,7 @@ class ResolverResolve : public AnfVisitor { | |||
| }; | |||
| // {prim::kPrimGetAttr, Ns, Str} | |||
| class ResolverGetattr : public AnfVisitor { | |||
| class ResolverGetAttr : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| @@ -97,7 +98,7 @@ class ResolverGetattr : public AnfVisitor { | |||
| }; | |||
| // {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} | |||
| class ResolveAttr : public OptimizerCaller { | |||
| class ResolverGetAttrResolve : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| PatternNode<AnfNodePtr> ns_node, sym_node, attr_node; | |||
| @@ -122,6 +123,29 @@ class ResolveAttr : public OptimizerCaller { | |||
| return nullptr; | |||
| } | |||
| }; | |||
| class ResolverResolveAndGetAttr : public OptimizerCaller { | |||
| public: | |||
| ResolverResolveAndGetAttr() { | |||
| resolver_optimizers_ = {std::make_shared<ResolverGetAttrResolve>(), std::make_shared<ResolverResolve>(), | |||
| std::make_shared<ResolverGetAttr>()}; | |||
| } | |||
| ~ResolverResolveAndGetAttr() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (const auto &resolver_opt : resolver_optimizers_) { | |||
| new_node = (*resolver_opt)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| private: | |||
| std::vector<OptimizerCallerPtr> resolver_optimizers_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -276,8 +276,15 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa | |||
| if (!data_converter::IsCellInstance(obj)) { | |||
| return nullptr; | |||
| } | |||
| py::object obj_attr = obj.attr(attr.c_str()); | |||
| AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node); | |||
| const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL; | |||
| const std::string module = "mindspore._extends.parse.parser"; | |||
| py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj); | |||
| auto new_namespace = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj); | |||
| auto new_symbol = std::make_shared<Symbol>(attr); | |||
| AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)}; | |||
| AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs); | |||
| TraceManager::ClearParseOrResolveDebugInfo(); | |||
| return resolved_node; | |||
| } | |||
| @@ -285,16 +292,10 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa | |||
| namespace { | |||
| opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { | |||
| opt::OptPassGroupMap map({ | |||
| {"resolve_attr", | |||
| { | |||
| // for resolve primitive; | |||
| irpass.resolver_resolve_attr_, | |||
| }}, | |||
| {"resolve", | |||
| { | |||
| // for resolve and getattr primitive; | |||
| irpass.resolver_resolve_, | |||
| irpass.resolver_getattr_, | |||
| irpass.resolver_resolve_and_getattr_, | |||
| }}, | |||
| }); | |||
| return map; | |||
| @@ -0,0 +1,105 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test call inner net attr""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore import context | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| class InnerInNet(nn.Cell): | |||
| def __init__(self, init_data, const): | |||
| super(InnerInNet, self).__init__() | |||
| self.weight = Parameter(init_data, name="weight_s") | |||
| self.t = init_data | |||
| self.const = const | |||
| def construct(self, input_x): | |||
| if self.const: | |||
| return input_x * self.t | |||
| return input_x * self.weight | |||
| class InnerNet(nn.Cell): | |||
| def __init__(self, init_data, const): | |||
| super(InnerNet, self).__init__() | |||
| self.inner_in_net = InnerInNet(init_data, const) | |||
| self.t = init_data | |||
| self.const = const | |||
| def construct(self, input_x): | |||
| if self.const: | |||
| return self.inner_in_net.t / self.inner_in_net(input_x) | |||
| return self.inner_in_net.weight / self.inner_in_net(input_x) | |||
| class Net(nn.Cell): | |||
| def __init__(self, init_data, const): | |||
| super(Net, self).__init__() | |||
| self.inner_net = InnerNet(init_data, const) | |||
| self.x = Tensor(np.ones((2, 3)) * 5) | |||
| self.y = Tensor(np.ones((2, 3)) * 6) | |||
| self.const = const | |||
| self.weight = Parameter(init_data, name="weight_s") | |||
| def construct(self, input_x, input_y): | |||
| if self.const: | |||
| return self.inner_net.t + self.inner_net(self.x) - self.y | |||
| return self.inner_net.t + self.inner_net(input_x) - input_y | |||
| class OuterMostNet(nn.Cell): | |||
| def __init__(self, init_data, const): | |||
| super(OuterMostNet, self).__init__() | |||
| self.net = Net(init_data, const) | |||
| def construct(self, input_x, input_y): | |||
| return self.net.inner_net.inner_in_net.t | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.forward_net = net | |||
| self.sens = Tensor(np.ones((2, 2), np.float32) * 5) | |||
| self.grad_all = C.GradOperation(get_all=True) | |||
| def construct(self, input_x, input_y): | |||
| return self.grad_all(self.forward_net)(input_x, input_y) | |||
| def test_inner_net_attr(): | |||
| input_x = Tensor(np.ones((2, 3)) * 2) | |||
| input_y = Tensor(np.ones((2, 3)) * 3) | |||
| init_data = Tensor(np.ones((2, 3)) * 4) | |||
| test_var_net = Net(init_data, False) | |||
| test_var_net(input_x, input_y) | |||
| grad_net = GradNet(test_var_net) | |||
| grad_net(input_x, input_y) | |||
| test_const_net = Net(init_data, True) | |||
| ret = test_const_net(input_x, input_y) | |||
| expect = -1.8 * np.ones((2, 3)) | |||
| assert np.allclose(ret.asnumpy(), expect) | |||
| test_outer_net = OuterMostNet(init_data, True) | |||
| ret = test_outer_net(input_x, input_y) | |||
| assert np.allclose(ret.asnumpy(), init_data.asnumpy()) | |||