diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index a9fccf618c..365567ef68 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -202,10 +202,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() { } ResolveIRPassLib::ResolveIRPassLib() { - resolver_resolve_attr_ = - MakeSubstitution(std::make_shared(), "resolver_resolve_attr", prim::kPrimGetAttr); + resolver_resolve_and_getattr_ = + MakeSubstitution(std::make_shared(), "resolver_resolve_and_getattr", + {prim::kPrimGetAttr, prim::kPrimResolve}); resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); - resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); + resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); } InferenceOptPrepareLib::InferenceOptPrepareLib() { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 752f844057..3aa3d5305c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -141,7 +141,7 @@ class ResolveIRPassLib { ResolveIRPassLib(); ~ResolveIRPassLib() = default; - SubstitutionPtr resolver_resolve_attr_; + SubstitutionPtr resolver_resolve_and_getattr_; SubstitutionPtr resolver_resolve_; SubstitutionPtr resolver_getattr_; }; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h index 16ce1088e9..cc5972fa94 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h @@ -19,6 +19,7 @@ #include #include +#include #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer_caller.h" @@ -66,7 +67,7 @@ class ResolverResolve : public AnfVisitor { }; // {prim::kPrimGetAttr, Ns, Str} -class ResolverGetattr : public AnfVisitor { +class ResolverGetAttr : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { Reset(); @@ -97,7 +98,7 @@ class ResolverGetattr : public AnfVisitor { }; // {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} -class ResolveAttr : public OptimizerCaller { +class ResolverGetAttrResolve : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { PatternNode ns_node, sym_node, attr_node; @@ -122,6 +123,29 @@ class ResolveAttr : public OptimizerCaller { return nullptr; } }; + +class ResolverResolveAndGetAttr : public OptimizerCaller { + public: + ResolverResolveAndGetAttr() { + resolver_optimizers_ = {std::make_shared(), std::make_shared(), + std::make_shared()}; + } + ~ResolverResolveAndGetAttr() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (const auto &resolver_opt : resolver_optimizers_) { + new_node = (*resolver_opt)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + std::vector resolver_optimizers_{}; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 780636c601..254b4be403 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -276,8 +276,15 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa if (!data_converter::IsCellInstance(obj)) { return nullptr; } - py::object obj_attr = obj.attr(attr.c_str()); - AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node); + + const std::string fn = PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL; + const std::string module = "mindspore._extends.parse.parser"; + py::object namespace_obj = parse::python_adapter::GetPyFn(module, fn)(obj); + auto new_namespace = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_obj); + auto new_symbol = std::make_shared(attr); + + AnfNodePtrList inputs = {NewValueNode(prim::kPrimResolve), NewValueNode(new_namespace), NewValueNode(new_symbol)}; + AnfNodePtr resolved_node = node->func_graph()->NewCNode(inputs); TraceManager::ClearParseOrResolveDebugInfo(); return resolved_node; } @@ -285,16 +292,10 @@ AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpa namespace { opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap map({ - {"resolve_attr", - { - // for resolve primitive; - irpass.resolver_resolve_attr_, - }}, {"resolve", { // for resolve and getattr primitive; - irpass.resolver_resolve_, - irpass.resolver_getattr_, + irpass.resolver_resolve_and_getattr_, }}, }); return map; diff --git a/tests/ut/python/pipeline/parse/test_call_innetr_net_attr.py b/tests/ut/python/pipeline/parse/test_call_innetr_net_attr.py new file mode 100644 index 0000000000..8d490a3266 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_call_innetr_net_attr.py @@ -0,0 +1,105 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test call inner net attr""" +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import context +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + + +class InnerInNet(nn.Cell): + def __init__(self, init_data, const): + super(InnerInNet, self).__init__() + self.weight = Parameter(init_data, name="weight_s") + self.t = init_data + self.const = const + + def construct(self, input_x): + if self.const: + return input_x * self.t + return input_x * self.weight + + +class InnerNet(nn.Cell): + def __init__(self, init_data, const): + super(InnerNet, self).__init__() + self.inner_in_net = InnerInNet(init_data, const) + self.t = init_data + self.const = const + + def construct(self, input_x): + if self.const: + return self.inner_in_net.t / self.inner_in_net(input_x) + return self.inner_in_net.weight / self.inner_in_net(input_x) + + +class Net(nn.Cell): + def __init__(self, init_data, const): + super(Net, self).__init__() + self.inner_net = InnerNet(init_data, const) + self.x = Tensor(np.ones((2, 3)) * 5) + self.y = Tensor(np.ones((2, 3)) * 6) + self.const = const + self.weight = Parameter(init_data, name="weight_s") + + def construct(self, input_x, input_y): + if self.const: + return self.inner_net.t + self.inner_net(self.x) - self.y + return self.inner_net.t + self.inner_net(input_x) - input_y + + +class OuterMostNet(nn.Cell): + def __init__(self, init_data, const): + super(OuterMostNet, self).__init__() + self.net = Net(init_data, const) + + def construct(self, input_x, input_y): + return self.net.inner_net.inner_in_net.t + + +class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.forward_net = net + self.sens = Tensor(np.ones((2, 2), np.float32) * 5) + self.grad_all = C.GradOperation(get_all=True) + + def construct(self, input_x, input_y): + return self.grad_all(self.forward_net)(input_x, input_y) + + +def test_inner_net_attr(): + input_x = Tensor(np.ones((2, 3)) * 2) + input_y = Tensor(np.ones((2, 3)) * 3) + init_data = Tensor(np.ones((2, 3)) * 4) + + test_var_net = Net(init_data, False) + test_var_net(input_x, input_y) + + grad_net = GradNet(test_var_net) + grad_net(input_x, input_y) + + test_const_net = Net(init_data, True) + ret = test_const_net(input_x, input_y) + expect = -1.8 * np.ones((2, 3)) + assert np.allclose(ret.asnumpy(), expect) + + test_outer_net = OuterMostNet(init_data, True) + ret = test_outer_net(input_x, input_y) + assert np.allclose(ret.asnumpy(), init_data.asnumpy())