diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 4d6edd18cb..4f8493ca7b 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -330,7 +330,8 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_gra } oss << "SymInst(%para" << idx << ")"; } else { - MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString(); + MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString(); + oss << "SymInst(cnode_" << sym_node->ToString() << ")"; } return oss.str(); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index d4fe201710..452f1800f3 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -189,6 +189,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { if (!morph->isa()) { return nullptr; } + // for free variable, which may be handled in MapValueObject, just return it + auto node_adjoint_found = anfnode_to_adjoin_.find(morph); + if (node_adjoint_found != anfnode_to_adjoin_.end()) { + return node_adjoint_found->second; + } ScopeGuard scope_guard(morph->scope()); auto cnode_morph = morph->cast(); @@ -502,7 +507,7 @@ void DFunctor::MapFvObject() { if (parent_adjoint != nullptr) { adjoint = std::make_shared(node, parent_adjoint->k(), tape_); } else { - if (is_top_ || node->isa() || !IsInScope(node)) { + if (is_top_ || node->isa()) { // Out of ad scope, add adjoint for free variables. adjoint = std::make_shared(node, node, tape_); UpdateAdjoint(adjoint); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index b41c3081b4..dfef764b8c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -87,10 +87,12 @@ OptimizeIRPassLib::OptimizeIRPassLib() { env_get_item_eliminate_ = MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_ = - MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_bypass_recursive_ = + MakeSubstitution(std::make_shared(true), "incorporate_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); + incorporate_env_getitem_ = + MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); // Ref eliminate make_ref_eliminate_ = @@ -122,6 +124,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // inline inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); + inline_without_move_ = MakeSubstitution(std::make_shared(false), "inline", IsCNodeGraph); replace_applicator_ = MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); specialize_transform_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 5a0f2ed5b7..9a9a1e7a74 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -55,6 +55,7 @@ class OptimizeIRPassLib { SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr new_env_get_item_; SubstitutionPtr incorporate_env_getitem_; + SubstitutionPtr incorporate_env_getitem_bypass_recursive_; SubstitutionPtr incorporate_env_getitem_switch_; // Ref eliminate @@ -80,6 +81,7 @@ class OptimizeIRPassLib { // inline SubstitutionPtr inline_; + SubstitutionPtr inline_without_move_; SubstitutionPtr replace_applicator_; SubstitutionPtr specialize_transform_; @@ -193,6 +195,16 @@ inline bool IsCNodeDup(const AnfNodePtr &node) { auto inp0 = node->cast()->input(0); return (inp0 != nullptr) && inp0->isa(); } + +// check if the cnode is a switch cnode +inline bool IsCNodeSwitch(const AnfNodePtr &node) { + if (node != nullptr) { + if (node->isa()) { + return IsPrimitiveCNode(node, prim::kPrimSwitch); + } + } + return false; +} } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index 1fee007a88..6fa1304dd7 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -29,6 +29,7 @@ #include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/inline.h" #include "frontend/optimizer/optimizer.h" #include "utils/symbolic.h" @@ -59,8 +60,13 @@ class EnvGetitemTransform { while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { // {prim::kPrimEnvSetItem, env, symbolickey, value} auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; + if (inputs.size() != 4) { + MS_LOG(WARNING) << "Input size should be 4"; + return nullptr; + } + if (!IsValueNode(inputs[2])) { + MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; + return nullptr; } env = inputs[1]; @@ -91,33 +97,12 @@ class EnvGetitemTransform { class NewEnvGetItem : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - auto gety = [this](const AnfNodePtr &node) -> bool { - this->y_ = node; - return true; - }; - - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode, IsVNode, gety})(node); - if (env_ != nullptr && env_->Len() == 0) { - return y_; - } + PatternNode c1, c2, y; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y, + (IsValueNode(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) && + (GetValueNode(c1.GetNode(node)))->Len() == 0)); return nullptr; } - - void Visit(const ValueNodePtr &vnode) override { - if (env_ == nullptr) { - env_ = GetValueNode(vnode); - } - } - - void Reset() { - y_ = nullptr; - env_ = nullptr; - } - - private: - AnfNodePtr y_{nullptr}; - EnvInstancePtr env_{nullptr}; }; // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> @@ -205,8 +190,13 @@ class EnvGetSetItem : public AnfVisitor { while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { // {prim::kPrimEnvSetItem, env, symbolickey, value} auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; + if (inputs.size() != 4) { + MS_LOG(WARNING) << "Input size should be 4"; + return nullptr; + } + if (!IsValueNode(inputs[2])) { + MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; + return nullptr; } env = inputs[1]; @@ -257,7 +247,8 @@ class EnvGetItemEliminater : public OptimizerCaller { // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} class IncorporateEnvGetitem : public AnfVisitor { public: - IncorporateEnvGetitem() : env_get_item_transform_() {} + explicit IncorporateEnvGetitem(bool bypass_recursive = false) + : env_get_item_transform_(), bypass_recursive_(bypass_recursive) {} ~IncorporateEnvGetitem() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -285,7 +276,13 @@ class IncorporateEnvGetitem : public AnfVisitor { auto inputs = inp1->inputs(); auto fg = GetValueNode(inputs[0]); auto new_fg = env_get_item_transform_(fg, key, default_v); - + if (fg->recursive() && bypass_recursive_) { + MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString(); + return nullptr; + } + if (new_fg == nullptr) { + return nullptr; + } std::vector args; args.push_back(NewValueNode(new_fg)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); @@ -298,6 +295,7 @@ class IncorporateEnvGetitem : public AnfVisitor { private: bool is_match_{false}; internal::EnvGetitemTransform env_get_item_transform_; + bool bypass_recursive_; }; // {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} @@ -342,7 +340,9 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { auto g2 = GetValueNode(sw->input(3)); auto new_g1 = env_get_item_transform_(g1, key, default_v); auto new_g2 = env_get_item_transform_(g2, key, default_v); - + if (new_g1 == nullptr || new_g2 == nullptr) { + return nullptr; + } auto fg = node->func_graph(); auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 0be228f44b..ebe4cbe5e5 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -93,10 +93,22 @@ bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } +bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) { + bool unique_use = IsUniqueUse(fg, nullptr); + bool is_recursive = fg->recursive(); + if (fg->parent() != nullptr && is_recursive) { + if (fg->parent() == node->func_graph() && unique_use) { + return true; + } + } + return false; +} + // {G, Xs} class InlinerBase : public AnfVisitor { public: - explicit InlinerBase(std::vector> criterions) : criterions_(criterions) {} + explicit InlinerBase(std::vector> criterions, bool use_move = true) + : use_move_(use_move), criterions_(criterions) {} ~InlinerBase() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa()) { @@ -113,6 +125,7 @@ class InlinerBase : public AnfVisitor { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { return nullptr; } + // Do not inline GraphKernel to Cell. if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { // If the GraphKernel only contains a return node, we make it inlined. @@ -142,8 +155,12 @@ class InlinerBase : public AnfVisitor { std::vector params; (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); - - if (IsUniqueUse(fg, nullptr)) { + // compare size to avoid the case that the function has default value after grad. + // for which after renormalize, the function default value will be an input + if (fg->parameters().size() != params.size()) { + return nullptr; + } + if (use_move_ && IsUniqueUse(fg, nullptr)) { auto mng = fg->manager(); MS_EXCEPTION_IF_NULL(mng); ReplaceParams(mng, params, fg); @@ -183,21 +200,36 @@ class InlinerBase : public AnfVisitor { private: bool is_checked_{false}, is_recursive_{false}; + bool use_move_; std::vector> criterions_; }; class Inliner : public InlinerBase { public: - Inliner() - : InlinerBase({ - {IsUniqueUse, true}, - {IsTrivial, false}, - {IsInside, false}, - {IsCore, false}, - {NoCriterion, true}, - }) {} + explicit Inliner(bool use_move = true) + : InlinerBase( + { + {IsUniqueUse, true}, + {IsTrivial, false}, + {IsInside, false}, + {IsCore, false}, + {IsDirectParentCall, false}, + {NoCriterion, true}, + }, + use_move) {} ~Inliner() override = default; }; + +class DirectInliner : public InlinerBase { + public: + explicit DirectInliner(bool use_move = true) + : InlinerBase( + { + {IsDirectParentCall, false}, + }, + use_move) {} + ~DirectInliner() override = default; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h index fc859b213e..ab72c19475 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h @@ -26,6 +26,30 @@ namespace mindspore { namespace opt { namespace irpass { +namespace internal { +class GetRefValueTransform { + public: + GetRefValueTransform() {} + ~GetRefValueTransform() = default; + + AnfNodePtr operator()(const AnfNodePtr &node) { + CNodePtr cnode = node->cast(); + auto inputs = cnode->inputs(); + auto fg = GetValueNode(inputs[0])->cast(); + if (fg->recursive()) { + MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString(); + return node; + } + auto new_fg = TransformableClone(fg, std::make_shared("GetRefValue")); + auto output = new_fg->output(); + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output})); + inputs[0] = NewValueNode(new_fg); + auto ret_node = cnode->func_graph()->NewCNode(inputs); + return ret_node; + } +}; +} // namespace internal + // {prim::kPrimMakeRef, X, Y, Z} -> Y class MakeRefEliminater : public OptimizerCaller { public: @@ -48,13 +72,23 @@ class GetRefParamEliminater : public OptimizerCaller { // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y +// {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f} class GetMakeRefEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x, y, z; MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); - + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node)); + internal::GetRefValueTransform trans; + auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr { + auto rep = trans(x.GetNode(node)); + if (rep != nullptr) { + return rep; + } + return nullptr; + }; + MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node)); return nullptr; } }; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index c5b38fe829..c556b3399f 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -314,6 +314,7 @@ bool TaskEmitAction(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); auto bc_ptr = res->results()[kBackend].cast(); auto context_ptr = MsContext::GetInstance(); + std::string backend = MsContext::GetInstance()->backend_policy(); MS_EXCEPTION_IF_NULL(context_ptr); if (CompileGraphs::ContainMixedTarget(func_graph)) { bc_ptr->set_is_multi_graph_sink(false); @@ -321,13 +322,13 @@ bool TaskEmitAction(const ResourcePtr &res) { context_ptr->set_loop_sink_flag(false); } else if (context_ptr->execution_mode() != kPynativeMode) { std::string device_target = context_ptr->device_target(); - if (device_target == kAscendDevice) { + if (device_target == kAscendDevice && backend != kMsVm) { bc_ptr->set_is_multi_graph_sink(true); context_ptr->set_is_multi_graph_sink(true); } } - if (IsCtrlSink()) { + if (IsCtrlSink() && backend == kMsConvert) { res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); return true; } @@ -344,8 +345,8 @@ bool ExecuteAction(const ResourcePtr &res) { if (res->results().count(kOutput) == 0) { MS_LOG(EXCEPTION) << "Execute args error"; } - - if (IsCtrlSink()) { + std::string backend = MsContext::GetInstance()->backend_policy(); + if (IsCtrlSink() && backend == kMsConvert) { if (!res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; } diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 0c27ba7c48..c047e11335 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -30,6 +30,7 @@ #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/validator.h" +#include "pipeline/jit/remove_value_node_dup.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/cse.h" #include "frontend/optimizer/graph_kernel_reuse.h" @@ -127,11 +128,14 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.incorporate_getitem_set_, irpass.incorporate_call_, irpass.incorporate_call_switch_, - irpass.incorporate_env_getitem_, + irpass.incorporate_env_getitem_bypass_recursive_, irpass.incorporate_env_getitem_switch_, irpass.new_env_get_item_, irpass.depend_value_elim_, }); + opt::OptPassConfig a_after_grad = opt::OptPassConfig({ + irpass.inline_without_move_, + }); opt::OptPassConfig a_3 = opt::OptPassConfig({ irpass.arithmetic_simplify2_, irpass.same_eliminate_, @@ -154,6 +158,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { {"virtual_dataset", virtual_dataset}, {"grad", grad}, {"resolve", resolve_pass}, + {"a_after_grad", a_after_grad}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"cse", opt::OptPassConfig(opt::CSE(false))}, {"a_3", a_3}}); @@ -161,11 +166,24 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { return map_a; } +OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig c_1 = opt::OptPassConfig({ + // Safe inlining + irpass.inline_, + irpass.partial_eliminate_, + }); + + OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); + + return map_a; +} + OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig b_1 = - opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, - irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_}); + opt::OptPassConfig b_1 = opt::OptPassConfig( + {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, + irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, + irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, + irpass.value_based_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, @@ -244,6 +262,8 @@ void InitOpt(const ResourcePtr &res) { opt::irpass::OptimizeIRPassLib irpass; g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); + g_pass_opts["opt_after_cconv"] = + Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true); g_pass_opts["opt_graph_kernel_a"] = Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); g_pass_opts["opt_graph_kernel_b"] = @@ -288,6 +308,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } +bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } @@ -311,6 +332,33 @@ bool AddControlDependPass(const ResourcePtr &res) { return true; } +bool MergeDupGraphPass(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(res->manager()); + if (res->manager()->func_graphs().size() <= 1) { + return true; + } + return MergeDuplicateGraphs(res->manager()); +} + +bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "Remove value node duplications error."; + } + auto manager = res->manager(); + HashCache hash_cache; + HashValue hashes; + // Remove duplicated value nodes across all graphs in manager + for (auto &fg : manager->func_graphs()) { + auto value_nodes = fg->value_nodes(); + for (const auto &value_pair : value_nodes) { + TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); + } + } + return true; +} + bool CconvPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); @@ -340,6 +388,8 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru {"clean_after_opta", CleanAfterOptAPass}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}, + {"opt_after_cconv", OptPassAfterCconvGroup}, + {"remove_dup_value", RemoveValueNodeDuplicationsPass}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_control_depend", AddControlDependPass}}; diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc index e9467e4aeb..2d390c46a2 100644 --- a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "pipeline/jit/remove_value_node_dup.h" #include "ir/anf.h" @@ -70,5 +71,108 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has // Meet for the first time, append node to bucket. bucket.emplace_back(node); } + +size_t HashOfGraph(const FuncGraphPtr &fg) { + std::vector toposet = TopoSort(fg->get_return()); + MS_LOG(DEBUG) << "TopSort for:" << fg->ToString(); + std::unordered_map hashes; + auto ¶ms = fg->parameters(); + for (size_t i = 0; i < params.size(); i++) { + hashes[params[i]] = std::hash{}("param" + std::to_string(i)); + } + for (auto node : toposet) { + MS_EXCEPTION_IF_NULL(node); + if (hashes.find(node) != hashes.end()) { + continue; + } + + std::size_t h = 0; + if (node->isa()) { + ValueNodePtr value_node = node->cast(); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (IsValueNode(value_node)) { + auto v_fg = value->cast(); + h = value->hash(); + } else if (IsValueNode(value_node)) { + // the tensor has same value has been replaced in duplicate value pass, + // so we use the value pointer here as an identifier + h = hash_combine(value->hash(), std::hash{}(value.get())); + } else { + h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash())); + } + } else if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + size_t init = 0; + h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) { + return hash_combine(hash, hashes[node_in]); + }); + } else if (node->isa()) { + h = node->hash(); + } else { + MS_LOG(ERROR) << "Unknow node type"; + } + hashes[node] = h; + } + return hashes[fg->get_return()]; +} + +bool IsCNodeGraph(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + return IsValueNode(inp0); +} + +bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) { + std::unordered_map> hash_graphs; + std::unordered_map graph_hash; + for (auto fg : manager->func_graphs()) { + size_t h = HashOfGraph(fg); + graph_hash[fg] = h; + if (hash_graphs.find(h) == hash_graphs.end()) { + hash_graphs[h] = {fg}; + } else { + hash_graphs[h].push_back(fg); + } + } + FuncGraphPairMapEquiv equiv_graph; + NodeMapEquiv equiv_node; + for (auto &fg : manager->func_graphs()) { + MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString(); + for (auto &item : fg->nodes()) { + if (!item->isa()) { + continue; + } + auto &inputs = item->cast()->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + if (!inputs[i]->isa()) { + continue; + } + auto value_ptr = GetValueNode(inputs[i]); + auto v_fg = value_ptr->cast(); + if (v_fg == nullptr) { + continue; + } + auto &fg_vec = hash_graphs[graph_hash[v_fg]]; + if (fg_vec.size() > 1) { + if (v_fg != fg_vec[0]) { + bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node); + if (is_morphic) { + auto new_node = NewValueNode(fg_vec[0]); + MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString(); + manager->Replace(inputs[i], new_node); + } + } + } + } + } + } + return true; +} + } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h index fd52924d58..39fcd4472b 100644 --- a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h @@ -28,6 +28,10 @@ using HashCache = std::unordered_map>; using HashValue = std::unordered_map; void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); +size_t HashOfGraph(const FuncGraphPtr &fg); +bool IsCNodeGraph(const AnfNodePtr &node); +bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager); + } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 424a057bc3..f6ffda863b 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -113,17 +113,18 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr } const AnfNodePtr &func_node = fg->get_return(); - MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() + MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); AbstractBasePtr ret_base = nullptr; std::vector nodes = FastShadowSort(func_node); for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { const auto &node = *it; AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); - MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); + MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() + << ", node_conf: " << node_conf->ToString(); ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); - MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() - << ", abstract: " << ret_base->ToString(); + MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() + << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); } MS_EXCEPTION_IF_NULL(ret_base); @@ -142,16 +143,17 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), [](const AbstractBasePtr &arg) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); + if (arg->GetValueTrack() != kAnyValue) { + return arg->Broaden(); + } + return arg; }); - if (func_graph_->joined_shapes_.size() != broaded_list.size()) { - MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size() - << " does not equal to number of original buffer arguments " - << func_graph_->joined_shapes_.size(); - } - for (size_t i = 0; i < broaded_list.size(); ++i) { - broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); + if (func_graph_->joined_shapes_.size() == broaded_list.size()) { + for (size_t i = 0; i < broaded_list.size(); ++i) { + broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); + } } + MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) << ", broaded: " << mindspore::ToString(broaded_list); return broaded_list; @@ -181,8 +183,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->joined_shapes_.clear(); std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), - std::back_inserter(func_graph_->joined_shapes_), - [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); + std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { + if (arg_spec->isa()) { + return arg_spec->cast()->ref()->GetShapeTrack(); + } + return arg_spec->GetShapeTrack(); + }); + joined_args_spec_list = NormalizeArgs(joined_args_spec_list); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } return joined_args_spec_list; @@ -199,8 +206,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->joined_shapes_.clear(); std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), - std::back_inserter(func_graph_->joined_shapes_), - [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); + std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { + if (arg_spec->isa()) { + return arg_spec->cast()->ref()->GetShapeTrack(); + } + return arg_spec->GetShapeTrack(); + }); + joined_args_spec_list = NormalizeArgs(joined_args_spec_list); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 338743b1da..e5f9cdb6b2 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -188,6 +188,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { trace::TraceEvalCNodeLeave(); } else { MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() + << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph") << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } @@ -301,6 +302,8 @@ void AnalysisEngine::Clear() { anfnode_config_map_.clear(); eval_trace_.clear(); constructors_.clear(); + constructors_app_.clear(); + continued_evals_.clear(); } namespace { @@ -426,8 +429,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptrfn(); EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); + auto part_pair = std::make_pair(func_orig, func->args()); + auto itr = constructors_app_.find(part_pair); + if (itr != constructors_app_.end()) { + return itr->second; + } std::shared_ptr partial_evaluator = std::make_shared(evaluator_orig, func->args()); + constructors_app_[part_pair] = partial_evaluator; return partial_evaluator; } @@ -504,9 +513,10 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { if (fg_eval == nullptr) { return; } + auto fg = fg_eval->func_graph(); MS_EXCEPTION_IF_NULL(fg); - auto undetermined_fgs = fg->recursive_graphs(); + auto undetermined_fgs = fg->recursive(); if (undetermined_fgs) { auto fg_parent = fg->parent(); MS_EXCEPTION_IF_NULL(fg_parent); @@ -546,15 +556,19 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vectorToString() << " check undetermined."; - if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; + MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined."; + auto &alternate_evaluator = multi_poss_[u_eval.first]; + auto &eval_cache = alternate_evaluator->cache(); + if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) && + (((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) || + (eval_cache->find(args_spec_list) == eval_cache->end()))) { + MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined."; has_undetermined = true; break; } } if (has_undetermined == false) { - MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; + MS_LOG(DEBUG) << eval->ToString() << "has no undetermined."; *continue_flag = true; return latest_entry; } @@ -597,34 +611,33 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorToString(); - // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); if (it == eval_trace_.rend()) { eval_trace_.push_back(current_inf); - MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); MS_EXCEPTION_IF_NULL(eval); auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); out_specs.push_back(eval_result->abstract()); eval_trace_.pop_back(); if (eval_trace_.empty()) { multi_poss_.clear(); } - } else if (it != eval_trace_.rbegin()) { + } else { bool continue_flag = false; auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); if (continue_flag) { + MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.first.get() << current_inf.first->ToString(); + continued_evals_.insert(current_inf); continue; } // Try to travel the latest undetermined. if (latest_entry != eval_trace_.rbegin()->first) { - MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); + MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString(); auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() + MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); return eval_result; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 0ebd9a0af4..7018932898 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -26,6 +26,7 @@ #include #include #include +#include #ifdef DEBUG #include @@ -113,7 +114,8 @@ class AnfNodeConfig : public Config { std::string ToString() const override { std::ostringstream buffer; - buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); + buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId() + << "), Context: " << context_->ToString(); return buffer.str(); } @@ -173,7 +175,13 @@ struct AnalysisResult { }; using EvalTraceRevIter = std::list>::reverse_iterator; - +struct PartialAppHasher { + std::size_t operator()(const std::pair &p) const { + auto h1 = std::hash{}(p.first); + auto h2 = AbstractBasePtrListHash(p.second); + return h1 ^ h2; + } +}; class AnalysisEngine : public std::enable_shared_from_this { public: AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) @@ -233,10 +241,13 @@ class AnalysisEngine : public std::enable_shared_from_this { const PrimEvaluatorMap &prim_constructors_; FuncGraphManagerPtr func_graph_manager_; std::unordered_map constructors_; + std::unordered_map, EvaluatorPtr, PartialAppHasher> + constructors_app_; AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. std::list> eval_trace_; std::map multi_poss_; + std::set> continued_evals_; AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, const ConfigPtrList &args_conf_list); diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 9655f7a659..95a54eebb2 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -34,6 +34,7 @@ using mindspore::abstract::AbstractError; using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractRef; using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractSparseTensor; @@ -83,7 +84,8 @@ void ValidateAbstract(const AnfNodePtr &node) { // only send string in external if (!IsValueNode(node)) { // Validate a type. - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() + << " for node=" << node->DebugString(); } } return; @@ -96,7 +98,7 @@ void ValidateAbstract(const AnfNodePtr &node) { if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa()) { + ptrBase->isa() || ptrBase->isa() || ptrBase->isa()) { return; } diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 70590da753..0d3ab3b651 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -481,8 +481,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple } // Isomorphism -static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, - NodeMapEquiv *const equiv_node) { +static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node); +bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { if (equiv_node == nullptr) { MS_LOG(ERROR) << "Invalid equiv_node"; return false; @@ -514,6 +516,9 @@ static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, Fu MS_LOG(DEBUG) << "two parameters are not equal."; return false; } + if (node1->isa() && node2->isa()) { + return SameNode(node1, node2, equiv_func_graph, equiv_node); + } MS_LOG(ERROR) << "type error"; return false; } diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 141adc1bff..33197876a7 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -116,12 +116,15 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo } // namespace std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { - auto fg = std::make_shared(); - AnfNodePtrList inputs; - AnfNodePtrToAnfNodePtrMap eqv; if (lst.empty()) { MS_LOG(EXCEPTION) << "Input anf node list is empty"; } + TraceManager::DebugTrace( + std::make_shared(lst[0]->cast()->func_graph()->debug_info())); + auto fg = std::make_shared(); + TraceManager::EndTrace(); + AnfNodePtrList inputs; + AnfNodePtrToAnfNodePtrMap eqv; // Merge CNodes into a AnfGraph that represents a linear instruction segment for (auto n : lst) { if (!n->isa()) { @@ -154,7 +157,9 @@ std::tuple TransformSegmentToAnfGr (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); } + TraceManager::DebugTrace(std::make_shared(n->debug_info())); eqv[n] = fg->NewCNode(args); + TraceManager::EndTrace(); eqv[n]->set_abstract(n->abstract()); eqv[n]->set_kernel_info(n->kernel_info_ptr()); } diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index dab262bc89..0fb6759d95 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -452,6 +452,10 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { } auto other_tensor = dyn_cast(other); if (other_tensor == nullptr) { + auto ref_tensor = dyn_cast(other); + if (ref_tensor != nullptr) { + return this->Join(ref_tensor->ref()); + } MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } if (*this == *other) { diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc index cde5eaafba..ccdf8ee1d7 100644 --- a/mindspore/core/ir/graph_utils.cc +++ b/mindspore/core/ir/graph_utils.cc @@ -48,7 +48,7 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c continue; } if (rank.find(node) != rank.end() && rank[node] != todo.size()) { - MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(); + MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2); } rank[node] = todo.size(); bool cont = false; diff --git a/mindspore/core/ir/scalar.h b/mindspore/core/ir/scalar.h index b814a4781d..62c5f35ba5 100644 --- a/mindspore/core/ir/scalar.h +++ b/mindspore/core/ir/scalar.h @@ -30,6 +30,7 @@ #include "base/base.h" #include "ir/dtype.h" #include "ir/dtype/number.h" +#include "utils/hashing.h" using std::fabs; @@ -51,7 +52,7 @@ using ScalarPtr = std::shared_ptr; class BoolImm : public Scalar { public: - explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash{}(v_); } + explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~BoolImm() override = default; MS_DECLARE_PARENT(BoolImm, Scalar) std::size_t hash() const override { return hash_; } @@ -91,7 +92,7 @@ class IntergerImm : public Scalar { class Int8Imm : public IntergerImm { public: Int8Imm() : IntergerImm(kInt8), v_(0) {} - explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash{}(v_); } + explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int8Imm() override = default; MS_DECLARE_PARENT(Int8Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -117,7 +118,7 @@ IMM_TRAITS(Int8ImmPtr, int8_t) class Int16Imm : public IntergerImm { public: Int16Imm() : IntergerImm(kInt16), v_(0) {} - explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash{}(v_); } + explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int16Imm() override = default; MS_DECLARE_PARENT(Int16Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -143,7 +144,7 @@ IMM_TRAITS(Int16ImmPtr, int16_t) class Int32Imm : public IntergerImm { public: Int32Imm() : IntergerImm(kInt32), v_(0) {} - explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash{}(v_); } + explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int32Imm() override = default; MS_DECLARE_PARENT(Int32Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -169,7 +170,7 @@ IMM_TRAITS(Int32ImmPtr, int32_t) class Int64Imm : public IntergerImm { public: Int64Imm() : IntergerImm(kInt64), v_(0) {} - explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash{}(v_); } + explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int64Imm() override = default; MS_DECLARE_PARENT(Int64Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -195,7 +196,9 @@ IMM_TRAITS(Int64ImmPtr, int64_t) class UInt8Imm : public IntergerImm { public: UInt8Imm() : IntergerImm(kUInt8), v_(0) {} - explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash{}(v_); } + explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v_)}); + } ~UInt8Imm() override = default; MS_DECLARE_PARENT(UInt8Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -221,7 +224,9 @@ IMM_TRAITS(UInt8ImmPtr, uint8_t); class UInt16Imm : public IntergerImm { public: UInt16Imm() : IntergerImm(kUInt16), v_(0) {} - explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash{}(v_); } + explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v_)}); + } ~UInt16Imm() override = default; MS_DECLARE_PARENT(UInt16Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -247,7 +252,9 @@ IMM_TRAITS(UInt16ImmPtr, uint16_t); class UInt32Imm : public IntergerImm { public: UInt32Imm() : IntergerImm(kUInt32), v_(0) {} - explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash{}(v_); } + explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v_)}); + } ~UInt32Imm() override = default; MS_DECLARE_PARENT(UInt32Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -273,7 +280,9 @@ IMM_TRAITS(UInt32ImmPtr, uint32_t); class UInt64Imm : public IntergerImm { public: UInt64Imm() : IntergerImm(kUInt64), v_(0) {} - explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash{}(v); } + explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v)}); + } ~UInt64Imm() override = default; MS_DECLARE_PARENT(UInt64Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -308,7 +317,7 @@ using FloatImmPtr = std::shared_ptr; class FP32Imm : public FloatImm { public: FP32Imm() : FloatImm(kFloat32), v_(0.0) {} - explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash{}(v_); } + explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~FP32Imm() override = default; MS_DECLARE_PARENT(FP32Imm, FloatImm) std::size_t hash() const override { return hash_; } @@ -334,7 +343,7 @@ IMM_TRAITS(FP32ImmPtr, float) class FP64Imm : public FloatImm { public: FP64Imm() : FloatImm(kFloat64), v_(0.0) {} - explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash{}(v_); } + explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~FP64Imm() override = default; MS_DECLARE_PARENT(FP64Imm, FloatImm) std::size_t hash() const override { return hash_; } diff --git a/mindspore/core/utils/trace_info.h b/mindspore/core/utils/trace_info.h index fea2cb3ea8..5c9160d7c1 100644 --- a/mindspore/core/utils/trace_info.h +++ b/mindspore/core/utils/trace_info.h @@ -412,6 +412,16 @@ class TraceCombileLikeGraphs : public TraceInfo { return std::make_shared(*shared_from_base()); } }; + +class TraceSegmentTransform : public TraceInfo { + public: + explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {} + MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); + ~TraceSegmentTransform() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_ diff --git a/tests/st/control/test_cont_grad.py b/tests/st/control/test_cont_grad.py new file mode 100644 index 0000000000..c3baae1fb5 --- /dev/null +++ b/tests/st/control/test_cont_grad.py @@ -0,0 +1,816 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test control ops """ +import numpy as np + +from mindspore import dtype as ms +from mindspore import Tensor +from mindspore import context +from mindspore import nn +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.ops import composite as C +from mindspore.ops import operations as P +# from tests.vm_impl.math_ops_vm_impl import * +# from tests.vm_impl.vm_interface import * +# from tests.vm_impl import * +# context.set_context(save_graphs=True) + + +def test_while_forward(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + + def construct(self, idx, end, x): + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + idx = idx + 1 + return x + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + net = MyWhileNet() + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + + def construct(self, idx, end, x): + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + idx = idx + 1 + return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_forward(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + out = out + x + self.param + idx = idx + 1 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + net = MyWhileNet() + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_endless_case(): + """endless case when optmization""" + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + part = x[idx, :, :] + out = out + part + idx = idx + 1 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + net = MyWhileNet() + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + out = out + x + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_forward_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = while_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_opt_endless(): + """endless during optimization case""" + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.addn = P.AddN() + + def construct(self, idx, end, x): + addn1 = self.addn((x, x, x)) + out = addn1 + while idx < end: + out = self.addn((out, addn1)) + idx = idx + 1 + out = self.addn((out, x)) + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) + net(idx, end, x) + + +def test_no_while_call(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + def construct(self, idx, end, x): + out = self.zero + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = while_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = self.zero + for _ in range(0, 2): + idx = self.start + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_basic(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = self.zero + for _ in range(0, 2): + idx = self.start + while idx < end: + out = out + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_normal(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = x + for _ in range(0, 2): + idx = self.start + while idx < end: + out = out + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_mul(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out * self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_two(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + self.weight + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_three(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") + self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + self.weight + self.key + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_if_with_param_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if self.max(out) < self.max(x): + out = out + self.param * 2 + else: + out = out + self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad_not_enter_while(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param * 3 + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(3), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param + else: + out = out + x + if a == b: + out = out + x*3 + self.param + else: + out = out + x*2 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_inputs(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 4 + if a == b: + out = out + x*3 + self.param * 3 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_parameter(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 2 + if a == b: + out = out + x*3 + self.param + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_param_excute_null(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 2 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(4), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_return_inside_grad(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + return out + x + self.param + if a == b: + return out + self.param * 2 + return out + self.param * 3 + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(1), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + if a < b: + a = self.add(a, b) + else: + a = self.sub(a, b) + if a == x: + a = self.mul(a, b) + else: + a = self.div(a, b) + if b == x: + b = self.add(a, b) + else: + b = self.add(a, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(4), dtype=ms.float32) + net(idx, end, x) diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index e43a8272ca..71c0f39d36 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -58,6 +58,7 @@ add_subdirectory(serving) file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/core/base/*.cc" + "../../../mindspore/core/gvar/*.cc" "../../../mindspore/core/abstract/*.cc" "../../../mindspore/core/ir/*.cc" "../../../mindspore/core/utils/*.cc" diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 2eb3584c33..4f3c3302bb 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -34,7 +34,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ import pipeline_for_compile_grad_ge_graph_for_case_by_case_config - class InputBackward(nn.Cell): def __init__(self, network): super(InputBackward, self).__init__() diff --git a/tests/ut/python/runtest.sh b/tests/ut/python/runtest.sh index 6108e0f475..1a687b9b35 100755 --- a/tests/ut/python/runtest.sh +++ b/tests/ut/python/runtest.sh @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - CURRPATH=$(cd $(dirname $0); pwd) IGNORE_EXEC="--ignore=$CURRPATH/exec" PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd) diff --git a/tests/vm_impl/array_ops_vm_impl.py b/tests/vm_impl/array_ops_vm_impl.py index 921d5c5182..9f54533213 100644 --- a/tests/vm_impl/array_ops_vm_impl.py +++ b/tests/vm_impl/array_ops_vm_impl.py @@ -14,7 +14,6 @@ # ============================================================================ """Generate vm_impl function for array ops""" import numpy as np - import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.ops import operations as P @@ -22,7 +21,6 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from .vm_interface import vm - # pylint: disable=unused-argument @@ -181,8 +179,7 @@ def vm_impl_tile(self): def vm_impl(x, multiples): x = x.asnumpy() - multiples = multiples.asnumpy() - out = vm.Tile(x, multiples) + out = np.tile(x, multiples) return Tensor(out) return vm_impl @@ -255,7 +252,10 @@ def vm_impl_sum(self): def vm_impl(x, axis): x = x.asnumpy() - out = vm.sum(x, axis) + if axis == (): + out = np.sum(x) + else: + out = np.sum(x, axis=axis) return Tensor(np.array(out)) return vm_impl @@ -291,12 +291,14 @@ def vm_impl_square(self): return vm_impl + @vm_impl_getters.register(P.ZerosLike) def vm_impl_zeros_like(self): """Generate vm_impl function for ZerosLike""" def vm_impl(x): return Tensor(np.zeros_like(x.asnumpy())) + @vm_impl_getters.register(P.Partial) def vm_impl_partial(self): """Generate vm_impl function for Partial""" @@ -307,6 +309,7 @@ def vm_impl_partial(self): return vm_impl + @vm_impl_getters.register(P.Depend) def vm_impl_depend(self): """Generate vm_impl function for Depend""" diff --git a/tests/vm_impl/math_ops_vm_impl.py b/tests/vm_impl/math_ops_vm_impl.py index d409616436..9a614c9c92 100644 --- a/tests/vm_impl/math_ops_vm_impl.py +++ b/tests/vm_impl/math_ops_vm_impl.py @@ -196,6 +196,18 @@ def vm_impl_reduce_mean(self): return vm_impl +@vm_impl_getters.register(P.ReduceMax) +def vm_impl_reduce_max(self): + """Generate vm_impl function for ReduceMean.""" + + def vm_impl(x, axis): + x = x.asnumpy() + if axis == (): + axis = None + out = np.amax(x, axis) + return Tensor(out) + + return vm_impl @vm_impl_getters.register(P.Equal) def vm_impl_equal(self):