Merge pull request !2948 from amongo/FixControlFlowtags/v0.7.0-beta
| @@ -330,7 +330,8 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_gra | |||
| } | |||
| oss << "SymInst(%para" << idx << ")"; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString(); | |||
| MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString(); | |||
| oss << "SymInst(cnode_" << sym_node->ToString() << ")"; | |||
| } | |||
| return oss.str(); | |||
| @@ -191,6 +191,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||
| if (!morph->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| // for free variable, which may be handled in MapValueObject, just return it | |||
| auto node_adjoint_found = anfnode_to_adjoin_.find(morph); | |||
| if (node_adjoint_found != anfnode_to_adjoin_.end()) { | |||
| return node_adjoint_found->second; | |||
| } | |||
| ScopeGuard scope_guard(morph->scope()); | |||
| auto cnode_morph = morph->cast<CNodePtr>(); | |||
| @@ -504,7 +509,7 @@ void DFunctor::MapFvObject() { | |||
| if (parent_adjoint != nullptr) { | |||
| adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_); | |||
| } else { | |||
| if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) { | |||
| if (is_top_ || node->isa<Parameter>()) { | |||
| // Out of ad scope, add adjoint for free variables. | |||
| adjoint = std::make_shared<Adjoint>(node, node, tape_); | |||
| UpdateAdjoint(adjoint); | |||
| @@ -88,10 +88,12 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| env_get_item_eliminate_ = | |||
| MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||
| new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_ = | |||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_bypass_recursive_ = | |||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(), | |||
| "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_ = | |||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| // Ref eliminate | |||
| make_ref_eliminate_ = | |||
| @@ -123,6 +125,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // inline | |||
| inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph); | |||
| inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph); | |||
| replace_applicator_ = | |||
| MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>); | |||
| specialize_transform_ = | |||
| @@ -55,6 +55,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr env_get_item_eliminate_; | |||
| SubstitutionPtr new_env_get_item_; | |||
| SubstitutionPtr incorporate_env_getitem_; | |||
| SubstitutionPtr incorporate_env_getitem_bypass_recursive_; | |||
| SubstitutionPtr incorporate_env_getitem_switch_; | |||
| // Ref eliminate | |||
| @@ -80,6 +81,7 @@ class OptimizeIRPassLib { | |||
| // inline | |||
| SubstitutionPtr inline_; | |||
| SubstitutionPtr inline_without_move_; | |||
| SubstitutionPtr replace_applicator_; | |||
| SubstitutionPtr specialize_transform_; | |||
| @@ -196,6 +198,16 @@ inline bool IsCNodeDup(const AnfNodePtr &node) { | |||
| auto inp0 = node->cast<CNodePtr>()->input(0); | |||
| return (inp0 != nullptr) && inp0->isa<CNode>(); | |||
| } | |||
| // check if the cnode is a switch cnode | |||
| inline bool IsCNodeSwitch(const AnfNodePtr &node) { | |||
| if (node != nullptr) { | |||
| if (node->isa<CNode>()) { | |||
| return IsPrimitiveCNode(node, prim::kPrimSwitch); | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -29,6 +29,7 @@ | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/irpass/inline.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "utils/symbolic.h" | |||
| @@ -59,8 +60,13 @@ class EnvGetitemTransform { | |||
| while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { | |||
| // {prim::kPrimEnvSetItem, env, symbolickey, value} | |||
| auto &inputs = env->cast<CNodePtr>()->inputs(); | |||
| if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) { | |||
| MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; | |||
| if (inputs.size() != 4) { | |||
| MS_LOG(WARNING) << "Input size should be 4"; | |||
| return nullptr; | |||
| } | |||
| if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) { | |||
| MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; | |||
| return nullptr; | |||
| } | |||
| env = inputs[1]; | |||
| @@ -91,33 +97,12 @@ class EnvGetitemTransform { | |||
| class NewEnvGetItem : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| auto gety = [this](const AnfNodePtr &node) -> bool { | |||
| this->y_ = node; | |||
| return true; | |||
| }; | |||
| AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node); | |||
| if (env_ != nullptr && env_->Len() == 0) { | |||
| return y_; | |||
| } | |||
| PatternNode c1, c2, y; | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y, | |||
| (IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) && | |||
| (GetValueNode<EnvInstancePtr>(c1.GetNode(node)))->Len() == 0)); | |||
| return nullptr; | |||
| } | |||
| void Visit(const ValueNodePtr &vnode) override { | |||
| if (env_ == nullptr) { | |||
| env_ = GetValueNode<EnvInstancePtr>(vnode); | |||
| } | |||
| } | |||
| void Reset() { | |||
| y_ = nullptr; | |||
| env_ = nullptr; | |||
| } | |||
| private: | |||
| AnfNodePtr y_{nullptr}; | |||
| EnvInstancePtr env_{nullptr}; | |||
| }; | |||
| // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> | |||
| @@ -205,8 +190,13 @@ class EnvGetSetItem : public AnfVisitor { | |||
| while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { | |||
| // {prim::kPrimEnvSetItem, env, symbolickey, value} | |||
| auto &inputs = env->cast<CNodePtr>()->inputs(); | |||
| if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) { | |||
| MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; | |||
| if (inputs.size() != 4) { | |||
| MS_LOG(WARNING) << "Input size should be 4"; | |||
| return nullptr; | |||
| } | |||
| if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) { | |||
| MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; | |||
| return nullptr; | |||
| } | |||
| env = inputs[1]; | |||
| @@ -257,7 +247,8 @@ class EnvGetItemEliminater : public OptimizerCaller { | |||
| // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} | |||
| class IncorporateEnvGetitem : public AnfVisitor { | |||
| public: | |||
| IncorporateEnvGetitem() : env_get_item_transform_() {} | |||
| explicit IncorporateEnvGetitem(bool bypass_recursive = false) | |||
| : env_get_item_transform_(), bypass_recursive_(bypass_recursive) {} | |||
| ~IncorporateEnvGetitem() override = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -285,7 +276,13 @@ class IncorporateEnvGetitem : public AnfVisitor { | |||
| auto inputs = inp1->inputs(); | |||
| auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); | |||
| auto new_fg = env_get_item_transform_(fg, key, default_v); | |||
| if (fg->recursive() && bypass_recursive_) { | |||
| MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString(); | |||
| return nullptr; | |||
| } | |||
| if (new_fg == nullptr) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> args; | |||
| args.push_back(NewValueNode(new_fg)); | |||
| (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); | |||
| @@ -298,6 +295,7 @@ class IncorporateEnvGetitem : public AnfVisitor { | |||
| private: | |||
| bool is_match_{false}; | |||
| internal::EnvGetitemTransform env_get_item_transform_; | |||
| bool bypass_recursive_; | |||
| }; | |||
| // {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} | |||
| @@ -342,7 +340,9 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { | |||
| auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3)); | |||
| auto new_g1 = env_get_item_transform_(g1, key, default_v); | |||
| auto new_g2 = env_get_item_transform_(g2, key, default_v); | |||
| if (new_g1 == nullptr || new_g2 == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto fg = node->func_graph(); | |||
| auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); | |||
| @@ -93,10 +93,22 @@ bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } | |||
| bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } | |||
| bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) { | |||
| bool unique_use = IsUniqueUse(fg, nullptr); | |||
| bool is_recursive = fg->recursive(); | |||
| if (fg->parent() != nullptr && is_recursive) { | |||
| if (fg->parent() == node->func_graph() && unique_use) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| // {G, Xs} | |||
| class InlinerBase : public AnfVisitor { | |||
| public: | |||
| explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {} | |||
| explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true) | |||
| : use_move_(use_move), criterions_(criterions) {} | |||
| ~InlinerBase() override = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!node->isa<CNode>()) { | |||
| @@ -113,6 +125,7 @@ class InlinerBase : public AnfVisitor { | |||
| if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { | |||
| return nullptr; | |||
| } | |||
| // Do not inline GraphKernel to Cell. | |||
| if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| // If the GraphKernel only contains a return node, we make it inlined. | |||
| @@ -142,8 +155,12 @@ class InlinerBase : public AnfVisitor { | |||
| std::vector<AnfNodePtr> params; | |||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); | |||
| if (IsUniqueUse(fg, nullptr)) { | |||
| // compare size to avoid the case that the function has default value after grad. | |||
| // for which after renormalize, the function default value will be an input | |||
| if (fg->parameters().size() != params.size()) { | |||
| return nullptr; | |||
| } | |||
| if (use_move_ && IsUniqueUse(fg, nullptr)) { | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| ReplaceParams(mng, params, fg); | |||
| @@ -183,21 +200,36 @@ class InlinerBase : public AnfVisitor { | |||
| private: | |||
| bool is_checked_{false}, is_recursive_{false}; | |||
| bool use_move_; | |||
| std::vector<std::pair<CriterionFuncType, bool>> criterions_; | |||
| }; | |||
| class Inliner : public InlinerBase { | |||
| public: | |||
| Inliner() | |||
| : InlinerBase({ | |||
| {IsUniqueUse, true}, | |||
| {IsTrivial, false}, | |||
| {IsInside, false}, | |||
| {IsCore, false}, | |||
| {NoCriterion, true}, | |||
| }) {} | |||
| explicit Inliner(bool use_move = true) | |||
| : InlinerBase( | |||
| { | |||
| {IsUniqueUse, true}, | |||
| {IsTrivial, false}, | |||
| {IsInside, false}, | |||
| {IsCore, false}, | |||
| {IsDirectParentCall, false}, | |||
| {NoCriterion, true}, | |||
| }, | |||
| use_move) {} | |||
| ~Inliner() override = default; | |||
| }; | |||
| class DirectInliner : public InlinerBase { | |||
| public: | |||
| explicit DirectInliner(bool use_move = true) | |||
| : InlinerBase( | |||
| { | |||
| {IsDirectParentCall, false}, | |||
| }, | |||
| use_move) {} | |||
| ~DirectInliner() override = default; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -26,6 +26,30 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| namespace internal { | |||
| class GetRefValueTransform { | |||
| public: | |||
| GetRefValueTransform() {} | |||
| ~GetRefValueTransform() = default; | |||
| AnfNodePtr operator()(const AnfNodePtr &node) { | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| auto inputs = cnode->inputs(); | |||
| auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>(); | |||
| if (fg->recursive()) { | |||
| MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString(); | |||
| return node; | |||
| } | |||
| auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue")); | |||
| auto output = new_fg->output(); | |||
| new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output})); | |||
| inputs[0] = NewValueNode(new_fg); | |||
| auto ret_node = cnode->func_graph()->NewCNode(inputs); | |||
| return ret_node; | |||
| } | |||
| }; | |||
| } // namespace internal | |||
| // {prim::kPrimMakeRef, X, Y, Z} -> Y | |||
| class MakeRefEliminater : public OptimizerCaller { | |||
| public: | |||
| @@ -48,13 +72,23 @@ class GetRefParamEliminater : public OptimizerCaller { | |||
| // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X | |||
| // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y | |||
| // {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f} | |||
| class GetMakeRefEliminater : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| PatternNode<AnfNodePtr> x, y, z; | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node)); | |||
| internal::GetRefValueTransform trans; | |||
| auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr { | |||
| auto rep = trans(x.GetNode(node)); | |||
| if (rep != nullptr) { | |||
| return rep; | |||
| } | |||
| return nullptr; | |||
| }; | |||
| MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node)); | |||
| return nullptr; | |||
| } | |||
| }; | |||
| @@ -314,6 +314,7 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| std::string backend = MsContext::GetInstance()->backend_policy(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (CompileGraphs::ContainMixedTarget(func_graph)) { | |||
| bc_ptr->set_is_multi_graph_sink(false); | |||
| @@ -321,13 +322,13 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||
| context_ptr->set_loop_sink_flag(false); | |||
| } else if (context_ptr->execution_mode() != kPynativeMode) { | |||
| std::string device_target = context_ptr->device_target(); | |||
| if (device_target == kAscendDevice) { | |||
| if (device_target == kAscendDevice && backend != kMsVm) { | |||
| bc_ptr->set_is_multi_graph_sink(true); | |||
| context_ptr->set_is_multi_graph_sink(true); | |||
| } | |||
| } | |||
| if (IsCtrlSink()) { | |||
| if (IsCtrlSink() && backend == kMsConvert) { | |||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | |||
| return true; | |||
| } | |||
| @@ -344,8 +345,8 @@ bool ExecuteAction(const ResourcePtr &res) { | |||
| if (res->results().count(kOutput) == 0) { | |||
| MS_LOG(EXCEPTION) << "Execute args error"; | |||
| } | |||
| if (IsCtrlSink()) { | |||
| std::string backend = MsContext::GetInstance()->backend_policy(); | |||
| if (IsCtrlSink() && backend == kMsConvert) { | |||
| if (!res->results()[kOutput].is<GraphId>()) { | |||
| MS_LOG(EXCEPTION) << "Execute args error"; | |||
| } | |||
| @@ -30,6 +30,7 @@ | |||
| #include "pipeline/jit/parse/data_converter.h" | |||
| #include "pipeline/jit/resource.h" | |||
| #include "pipeline/jit/validator.h" | |||
| #include "pipeline/jit/remove_value_node_dup.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/cse.h" | |||
| #include "frontend/optimizer/graph_kernel_reuse.h" | |||
| @@ -128,11 +129,14 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.incorporate_getitem_set_, | |||
| irpass.incorporate_call_, | |||
| irpass.incorporate_call_switch_, | |||
| irpass.incorporate_env_getitem_, | |||
| irpass.incorporate_env_getitem_bypass_recursive_, | |||
| irpass.incorporate_env_getitem_switch_, | |||
| irpass.new_env_get_item_, | |||
| irpass.depend_value_elim_, | |||
| }); | |||
| opt::OptPassConfig a_after_grad = opt::OptPassConfig({ | |||
| irpass.inline_without_move_, | |||
| }); | |||
| opt::OptPassConfig a_3 = opt::OptPassConfig({ | |||
| irpass.arithmetic_simplify2_, | |||
| irpass.same_eliminate_, | |||
| @@ -155,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| {"virtual_dataset", virtual_dataset}, | |||
| {"grad", grad}, | |||
| {"resolve", resolve_pass}, | |||
| {"a_after_grad", a_after_grad}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | |||
| {"a_3", a_3}}); | |||
| @@ -162,11 +167,24 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return map_a; | |||
| } | |||
| OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig c_1 = opt::OptPassConfig({ | |||
| // Safe inlining | |||
| irpass.inline_, | |||
| irpass.partial_eliminate_, | |||
| }); | |||
| OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); | |||
| return map_a; | |||
| } | |||
| OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig b_1 = | |||
| opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | |||
| irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, | |||
| irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_}); | |||
| opt::OptPassConfig b_1 = opt::OptPassConfig( | |||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | |||
| irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | |||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | |||
| irpass.value_based_eliminate_}); | |||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | |||
| irpass.replace_refkey_by_param_, | |||
| irpass.make_ref_eliminate_, | |||
| @@ -245,6 +263,8 @@ void InitOpt(const ResourcePtr &res) { | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); | |||
| g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); | |||
| g_pass_opts["opt_after_cconv"] = | |||
| Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true); | |||
| g_pass_opts["opt_graph_kernel_a"] = | |||
| Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); | |||
| g_pass_opts["opt_graph_kernel_b"] = | |||
| @@ -289,6 +309,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { | |||
| bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } | |||
| bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } | |||
| bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } | |||
| bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } | |||
| bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } | |||
| bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } | |||
| @@ -312,6 +333,33 @@ bool AddControlDependPass(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| bool MergeDupGraphPass(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(res->manager()); | |||
| if (res->manager()->func_graphs().size() <= 1) { | |||
| return true; | |||
| } | |||
| return MergeDuplicateGraphs(res->manager()); | |||
| } | |||
| bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) { | |||
| if (res->func_graph() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Remove value node duplications error."; | |||
| } | |||
| auto manager = res->manager(); | |||
| HashCache hash_cache; | |||
| HashValue hashes; | |||
| // Remove duplicated value nodes across all graphs in manager | |||
| for (auto &fg : manager->func_graphs()) { | |||
| auto value_nodes = fg->value_nodes(); | |||
| for (const auto &value_pair : value_nodes) { | |||
| TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool CconvPass(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| @@ -341,6 +389,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru | |||
| {"clean_after_opta", CleanAfterOptAPass}, | |||
| {"opt_b", OptPassBGroup}, | |||
| {"cconv", CconvPass}, | |||
| {"opt_after_cconv", OptPassAfterCconvGroup}, | |||
| {"remove_dup_value", RemoveValueNodeDuplicationsPass}, | |||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | |||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | |||
| {"add_control_depend", AddControlDependPass}}; | |||
| @@ -13,6 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "pipeline/jit/remove_value_node_dup.h" | |||
| #include "ir/anf.h" | |||
| @@ -70,5 +71,108 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has | |||
| // Meet for the first time, append node to bucket. | |||
| bucket.emplace_back(node); | |||
| } | |||
| size_t HashOfGraph(const FuncGraphPtr &fg) { | |||
| std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return()); | |||
| MS_LOG(DEBUG) << "TopSort for:" << fg->ToString(); | |||
| std::unordered_map<AnfNodePtr, std::size_t> hashes; | |||
| auto ¶ms = fg->parameters(); | |||
| for (size_t i = 0; i < params.size(); i++) { | |||
| hashes[params[i]] = std::hash<std::string>{}("param" + std::to_string(i)); | |||
| } | |||
| for (auto node : toposet) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (hashes.find(node) != hashes.end()) { | |||
| continue; | |||
| } | |||
| std::size_t h = 0; | |||
| if (node->isa<ValueNode>()) { | |||
| ValueNodePtr value_node = node->cast<ValueNodePtr>(); | |||
| auto value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (IsValueNode<FuncGraph>(value_node)) { | |||
| auto v_fg = value->cast<FuncGraphPtr>(); | |||
| h = value->hash(); | |||
| } else if (IsValueNode<tensor::Tensor>(value_node)) { | |||
| // the tensor has same value has been replaced in duplicate value pass, | |||
| // so we use the value pointer here as an identifier | |||
| h = hash_combine(value->hash(), std::hash<Value *>{}(value.get())); | |||
| } else { | |||
| h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash())); | |||
| } | |||
| } else if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| size_t init = 0; | |||
| h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) { | |||
| return hash_combine(hash, hashes[node_in]); | |||
| }); | |||
| } else if (node->isa<Parameter>()) { | |||
| h = node->hash(); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unknow node type"; | |||
| } | |||
| hashes[node] = h; | |||
| } | |||
| return hashes[fg->get_return()]; | |||
| } | |||
| bool IsCNodeGraph(const AnfNodePtr &node) { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto inp0 = node->cast<CNodePtr>()->input(0); | |||
| return IsValueNode<FuncGraph>(inp0); | |||
| } | |||
| bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) { | |||
| std::unordered_map<size_t, std::vector<FuncGraphPtr>> hash_graphs; | |||
| std::unordered_map<FuncGraphPtr, size_t> graph_hash; | |||
| for (auto fg : manager->func_graphs()) { | |||
| size_t h = HashOfGraph(fg); | |||
| graph_hash[fg] = h; | |||
| if (hash_graphs.find(h) == hash_graphs.end()) { | |||
| hash_graphs[h] = {fg}; | |||
| } else { | |||
| hash_graphs[h].push_back(fg); | |||
| } | |||
| } | |||
| FuncGraphPairMapEquiv equiv_graph; | |||
| NodeMapEquiv equiv_node; | |||
| for (auto &fg : manager->func_graphs()) { | |||
| MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString(); | |||
| for (auto &item : fg->nodes()) { | |||
| if (!item->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto &inputs = item->cast<CNodePtr>()->inputs(); | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| if (!inputs[i]->isa<ValueNode>()) { | |||
| continue; | |||
| } | |||
| auto value_ptr = GetValueNode(inputs[i]); | |||
| auto v_fg = value_ptr->cast<FuncGraphPtr>(); | |||
| if (v_fg == nullptr) { | |||
| continue; | |||
| } | |||
| auto &fg_vec = hash_graphs[graph_hash[v_fg]]; | |||
| if (fg_vec.size() > 1) { | |||
| if (v_fg != fg_vec[0]) { | |||
| bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node); | |||
| if (is_morphic) { | |||
| auto new_node = NewValueNode(fg_vec[0]); | |||
| MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString(); | |||
| manager->Replace(inputs[i], new_node); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace pipeline | |||
| } // namespace mindspore | |||
| @@ -28,6 +28,10 @@ using HashCache = std::unordered_map<std::size_t, std::vector<AnfNodePtr>>; | |||
| using HashValue = std::unordered_map<AnfNodePtr, std::size_t>; | |||
| void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); | |||
| size_t HashOfGraph(const FuncGraphPtr &fg); | |||
| bool IsCNodeGraph(const AnfNodePtr &node); | |||
| bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager); | |||
| } // namespace pipeline | |||
| } // namespace mindspore | |||
| @@ -113,17 +113,18 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| } | |||
| const AnfNodePtr &func_node = fg->get_return(); | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() | |||
| << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); | |||
| AbstractBasePtr ret_base = nullptr; | |||
| std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); | |||
| for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { | |||
| const auto &node = *it; | |||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); | |||
| MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); | |||
| MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() | |||
| << ", node_conf: " << node_conf->ToString(); | |||
| ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); | |||
| MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() | |||
| << ", abstract: " << ret_base->ToString(); | |||
| MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() | |||
| << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(ret_base); | |||
| @@ -142,16 +143,17 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| return arg->Broaden(); | |||
| if (arg->GetValueTrack() != kAnyValue) { | |||
| return arg->Broaden(); | |||
| } | |||
| return arg; | |||
| }); | |||
| if (func_graph_->joined_shapes_.size() != broaded_list.size()) { | |||
| MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size() | |||
| << " does not equal to number of original buffer arguments " | |||
| << func_graph_->joined_shapes_.size(); | |||
| } | |||
| for (size_t i = 0; i < broaded_list.size(); ++i) { | |||
| broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); | |||
| if (func_graph_->joined_shapes_.size() == broaded_list.size()) { | |||
| for (size_t i = 0; i < broaded_list.size(); ++i) { | |||
| broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) | |||
| << ", broaded: " << mindspore::ToString(broaded_list); | |||
| return broaded_list; | |||
| @@ -181,8 +183,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->joined_shapes_.clear(); | |||
| std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), | |||
| std::back_inserter(func_graph_->joined_shapes_), | |||
| [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); | |||
| std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { | |||
| if (arg_spec->isa<AbstractRef>()) { | |||
| return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack(); | |||
| } | |||
| return arg_spec->GetShapeTrack(); | |||
| }); | |||
| joined_args_spec_list = NormalizeArgs(joined_args_spec_list); | |||
| MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; | |||
| } | |||
| return joined_args_spec_list; | |||
| @@ -199,8 +206,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->joined_shapes_.clear(); | |||
| std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), | |||
| std::back_inserter(func_graph_->joined_shapes_), | |||
| [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); | |||
| std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { | |||
| if (arg_spec->isa<AbstractRef>()) { | |||
| return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack(); | |||
| } | |||
| return arg_spec->GetShapeTrack(); | |||
| }); | |||
| joined_args_spec_list = NormalizeArgs(joined_args_spec_list); | |||
| MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; | |||
| } | |||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); | |||
| @@ -188,6 +188,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { | |||
| trace::TraceEvalCNodeLeave(); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() | |||
| << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph") | |||
| << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); | |||
| } | |||
| @@ -301,6 +302,8 @@ void AnalysisEngine::Clear() { | |||
| anfnode_config_map_.clear(); | |||
| eval_trace_.clear(); | |||
| constructors_.clear(); | |||
| constructors_app_.clear(); | |||
| continued_evals_.clear(); | |||
| } | |||
| namespace { | |||
| @@ -426,8 +429,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstr | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| AbstractFunctionPtr func_orig = func->fn(); | |||
| EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); | |||
| auto part_pair = std::make_pair(func_orig, func->args()); | |||
| auto itr = constructors_app_.find(part_pair); | |||
| if (itr != constructors_app_.end()) { | |||
| return itr->second; | |||
| } | |||
| std::shared_ptr<PartialAppEvaluator> partial_evaluator = | |||
| std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args()); | |||
| constructors_app_[part_pair] = partial_evaluator; | |||
| return partial_evaluator; | |||
| } | |||
| @@ -504,9 +513,10 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { | |||
| if (fg_eval == nullptr) { | |||
| return; | |||
| } | |||
| auto fg = fg_eval->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto undetermined_fgs = fg->recursive_graphs(); | |||
| auto undetermined_fgs = fg->recursive(); | |||
| if (undetermined_fgs) { | |||
| auto fg_parent = fg->parent(); | |||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||
| @@ -546,15 +556,19 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt | |||
| MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); | |||
| for (auto u_eval : undetermined_evals) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; | |||
| if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined."; | |||
| auto &alternate_evaluator = multi_poss_[u_eval.first]; | |||
| auto &eval_cache = alternate_evaluator->cache(); | |||
| if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) && | |||
| (((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) || | |||
| (eval_cache->find(args_spec_list) == eval_cache->end()))) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined."; | |||
| has_undetermined = true; | |||
| break; | |||
| } | |||
| } | |||
| if (has_undetermined == false) { | |||
| MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; | |||
| MS_LOG(DEBUG) << eval->ToString() << "has no undetermined."; | |||
| *continue_flag = true; | |||
| return latest_entry; | |||
| } | |||
| @@ -597,34 +611,33 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||
| auto current_inf = std::make_pair(eval, args_spec_list); | |||
| MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | |||
| // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. | |||
| auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); | |||
| if (it == eval_trace_.rend()) { | |||
| eval_trace_.push_back(current_inf); | |||
| MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); | |||
| MS_EXCEPTION_IF_NULL(eval); | |||
| auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(eval_result->abstract()); | |||
| MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); | |||
| out_specs.push_back(eval_result->abstract()); | |||
| eval_trace_.pop_back(); | |||
| if (eval_trace_.empty()) { | |||
| multi_poss_.clear(); | |||
| } | |||
| } else if (it != eval_trace_.rbegin()) { | |||
| } else { | |||
| bool continue_flag = false; | |||
| auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); | |||
| if (continue_flag) { | |||
| MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.first.get() << current_inf.first->ToString(); | |||
| continued_evals_.insert(current_inf); | |||
| continue; | |||
| } | |||
| // Try to travel the latest undetermined. | |||
| if (latest_entry != eval_trace_.rbegin()->first) { | |||
| MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); | |||
| MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString(); | |||
| auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(eval_result->abstract()); | |||
| MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() | |||
| MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString() | |||
| << " return out_spec: " << eval_result->abstract()->ToString(); | |||
| return eval_result; | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <map> | |||
| #include <set> | |||
| #ifdef DEBUG | |||
| #include <stack> | |||
| @@ -113,7 +114,8 @@ class AnfNodeConfig : public Config { | |||
| std::string ToString() const override { | |||
| std::ostringstream buffer; | |||
| buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); | |||
| buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId() | |||
| << "), Context: " << context_->ToString(); | |||
| return buffer.str(); | |||
| } | |||
| @@ -173,7 +175,13 @@ struct AnalysisResult { | |||
| }; | |||
| using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator; | |||
| struct PartialAppHasher { | |||
| std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const { | |||
| auto h1 = std::hash<AbstractFunctionPtr>{}(p.first); | |||
| auto h2 = AbstractBasePtrListHash(p.second); | |||
| return h1 ^ h2; | |||
| } | |||
| }; | |||
| class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| public: | |||
| AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) | |||
| @@ -233,10 +241,13 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| const PrimEvaluatorMap &prim_constructors_; | |||
| FuncGraphManagerPtr func_graph_manager_; | |||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_; | |||
| std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher> | |||
| constructors_app_; | |||
| AnfNodeConfigMap anfnode_config_map_; | |||
| // Use a list to trace multiple evaluators. | |||
| std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; | |||
| std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_; | |||
| std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> continued_evals_; | |||
| AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | |||
| const ConfigPtrList &args_conf_list); | |||
| @@ -34,6 +34,7 @@ using mindspore::abstract::AbstractError; | |||
| using mindspore::abstract::AbstractFunction; | |||
| using mindspore::abstract::AbstractJTagged; | |||
| using mindspore::abstract::AbstractList; | |||
| using mindspore::abstract::AbstractRef; | |||
| using mindspore::abstract::AbstractRowTensor; | |||
| using mindspore::abstract::AbstractScalar; | |||
| using mindspore::abstract::AbstractSparseTensor; | |||
| @@ -83,7 +84,8 @@ void ValidateAbstract(const AnfNodePtr &node) { | |||
| // only send string in external | |||
| if (!IsValueNode<StringImm>(node)) { | |||
| // Validate a type. | |||
| MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); | |||
| MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() | |||
| << " for node=" << node->DebugString(); | |||
| } | |||
| } | |||
| return; | |||
| @@ -96,7 +98,7 @@ void ValidateAbstract(const AnfNodePtr &node) { | |||
| if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || | |||
| ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() || | |||
| ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) { | |||
| ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>()) { | |||
| return; | |||
| } | |||
| @@ -481,8 +481,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple | |||
| } | |||
| // Isomorphism | |||
| static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, | |||
| NodeMapEquiv *const equiv_node) { | |||
| static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, | |||
| NodeMapEquiv *const equiv_node); | |||
| bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, | |||
| NodeMapEquiv *const equiv_node) { | |||
| if (equiv_node == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid equiv_node"; | |||
| return false; | |||
| @@ -514,6 +516,9 @@ static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, Fu | |||
| MS_LOG(DEBUG) << "two parameters are not equal."; | |||
| return false; | |||
| } | |||
| if (node1->isa<CNode>() && node2->isa<CNode>()) { | |||
| return SameNode(node1, node2, equiv_func_graph, equiv_node); | |||
| } | |||
| MS_LOG(ERROR) << "type error"; | |||
| return false; | |||
| } | |||
| @@ -116,12 +116,15 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo | |||
| } // namespace | |||
| std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { | |||
| auto fg = std::make_shared<FuncGraph>(); | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrToAnfNodePtrMap eqv; | |||
| if (lst.empty()) { | |||
| MS_LOG(EXCEPTION) << "Input anf node list is empty"; | |||
| } | |||
| TraceManager::DebugTrace( | |||
| std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info())); | |||
| auto fg = std::make_shared<FuncGraph>(); | |||
| TraceManager::EndTrace(); | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrToAnfNodePtrMap eqv; | |||
| // Merge CNodes into a AnfGraph that represents a linear instruction segment | |||
| for (auto n : lst) { | |||
| if (!n->isa<CNode>()) { | |||
| @@ -154,7 +157,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr | |||
| (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), | |||
| [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); | |||
| } | |||
| TraceManager::DebugTrace(std::make_shared<TraceGetEnv>(n->debug_info())); | |||
| eqv[n] = fg->NewCNode(args); | |||
| TraceManager::EndTrace(); | |||
| eqv[n]->set_abstract(n->abstract()); | |||
| eqv[n]->set_kernel_info(n->kernel_info_ptr()); | |||
| } | |||
| @@ -452,6 +452,10 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||
| } | |||
| auto other_tensor = dyn_cast<AbstractTensor>(other); | |||
| if (other_tensor == nullptr) { | |||
| auto ref_tensor = dyn_cast<AbstractRef>(other); | |||
| if (ref_tensor != nullptr) { | |||
| return this->Join(ref_tensor->ref()); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); | |||
| } | |||
| if (*this == *other) { | |||
| @@ -48,7 +48,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c | |||
| continue; | |||
| } | |||
| if (rank.find(node) != rank.end() && rank[node] != todo.size()) { | |||
| MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(); | |||
| MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2); | |||
| } | |||
| rank[node] = todo.size(); | |||
| bool cont = false; | |||
| @@ -30,6 +30,7 @@ | |||
| #include "base/base.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/dtype/number.h" | |||
| #include "utils/hashing.h" | |||
| using std::fabs; | |||
| @@ -51,7 +52,7 @@ using ScalarPtr = std::shared_ptr<Scalar>; | |||
| class BoolImm : public Scalar { | |||
| public: | |||
| explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash<bool>{}(v_); } | |||
| explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); } | |||
| ~BoolImm() override = default; | |||
| MS_DECLARE_PARENT(BoolImm, Scalar) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -91,7 +92,7 @@ class IntergerImm : public Scalar { | |||
| class Int8Imm : public IntergerImm { | |||
| public: | |||
| Int8Imm() : IntergerImm(kInt8), v_(0) {} | |||
| explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash<int>{}(v_); } | |||
| explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); } | |||
| ~Int8Imm() override = default; | |||
| MS_DECLARE_PARENT(Int8Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -117,7 +118,7 @@ IMM_TRAITS(Int8ImmPtr, int8_t) | |||
| class Int16Imm : public IntergerImm { | |||
| public: | |||
| Int16Imm() : IntergerImm(kInt16), v_(0) {} | |||
| explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash<int>{}(v_); } | |||
| explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); } | |||
| ~Int16Imm() override = default; | |||
| MS_DECLARE_PARENT(Int16Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -143,7 +144,7 @@ IMM_TRAITS(Int16ImmPtr, int16_t) | |||
| class Int32Imm : public IntergerImm { | |||
| public: | |||
| Int32Imm() : IntergerImm(kInt32), v_(0) {} | |||
| explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash<int>{}(v_); } | |||
| explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); } | |||
| ~Int32Imm() override = default; | |||
| MS_DECLARE_PARENT(Int32Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -169,7 +170,7 @@ IMM_TRAITS(Int32ImmPtr, int32_t) | |||
| class Int64Imm : public IntergerImm { | |||
| public: | |||
| Int64Imm() : IntergerImm(kInt64), v_(0) {} | |||
| explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash<int64_t>{}(v_); } | |||
| explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); } | |||
| ~Int64Imm() override = default; | |||
| MS_DECLARE_PARENT(Int64Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -195,7 +196,9 @@ IMM_TRAITS(Int64ImmPtr, int64_t) | |||
| class UInt8Imm : public IntergerImm { | |||
| public: | |||
| UInt8Imm() : IntergerImm(kUInt8), v_(0) {} | |||
| explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash<unsigned int>{}(v_); } | |||
| explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { | |||
| hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)}); | |||
| } | |||
| ~UInt8Imm() override = default; | |||
| MS_DECLARE_PARENT(UInt8Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -221,7 +224,9 @@ IMM_TRAITS(UInt8ImmPtr, uint8_t); | |||
| class UInt16Imm : public IntergerImm { | |||
| public: | |||
| UInt16Imm() : IntergerImm(kUInt16), v_(0) {} | |||
| explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash<unsigned int>{}(v_); } | |||
| explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { | |||
| hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)}); | |||
| } | |||
| ~UInt16Imm() override = default; | |||
| MS_DECLARE_PARENT(UInt16Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -247,7 +252,9 @@ IMM_TRAITS(UInt16ImmPtr, uint16_t); | |||
| class UInt32Imm : public IntergerImm { | |||
| public: | |||
| UInt32Imm() : IntergerImm(kUInt32), v_(0) {} | |||
| explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash<unsigned int>{}(v_); } | |||
| explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { | |||
| hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)}); | |||
| } | |||
| ~UInt32Imm() override = default; | |||
| MS_DECLARE_PARENT(UInt32Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -273,7 +280,9 @@ IMM_TRAITS(UInt32ImmPtr, uint32_t); | |||
| class UInt64Imm : public IntergerImm { | |||
| public: | |||
| UInt64Imm() : IntergerImm(kUInt64), v_(0) {} | |||
| explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash<uint64_t>{}(v); } | |||
| explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { | |||
| hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)}); | |||
| } | |||
| ~UInt64Imm() override = default; | |||
| MS_DECLARE_PARENT(UInt64Imm, IntergerImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -308,7 +317,7 @@ using FloatImmPtr = std::shared_ptr<FloatImm>; | |||
| class FP32Imm : public FloatImm { | |||
| public: | |||
| FP32Imm() : FloatImm(kFloat32), v_(0.0) {} | |||
| explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash<float>{}(v_); } | |||
| explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); } | |||
| ~FP32Imm() override = default; | |||
| MS_DECLARE_PARENT(FP32Imm, FloatImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -334,7 +343,7 @@ IMM_TRAITS(FP32ImmPtr, float) | |||
| class FP64Imm : public FloatImm { | |||
| public: | |||
| FP64Imm() : FloatImm(kFloat64), v_(0.0) {} | |||
| explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash<double>{}(v_); } | |||
| explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); } | |||
| ~FP64Imm() override = default; | |||
| MS_DECLARE_PARENT(FP64Imm, FloatImm) | |||
| std::size_t hash() const override { return hash_; } | |||
| @@ -412,6 +412,16 @@ class TraceCombileLikeGraphs : public TraceInfo { | |||
| return std::make_shared<TraceCombileLikeGraphs>(*shared_from_base<TraceCombileLikeGraphs>()); | |||
| } | |||
| }; | |||
| class TraceSegmentTransform : public TraceInfo { | |||
| public: | |||
| explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {} | |||
| MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); | |||
| ~TraceSegmentTransform() override = default; | |||
| TraceInfoPtr clone() override { | |||
| return std::make_shared<TraceSegmentTransform>(*shared_from_base<TraceSegmentTransform>()); | |||
| } | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_ | |||
| @@ -0,0 +1,816 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test control ops """ | |||
| import numpy as np | |||
| from mindspore import dtype as ms | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore import nn | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| # from tests.vm_impl.math_ops_vm_impl import * | |||
| # from tests.vm_impl.vm_interface import * | |||
| # from tests.vm_impl import * | |||
| # context.set_context(save_graphs=True) | |||
| def test_while_forward(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| def construct(self, idx, end, x): | |||
| while idx < end: | |||
| part = x[idx, :, :] | |||
| max_num = self.max(part) | |||
| x[idx, :, 0:2] = max_num | |||
| idx = idx + 1 | |||
| return x | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| net = MyWhileNet() | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(2), dtype=ms.int32) | |||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_grad(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| def construct(self, idx, end, x): | |||
| while idx < end: | |||
| part = x[idx, :, :] | |||
| max_num = self.max(part) | |||
| x[idx, :, 0:2] = max_num | |||
| idx = idx + 1 | |||
| return x | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_all(self.net)(*inputs) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(2), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_forward(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| part = x[idx, :, :] | |||
| max_num = self.max(part) | |||
| x[idx, :, 0:2] = max_num | |||
| out = out + x + self.param | |||
| idx = idx + 1 | |||
| return out | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| net = MyWhileNet() | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(2), dtype=ms.int32) | |||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_endless_case(): | |||
| """endless case when optmization""" | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| part = x[idx, :, :] | |||
| out = out + part | |||
| idx = idx + 1 | |||
| return out | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| net = MyWhileNet() | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(2), dtype=ms.int32) | |||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_grad(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| part = x[idx, :, :] | |||
| max_num = self.max(part) | |||
| x[idx, :, 0:2] = max_num | |||
| out = out + x + self.param | |||
| idx = idx + 1 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(2), dtype=ms.int32) | |||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_forward_with_const_branch(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.reduce = P.ReduceSum() | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| if 2 > 1: | |||
| out = out + self.param | |||
| else: | |||
| out = out + idx + self.param | |||
| idx = idx + 1 | |||
| return out | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = while_net | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_opt_endless(): | |||
| """endless during optimization case""" | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.reduce = P.ReduceSum() | |||
| self.addn = P.AddN() | |||
| def construct(self, idx, end, x): | |||
| addn1 = self.addn((x, x, x)) | |||
| out = addn1 | |||
| while idx < end: | |||
| out = self.addn((out, addn1)) | |||
| idx = idx + 1 | |||
| out = self.addn((out, x)) | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_all(self.net)(*inputs) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_no_while_call(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.reduce = P.ReduceSum() | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| if 2 > 1: | |||
| out = out + self.param | |||
| else: | |||
| out = out + idx + self.param | |||
| return out | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = while_net | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_grad_with_const_branch(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.reduce = P.ReduceSum() | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| if 2 > 1: | |||
| out = out + self.param | |||
| else: | |||
| out = out + idx + self.param | |||
| idx = idx + 1 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_for_while_with_param_grad_with_const_branch(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.reduce = P.ReduceSum() | |||
| self.start = Tensor(np.array(0), dtype=ms.int32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| for _ in range(0, 2): | |||
| idx = self.start | |||
| while idx < end: | |||
| if 2 > 1: | |||
| out = out + self.param | |||
| else: | |||
| out = out + idx + self.param | |||
| idx = idx + 1 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_for_while_with_param_grad_basic(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.reduce = P.ReduceSum() | |||
| self.start = Tensor(np.array(0), dtype=ms.int32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| for _ in range(0, 2): | |||
| idx = self.start | |||
| while idx < end: | |||
| out = out + self.param | |||
| idx = idx + 1 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_for_while_with_param_grad_normal(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.reduce = P.ReduceSum() | |||
| self.start = Tensor(np.array(0), dtype=ms.int32) | |||
| def construct(self, idx, end, x): | |||
| out = x | |||
| for _ in range(0, 2): | |||
| idx = self.start | |||
| while idx < end: | |||
| out = out + self.param | |||
| idx = idx + 1 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_basic_grad(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.t2 = Tensor(np.array(2), dtype=ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| out = out + self.param | |||
| idx = idx + 1 | |||
| return out + self.param | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(3), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_basic_grad_mul(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32) | |||
| self.t2 = Tensor(np.array(2), dtype=ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| out = out * self.param | |||
| idx = idx + 1 | |||
| return out + self.param | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(3), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_basic_grad_two(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.t2 = Tensor(np.array(2), dtype=ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| out = out + self.param + self.weight | |||
| idx = idx + 1 | |||
| return out + self.param | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(3), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_basic_grad_three(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") | |||
| self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.t2 = Tensor(np.array(2), dtype=ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| out = out + self.param + self.weight + self.key | |||
| idx = idx + 1 | |||
| return out + self.param | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(3), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_if_with_param_grad(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| self.t2 = Tensor(np.array(2), dtype=ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| if self.max(out) < self.max(x): | |||
| out = out + self.param * 2 | |||
| else: | |||
| out = out + self.param | |||
| idx = idx + 1 | |||
| return out + self.param | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(3), dtype=ms.int32) | |||
| x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_while_with_param_grad_not_enter_while(): | |||
| class MyWhileNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, idx, end, x): | |||
| out = self.zero | |||
| while idx < end: | |||
| out = out + self.param * 3 | |||
| idx = idx + 1 | |||
| return out + self.param | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, a, b, c): | |||
| return C.grad_by_list(self.net, self.weights)(a, b, c) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| while_net = MyWhileNet() | |||
| net = GradNet(while_net) | |||
| idx = Tensor(np.array(3), dtype=ms.int32) | |||
| end = Tensor(np.array(0), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_with_param_if_by_if_forward(): | |||
| class MyIfByIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, a, b, x): | |||
| out = self.zero | |||
| if a < b: | |||
| out = out + x + self.param | |||
| else: | |||
| out = out + x | |||
| if a == b: | |||
| out = out + x*3 + self.param | |||
| else: | |||
| out = out + x*2 | |||
| return out | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| if_net = MyIfByIfNet() | |||
| net = if_net | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(4), dtype=ms.int32) | |||
| x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_with_param_if_by_if_grad_inputs(): | |||
| class MyIfByIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, a, b, x): | |||
| out = self.zero | |||
| if a < b: | |||
| out = out + x + self.param * 4 | |||
| if a == b: | |||
| out = out + x*3 + self.param * 3 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, *inputs): | |||
| return C.grad_all(self.net)(*inputs) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(0), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_with_param_if_by_if_grad_parameter(): | |||
| class MyIfByIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, a, b, x): | |||
| out = self.zero | |||
| if a < b: | |||
| out = out + x + self.param * 2 | |||
| if a == b: | |||
| out = out + x*3 + self.param | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||
| end = Tensor(np.array(2), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_with_param_if_by_if_grad_param_excute_null(): | |||
| class MyIfByIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, a, b, x): | |||
| out = self.zero | |||
| if a < b: | |||
| out = out + x + self.param * 2 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||
| idx = Tensor(np.array(4), dtype=ms.int32) | |||
| end = Tensor(np.array(0), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_if_by_if_return_inside_grad(): | |||
| class MyIfByIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.max = P.ReduceMax() | |||
| self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") | |||
| self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) | |||
| def construct(self, a, b, x): | |||
| out = self.zero | |||
| if a < b: | |||
| return out + x + self.param | |||
| if a == b: | |||
| return out + self.param * 2 | |||
| return out + self.param * 3 | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, *inputs): | |||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| if_net = MyIfByIfNet() | |||
| net = GradNet(if_net) | |||
| idx = Tensor(np.array(1), dtype=ms.int32) | |||
| end = Tensor(np.array(0), dtype=ms.int32) | |||
| x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| def test_if_by_if_forward(): | |||
| class MyIfByIfNet(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.add = P.TensorAdd() | |||
| self.sub = P.Sub() | |||
| self.mul = P.Mul() | |||
| self.div = P.RealDiv() | |||
| def construct(self, a, b, x): | |||
| if a < b: | |||
| a = self.add(a, b) | |||
| else: | |||
| a = self.sub(a, b) | |||
| if a == x: | |||
| a = self.mul(a, b) | |||
| else: | |||
| a = self.div(a, b) | |||
| if b == x: | |||
| b = self.add(a, b) | |||
| else: | |||
| b = self.add(a, x) | |||
| a = a * b | |||
| out = a + b + x | |||
| return out | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| if_net = MyIfByIfNet() | |||
| net = if_net | |||
| idx = Tensor(np.array(2), dtype=ms.float32) | |||
| end = Tensor(np.array(3), dtype=ms.float32) | |||
| x = Tensor(np.array(4), dtype=ms.float32) | |||
| net(idx, end, x) | |||
| @@ -58,6 +58,7 @@ add_subdirectory(serving) | |||
| file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "../../../mindspore/core/base/*.cc" | |||
| "../../../mindspore/core/gvar/*.cc" | |||
| "../../../mindspore/core/abstract/*.cc" | |||
| "../../../mindspore/core/ir/*.cc" | |||
| "../../../mindspore/core/utils/*.cc" | |||
| @@ -34,7 +34,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ | |||
| import pipeline_for_compile_grad_ge_graph_for_case_by_case_config | |||
| class InputBackward(nn.Cell): | |||
| def __init__(self, network): | |||
| super(InputBackward, self).__init__() | |||
| @@ -13,7 +13,6 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| CURRPATH=$(cd $(dirname $0); pwd) | |||
| IGNORE_EXEC="--ignore=$CURRPATH/exec" | |||
| PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd) | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """Generate vm_impl function for array ops""" | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| @@ -22,7 +21,6 @@ from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters | |||
| from .vm_interface import vm | |||
| # pylint: disable=unused-argument | |||
| @@ -181,8 +179,7 @@ def vm_impl_tile(self): | |||
| def vm_impl(x, multiples): | |||
| x = x.asnumpy() | |||
| multiples = multiples.asnumpy() | |||
| out = vm.Tile(x, multiples) | |||
| out = np.tile(x, multiples) | |||
| return Tensor(out) | |||
| return vm_impl | |||
| @@ -255,7 +252,10 @@ def vm_impl_sum(self): | |||
| def vm_impl(x, axis): | |||
| x = x.asnumpy() | |||
| out = vm.sum(x, axis) | |||
| if axis == (): | |||
| out = np.sum(x) | |||
| else: | |||
| out = np.sum(x, axis=axis) | |||
| return Tensor(np.array(out)) | |||
| return vm_impl | |||
| @@ -291,12 +291,14 @@ def vm_impl_square(self): | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.ZerosLike) | |||
| def vm_impl_zeros_like(self): | |||
| """Generate vm_impl function for ZerosLike""" | |||
| def vm_impl(x): | |||
| return Tensor(np.zeros_like(x.asnumpy())) | |||
| @vm_impl_getters.register(P.Partial) | |||
| def vm_impl_partial(self): | |||
| """Generate vm_impl function for Partial""" | |||
| @@ -307,6 +309,7 @@ def vm_impl_partial(self): | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.Depend) | |||
| def vm_impl_depend(self): | |||
| """Generate vm_impl function for Depend""" | |||
| @@ -196,6 +196,18 @@ def vm_impl_reduce_mean(self): | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.ReduceMax) | |||
| def vm_impl_reduce_max(self): | |||
| """Generate vm_impl function for ReduceMean.""" | |||
| def vm_impl(x, axis): | |||
| x = x.asnumpy() | |||
| if axis == (): | |||
| axis = None | |||
| out = np.amax(x, axis) | |||
| return Tensor(out) | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.Equal) | |||
| def vm_impl_equal(self): | |||