|
|
|
@@ -20,22 +20,21 @@ |
|
|
|
#include <vector> |
|
|
|
#include <algorithm> |
|
|
|
|
|
|
|
#include "optimizer/optimizer.h" |
|
|
|
#include "optimizer/irpass.h" |
|
|
|
#include "ir/visitor.h" |
|
|
|
#include "ir/func_graph.h" |
|
|
|
#include "ir/func_graph_cloner.h" |
|
|
|
#include "operator/ops.h" |
|
|
|
#include "ir/optimizer_caller.h" |
|
|
|
#include "ir/pattern_matcher.h" |
|
|
|
#include "operator/ops.h" |
|
|
|
#include "optimizer/irpass.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace irpass { |
|
|
|
// {prim::kPrimSwitch, true, X, Y} |
|
|
|
// {prim::kPrimSwitch, false, X, Y} |
|
|
|
class SwitchSimplify { |
|
|
|
class SwitchSimplify : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br; |
|
|
|
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { |
|
|
|
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node))); |
|
|
|
@@ -54,9 +53,9 @@ class SwitchSimplify { |
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => |
|
|
|
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} |
|
|
|
class FloatTupleGetItemSwitch { |
|
|
|
class FloatTupleGetItemSwitch : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br, x; |
|
|
|
MATCH_REPLACE_IF(node, |
|
|
|
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), |
|
|
|
@@ -69,9 +68,9 @@ class FloatTupleGetItemSwitch { |
|
|
|
|
|
|
|
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => |
|
|
|
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} |
|
|
|
class FloatEnvGetItemSwitch { |
|
|
|
class FloatEnvGetItemSwitch : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2; |
|
|
|
MATCH_REPLACE_IF(node, |
|
|
|
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), |
|
|
|
@@ -93,9 +92,9 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN |
|
|
|
} // namespace internal |
|
|
|
|
|
|
|
// {{prim::kPrimSwitch, X, G1, G2}, Xs} |
|
|
|
class ConvertSwitchReplacement { |
|
|
|
class ConvertSwitchReplacement : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
if (!node->isa<CNode>() || node->func_graph() == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|