Merge pull request !2276 from fary86/fix_eliminate_get_reftags/v0.5.0-beta
| @@ -81,6 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // Ref eliminate | // Ref eliminate | ||||
| make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); | make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); | ||||
| get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", | |||||
| {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||||
| get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", | get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", | ||||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | ||||
| @@ -57,6 +57,7 @@ class OptimizeIRPassLib { | |||||
| // Ref eliminate | // Ref eliminate | ||||
| SubstitutionPtr make_ref_eliminate_; | SubstitutionPtr make_ref_eliminate_; | ||||
| SubstitutionPtr get_ref_param_eliminate_; | |||||
| SubstitutionPtr get_make_ref_eliminate_; | SubstitutionPtr get_make_ref_eliminate_; | ||||
| SubstitutionPtr replace_refkey_by_param_; | SubstitutionPtr replace_refkey_by_param_; | ||||
| SubstitutionPtr replace_old_param_; | SubstitutionPtr replace_old_param_; | ||||
| @@ -46,6 +46,26 @@ class MakeRefEliminater : public AnfVisitor { | |||||
| AnfNodePtr y_{nullptr}; | AnfNodePtr y_{nullptr}; | ||||
| }; | }; | ||||
| // {prim::kPrimGetRefValue, Parameter} -> Parameter | |||||
| // {prim::kPrimGetRefOrigin, Parameter} -> Parameter | |||||
| class GetRefParamEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| x_ = nullptr; | |||||
| AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node); | |||||
| if (x_ != nullptr) { | |||||
| return x_; | |||||
| } | |||||
| AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node); | |||||
| return x_; | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { x_ = node; } | |||||
| private: | |||||
| AnfNodePtr x_{nullptr}; | |||||
| }; | |||||
| // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X | // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X | ||||
| // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y | // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y | ||||
| // {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z | // {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "debug/anf_ir_utils.h" | |||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "optimizer/opt.h" | #include "optimizer/opt.h" | ||||
| #include "pipeline/resource.h" | #include "pipeline/resource.h" | ||||
| @@ -175,6 +176,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | ||||
| func_graph->DumpFuncGraph(fg_name); | func_graph->DumpFuncGraph(fg_name); | ||||
| DumpIR(fg_name + ".ir", func_graph); | DumpIR(fg_name + ".ir", func_graph); | ||||
| ExportIR(fg_name + ".dat", "", func_graph); | |||||
| MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | ||||
| } | } | ||||
| } | } | ||||
| @@ -150,6 +150,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | opt::OptPassConfig b_2 = opt::OptPassConfig({ | ||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| irpass.make_ref_eliminate_, | irpass.make_ref_eliminate_, | ||||
| irpass.get_ref_param_eliminate_, | |||||
| }); | }); | ||||
| OptPassGroupMap map({ | OptPassGroupMap map({ | ||||
| {"b_1", b_1}, | {"b_1", b_1}, | ||||
| @@ -0,0 +1,31 @@ | |||||
| import numpy as np | |||||
| from mindspore import context, nn, Tensor, Parameter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, data): | |||||
| super(Net, self).__init__() | |||||
| self.start = Tensor(0, dtype=mstype.int32) | |||||
| self.end = Tensor(2, dtype=mstype.int32) | |||||
| self.max_output = Parameter(data, "output_x") | |||||
| self.upd = P.ScatterNdUpdate() | |||||
| self.zero = Tensor(np.ones([1], dtype=np.int32)) | |||||
| def construct(self, inputs): | |||||
| idx = self.start | |||||
| end = self.end | |||||
| while idx < end: | |||||
| xi = inputs[idx, :, :] | |||||
| self.upd(self.max_output, idx + self.zero, xi) | |||||
| idx = idx + 1 | |||||
| return self.max_output + 0 | |||||
| def test_x(): | |||||
| x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) | |||||
| net = Net(x) | |||||
| net(x) | |||||