|
|
|
@@ -37,7 +37,6 @@ AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_ |
|
|
|
} else { |
|
|
|
ResetKernelInfo(new_cnode, UNKNOWN_KERNEL_TYPE); |
|
|
|
} |
|
|
|
|
|
|
|
func_graph->AddNode(new_cnode); |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
@@ -287,6 +286,11 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
auto neg_mul_lambda = [&node, &x, &const_1]() -> AnfNodePtr { |
|
|
|
auto new_rhs = const_1.ValueNodeWithOprations(prim::kPrimNeg); |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2) |
|
|
|
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda); |
|
|
|
// (x*C1)*C2 ==> x*(C1*C2) |
|
|
|
@@ -308,6 +312,9 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { |
|
|
|
// x*rsqrt(x) ==> sqrt(x) |
|
|
|
MATCH_REPLACE_LAMBDA_IF(node, x * PUnaryOperation(prim::kPrimRsqrt, y), rsqrt_merge_lambda_3, |
|
|
|
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node))); |
|
|
|
// Neg(x) * const | const * Neg(x) = x * (-const) |
|
|
|
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimNeg, x) * const_1, neg_mul_lambda); |
|
|
|
MATCH_REPLACE_LAMBDA(node, const_1 * PUnaryOperation(prim::kPrimNeg, x), neg_mul_lambda); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -375,6 +382,11 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, u.GetNode(node)}, node); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
auto neg_div_lambda = [&node, &x, &const_1]() -> AnfNodePtr { |
|
|
|
auto new_rhs = const_1.ValueNodeWithOprations(prim::kPrimNeg); |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
// x/1 ==> x |
|
|
|
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x); |
|
|
|
MATCH_REPLACE(node, x / const_one, x); |
|
|
|
@@ -392,6 +404,8 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { |
|
|
|
!PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node))); |
|
|
|
// x / rsqrt(y) ==> x * sqrt(y) |
|
|
|
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda); |
|
|
|
// Neg(x) / const = x / (-const) |
|
|
|
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimNeg, x) / const_1, neg_div_lambda); |
|
|
|
// // x / const ==> x * (1/const) |
|
|
|
MATCH_REPLACE_LAMBDA(node, x / const_1, div_const); |
|
|
|
// (x/y) / (u/v) ==> (x*v) / (y*u) |
|
|
|
@@ -505,15 +519,17 @@ std::vector<std::pair<int, int>> GetUnmodifiedDim(const ShapeVector &a, const Sh |
|
|
|
if (i >= a.size() && j >= b.size()) { |
|
|
|
break; |
|
|
|
} |
|
|
|
patial_a *= a[i]; |
|
|
|
patial_b *= b[j]; |
|
|
|
if (i == j || patial_a == patial_b) { |
|
|
|
patial_a *= a[i]; |
|
|
|
patial_b *= b[j]; |
|
|
|
} |
|
|
|
if (patial_a == patial_b && a[i] == b[j]) { |
|
|
|
unmodified.emplace_back(std::make_pair(i, j)); |
|
|
|
++i; |
|
|
|
++j; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (patial_a < patial_b && b[j] > a[i]) { |
|
|
|
if (patial_a < patial_b) { |
|
|
|
++i; |
|
|
|
patial_a *= a[i]; |
|
|
|
if (patial_a == patial_b) { |
|
|
|
@@ -522,7 +538,7 @@ std::vector<std::pair<int, int>> GetUnmodifiedDim(const ShapeVector &a, const Sh |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (patial_a > patial_b && b[j] < a[i]) { |
|
|
|
if (patial_a > patial_b) { |
|
|
|
++j; |
|
|
|
patial_b *= b[j]; |
|
|
|
if (patial_a == patial_b) { |
|
|
|
@@ -541,7 +557,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
PatternNode<AnfNodePtr> x; |
|
|
|
auto trans_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto shape = GetNodeShape(node); |
|
|
|
if (shape.size() != 0 && shape.size() != 1) { |
|
|
|
return node; |
|
|
|
@@ -560,7 +576,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
}; |
|
|
|
auto reduce_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto tmp_node = node->cast<CNodePtr>(); |
|
|
|
auto arg_node = tmp_node->input(1); |
|
|
|
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis")); |
|
|
|
@@ -580,7 +596,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { |
|
|
|
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
auto reshape_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto reshape_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto tmp_node = node->cast<CNodePtr>(); |
|
|
|
auto arg_node = tmp_node->input(1); |
|
|
|
auto input_shape = GetNodeShape(arg_node->cast<CNodePtr>()->input(1)); |
|
|
|
@@ -621,17 +637,27 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
return node; |
|
|
|
}; |
|
|
|
auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr { |
|
|
|
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node); |
|
|
|
AnfAlgo::CopyNodeAttr("axis", node, arg_node); |
|
|
|
AnfAlgo::CopyNodeAttr("keep_dims", node, arg_node); |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), arg_node}, node); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
std::list<PrimitivePtr> ReduceOperations = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; |
|
|
|
for (auto operation : ReduceOperations) { |
|
|
|
// Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lamda, |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lambda, |
|
|
|
operation); |
|
|
|
// Reduce(Reduce(A)) = Reduce(A) |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lamda, operation); |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lambda, operation); |
|
|
|
// Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lamda, |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lambda, |
|
|
|
operation); |
|
|
|
} |
|
|
|
// ReduceSum(Neg(x)) = Neg(ReduceSum(x)) |
|
|
|
MATCH_REPLACE_LAMBDA(node, PPrimitive(prim::kPrimReduceSum, PUnaryOperation(prim::kPrimNeg, x)), |
|
|
|
neg_reducesum_lambda); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -769,15 +795,18 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) { |
|
|
|
mng_sub = Manage(sub_graph, false); |
|
|
|
sub_graph->set_manager(mng_sub); |
|
|
|
} |
|
|
|
for (auto node_sub : sub_graph->GetOrderedCnodes()) { |
|
|
|
auto new_node = TrySimplify(node_sub); |
|
|
|
if (new_node != nullptr) { |
|
|
|
PERFORM_REPLACE(node_sub->cast<AnfNodePtr>(), new_node, sub_graph, replaced); |
|
|
|
bool need_traverse = true; |
|
|
|
while (need_traverse) { |
|
|
|
need_traverse = false; |
|
|
|
for (auto node_sub : sub_graph->GetOrderedCnodes()) { |
|
|
|
auto new_node = TrySimplify(node_sub); |
|
|
|
if (new_node != nullptr) { |
|
|
|
PERFORM_REPLACE(node_sub->cast<AnfNodePtr>(), new_node, sub_graph, replaced); |
|
|
|
need_traverse = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto new_node = TrySimplify(node); |
|
|
|
PERFORM_REPLACE(node->cast<AnfNodePtr>(), new_node, func_graph, replaced); |
|
|
|
} |
|
|
|
} |
|
|
|
EliminateEmptyGraph(func_graph); |
|
|
|
|