Merge pull request !3453 from zichun_ye/resolve_attr_prtags/v0.7.0-beta
| @@ -168,6 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| } | |||
| ResolveIRPassLib::ResolveIRPassLib() { | |||
| resolver_resolve_attr_ = | |||
| MakeSubstitution(std::make_shared<ResolveAttr>(), "resolver_resolve_attr", prim::kPrimGetAttr); | |||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr); | |||
| } | |||
| @@ -118,6 +118,7 @@ class ResolveIRPassLib { | |||
| ResolveIRPassLib(); | |||
| ~ResolveIRPassLib() = default; | |||
| SubstitutionPtr resolver_resolve_attr_; | |||
| SubstitutionPtr resolver_resolve_; | |||
| SubstitutionPtr resolver_getattr_; | |||
| }; | |||
| @@ -21,15 +21,21 @@ | |||
| #include <memory> | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/optimizer_caller.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| #include "pipeline/jit/parse/parse_base.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| const char PARSE_SUPER_NAME[] = "namespace"; | |||
| // {prim::kPrimResolve, Ns, Sym} | |||
| class ResolverResolve : public AnfVisitor { | |||
| public: | |||
| @@ -90,6 +96,34 @@ class ResolverGetattr : public AnfVisitor { | |||
| parse::NameSpacePtr ns_{nullptr}; | |||
| parse::SymbolPtr sym_{nullptr}; | |||
| }; | |||
| // {prim::kPrimGetAttr, {prim::kPrimResolve, ns_node, sym_node}, attr_node} | |||
| class ResolveAttr : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| PatternNode<AnfNodePtr> ns_node, sym_node, attr_node; | |||
| auto ResolveAttrLambda = [&node, &ns_node, &sym_node, &attr_node, &optimizer]() -> AnfNodePtr { | |||
| auto node_to_getattr = node->cast<CNodePtr>()->input(1); | |||
| std::string attr_as_string = GetValueNode<StringImmPtr>(attr_node.GetNode(node))->value(); | |||
| auto ns_ = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node)); | |||
| auto sym_ = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node)); | |||
| if (ns_->module() == parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER && sym_->symbol() != PARSE_SUPER_NAME) { | |||
| // deal with the case of getting attr from a class member | |||
| // and avoid the case of getting attr from self (the result of ParseSuper) | |||
| auto result = parse::ResolveCellwithAttr(optimizer->manager(), ns_, sym_, node_to_getattr, attr_as_string); | |||
| return result; | |||
| } | |||
| return nullptr; | |||
| }; | |||
| MATCH_REPLACE_LAMBDA_IF( | |||
| node, PPrimitive(prim::kPrimGetAttr, PPrimitive(prim::kPrimResolve, ns_node, sym_node), attr_node), | |||
| ResolveAttrLambda, attr_node.CheckFunc(IsValueNode<StringImm>, node)); | |||
| return nullptr; | |||
| } | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -228,19 +228,10 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func | |||
| return true; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, | |||
| const AnfNodePtr &node) { | |||
| if (node->func_graph() == nullptr || manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; | |||
| } | |||
| SymbolResolver symbol_resolver(name_space, symbol, node); | |||
| if (!symbol_resolver.Resolve()) { | |||
| MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | |||
| } | |||
| py::object obj = symbol_resolver.result(); | |||
| // resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager | |||
| AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj, | |||
| const AnfNodePtr &node) { | |||
| ScopeGuard scope_guard(node->scope()); | |||
| AnfNodePtr resolved_node = nullptr; | |||
| TraceManager::DebugTrace(std::make_shared<TraceResolve>(node->debug_info())); | |||
| @@ -262,10 +253,54 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr | |||
| TraceManager::EndTrace(); | |||
| return resolved_node; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, | |||
| const AnfNodePtr &node) { | |||
| if (node->func_graph() == nullptr || manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; | |||
| } | |||
| SymbolResolver symbol_resolver(name_space, symbol, node); | |||
| if (!symbol_resolver.Resolve()) { | |||
| MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | |||
| } | |||
| py::object obj = symbol_resolver.result(); | |||
| AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj, node); | |||
| return resolved_node; | |||
| } | |||
| AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, | |||
| const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr) { | |||
| if (node->func_graph() == nullptr || manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; | |||
| } | |||
| SymbolResolver symbol_resolver(name_space, symbol, node); | |||
| if (!symbol_resolver.Resolve()) { | |||
| MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | |||
| } | |||
| py::object obj = symbol_resolver.result(); | |||
| if (!data_converter::IsCellInstance(obj)) { | |||
| return nullptr; | |||
| } | |||
| py::object obj_attr = obj.attr(attr.c_str()); | |||
| AnfNodePtr resolved_node = ResolveObjectAndAddToManager(manager, obj_attr, node); | |||
| return resolved_node; | |||
| } | |||
| namespace { | |||
| opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { | |||
| opt::OptPassGroupMap map({ | |||
| {"resolve_attr", | |||
| { | |||
| // for resolve primitive; | |||
| irpass.resolver_resolve_attr_, | |||
| }}, | |||
| {"resolve", | |||
| { | |||
| // for resolve and getattr primitive; | |||
| @@ -145,6 +145,10 @@ using SymbolResolverPtr = std::shared_ptr<SymbolResolver>; | |||
| AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, | |||
| const AnfNodePtr &node); | |||
| // Resolve Cell with attr name. | |||
| AnfNodePtr ResolveCellwithAttr(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, | |||
| const SymbolPtr &symbol, const AnfNodePtr &node, const std::string &attr); | |||
| // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). | |||
| bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test cases for new api of normal distribution""" | |||
| import numpy as np | |||
| from scipy import stats | |||
| import mindspore.nn as nn | |||
| from mindspore import dtype | |||
| from mindspore import Tensor | |||
| import mindspore.context as context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| """ | |||
| Test class: new api of normal distribution. | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.normal = nn.Normal(0., 1., dtype=dtype.float32) | |||
| def construct(self, x_, y_): | |||
| kl = self.normal.kl_loss('kl_loss', 'Normal', x_, y_) | |||
| prob = self.normal.prob('prob', kl) | |||
| return prob | |||
| def test_new_api(): | |||
| """ | |||
| Test new api of normal distribution. | |||
| """ | |||
| prob = Net() | |||
| mean_a = np.array([0.0]).astype(np.float32) | |||
| sd_a = np.array([1.0]).astype(np.float32) | |||
| mean_b = np.array([1.0]).astype(np.float32) | |||
| sd_b = np.array([1.0]).astype(np.float32) | |||
| ans = prob(Tensor(mean_b), Tensor(sd_b)) | |||
| diff_log_scale = np.log(sd_a) - np.log(sd_b) | |||
| squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) | |||
| expect_kl_loss = 0.5 * squared_diff + 0.5 * \ | |||
| np.expm1(2 * diff_log_scale) - diff_log_scale | |||
| norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0])) | |||
| expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32) | |||
| tol = 1e-6 | |||
| assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() | |||