|
|
@@ -23,33 +23,33 @@ |
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace opt { |
|
|
namespace opt { |
|
|
const BaseRef RemoveReshapePair::DefinePattern() const { |
|
|
const BaseRef RemoveReshapePair::DefinePattern() const { |
|
|
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name()); |
|
|
|
|
|
VectorRef reshape({prim_reshape, input_varptr_}); |
|
|
|
|
|
|
|
|
|
|
|
return VectorRef({prim::kPrimReshape, reshape}); |
|
|
|
|
|
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(X); |
|
|
|
|
|
return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
const EquivPtr &equiv) const { |
|
|
const EquivPtr &equiv) const { |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
MS_EXCEPTION_IF_NULL(equiv); |
|
|
auto manager = func_graph->manager(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
|
|
auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); |
|
|
auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); |
|
|
MS_EXCEPTION_IF_NULL(reshape_op_1); |
|
|
MS_EXCEPTION_IF_NULL(reshape_op_1); |
|
|
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly |
|
|
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly |
|
|
auto users = manager->node_users()[reshape_op_1]; |
|
|
|
|
|
if (users.size() > 1) { |
|
|
|
|
|
|
|
|
if (IsUsedByOthers(func_graph, reshape_op_1)) { |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); |
|
|
auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); |
|
|
MS_EXCEPTION_IF_NULL(reshape_op_2); |
|
|
MS_EXCEPTION_IF_NULL(reshape_op_2); |
|
|
users = manager->node_users()[reshape_op_2]; |
|
|
|
|
|
if (users.size() > 1) { |
|
|
|
|
|
|
|
|
if (IsUsedByOthers(func_graph, reshape_op_2)) { |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
auto input_node = reshape_op_2->input(1); |
|
|
|
|
|
return input_node; |
|
|
|
|
|
|
|
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); |
|
|
|
|
|
auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); |
|
|
|
|
|
if (input_shape == output_shape) { |
|
|
|
|
|
auto input_node = reshape_op_2->input(1); |
|
|
|
|
|
return input_node; |
|
|
|
|
|
} |
|
|
|
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
} // namespace opt |
|
|
} // namespace opt |
|
|
} // namespace mindspore |
|
|
} // namespace mindspore |