| @@ -69,7 +69,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // ops eliminate | // ops eliminate | ||||
| item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", | item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", | ||||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); | {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); | ||||
| tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile); | |||||
| tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile); | |||||
| cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | ||||
| reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | ||||
| transpose_eliminate_ = | transpose_eliminate_ = | ||||
| @@ -29,8 +29,9 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| // check if node is value tuple and all one. e.g. (1, 1, 1) | // check if node is value tuple and all one. e.g. (1, 1, 1) | ||||
| // {PrimTile, X, MultiOne} | |||||
| class TileMultiplyByOne : public AnfVisitor { | |||||
| // {PrimTile, X, MultiOne} -> X | |||||
| // {PrimTile, X, Empty} -> X | |||||
| class TileEliminater : public AnfVisitor { | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| Reset(); | Reset(); | ||||
| @@ -44,7 +45,7 @@ class TileMultiplyByOne : public AnfVisitor { | |||||
| auto value = GetValueNode(tuple_); | auto value = GetValueNode(tuple_); | ||||
| auto elements = GetValue<std::vector<int>>(value); | auto elements = GetValue<std::vector<int>>(value); | ||||
| if (elements.empty()) { | if (elements.empty()) { | ||||
| return nullptr; | |||||
| return x_; | |||||
| } | } | ||||
| auto cmp = std::all_of(elements.cbegin(), elements.cend(), [](int i) { return i == 1; }); | auto cmp = std::all_of(elements.cbegin(), elements.cend(), [](int i) { return i == 1; }); | ||||