|
|
|
@@ -14,18 +14,18 @@ |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h" |
|
|
|
|
|
|
|
#include <list> |
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" |
|
|
|
#include "backend/kernel_compiler/common_utils.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "ir/pattern_matcher.h" |
|
|
|
#include "frontend/operator/ops.h" |
|
|
|
#include "ir/pattern_matcher.h" |
|
|
|
#include "utils/convert_utils.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
|
|
|
|
AnfNodePtr NewCNodeWithInfo(const AnfNodePtrList &inputs, const AnfNodePtr &ori_node) { |
|
|
|
auto func_graph = ori_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
@@ -401,10 +401,236 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { |
|
|
|
(FLAG) = true; \ |
|
|
|
} |
|
|
|
|
|
|
|
bool TryTransposeToReshape(const AnfNodePtr &node) { |
|
|
|
auto perm = AnfAlgo::GetNodeAttr<std::vector<int>>(node, "perm"); |
|
|
|
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); |
|
|
|
std::vector<int> remove_one_perm; |
|
|
|
for (auto idx : perm) { |
|
|
|
if (idx < 0 || IntToSize(idx) >= ori_shape.size()) { |
|
|
|
MS_EXCEPTION(ValueError); |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (ori_shape[idx] != 1) { |
|
|
|
remove_one_perm.emplace_back(idx); |
|
|
|
} |
|
|
|
} |
|
|
|
if (remove_one_perm.size() < 2) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
for (size_t idx = 1; idx < remove_one_perm.size(); idx++) { |
|
|
|
if (remove_one_perm[idx] < remove_one_perm[idx - 1]) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr SimplifyTranspose(const AnfNodePtr &node) { |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimTranspose)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (TryTransposeToReshape(node)) { |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReshape), node->cast<CNodePtr>()->input(1)}, node); |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr SimplifyMatMul(const AnfNodePtr &node) { |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
PatternNode<AnfNodePtr> x, y; |
|
|
|
auto matmul_transpose_lambda = [&node, &x, &y]() -> AnfNodePtr { |
|
|
|
auto new_matmul = NewCNodeWithInfo({NewValueNode(prim::kPrimMatMul), y.GetNode(node), x.GetNode(node)}, node); |
|
|
|
auto new_abstract = node->abstract()->Clone(); |
|
|
|
auto ori_shape = node->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>(); |
|
|
|
auto shape_value = ori_shape->shape(); |
|
|
|
ShapeVector new_shape_value; |
|
|
|
std::copy(shape_value.rbegin(), shape_value.rend(), std::back_inserter(new_shape_value)); |
|
|
|
auto new_shape = std::make_shared<abstract::Shape>(new_shape_value); |
|
|
|
new_abstract->set_shape(new_shape); |
|
|
|
new_matmul->set_abstract(new_abstract); |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimTranspose), new_matmul}, node); |
|
|
|
auto transpose_a = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_a"); |
|
|
|
auto transpose_b = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_b"); |
|
|
|
auto transpose_x1 = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_x1"); |
|
|
|
auto transpose_x2 = AnfAlgo::GetNodeAttr<ValuePtr>(node, "transpose_x2"); |
|
|
|
auto perm = AnfAlgo::GetNodeAttr<ValuePtr>(node->cast<CNodePtr>()->input(1), "perm"); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_a", transpose_b, new_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_b", transpose_a, new_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x1", transpose_x2, new_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("transpose_x2", transpose_x1, new_matmul); |
|
|
|
AnfAlgo::SetNodeAttr("perm", perm, new_cnode); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
// MatMul(Transpose(x), Transpose(y)) ==> Transpose(MatMul(y, x)) |
|
|
|
MATCH_REPLACE_LAMBDA(node, |
|
|
|
PBinOperation(prim::kPrimMatMul, PUnaryOperation(prim::kPrimTranspose, x), |
|
|
|
PUnaryOperation(prim::kPrimTranspose, y), false), |
|
|
|
matmul_transpose_lambda); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
ShapeVector TransAxisValueToVector(const ValuePtr &value) { |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
ShapeVector axis_vector; |
|
|
|
if (value->isa<Int32Imm>()) { |
|
|
|
axis_vector.emplace_back(GetValue<int>(value)); |
|
|
|
} |
|
|
|
if (value->isa<ValueTuple>() || value->isa<ValueList>()) { |
|
|
|
axis_vector = GetValue<std::vector<int>>(value); |
|
|
|
} |
|
|
|
return axis_vector; |
|
|
|
} |
|
|
|
|
|
|
|
ShapeVector GetNodeShape(const AnfNodePtr &node) { |
|
|
|
auto base_shape = node->Shape()->cast<abstract::ShapePtr>(); |
|
|
|
std::vector<int> shape; |
|
|
|
std::transform(base_shape->shape().begin(), base_shape->shape().end(), std::back_inserter(shape), IntToSize); |
|
|
|
return shape; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::pair<int, int>> GetUnmodifiedDim(const ShapeVector &a, const ShapeVector &b) { |
|
|
|
std::vector<std::pair<int, int>> unmodified; |
|
|
|
for (size_t i = 0, j = 0, patial_a = 1, patial_b = 1;;) { |
|
|
|
if (i >= a.size() && j >= b.size()) { |
|
|
|
break; |
|
|
|
} |
|
|
|
patial_a *= a[i]; |
|
|
|
patial_b *= b[j]; |
|
|
|
if (patial_a == patial_b && a[i] == b[j]) { |
|
|
|
unmodified.emplace_back(std::make_pair(i, j)); |
|
|
|
++i; |
|
|
|
++j; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (patial_a < patial_b && b[j] > a[i]) { |
|
|
|
++i; |
|
|
|
patial_a *= a[i]; |
|
|
|
if (patial_a == patial_b) { |
|
|
|
++i; |
|
|
|
++j; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (patial_a > patial_b && b[j] < a[i]) { |
|
|
|
++j; |
|
|
|
patial_b *= b[j]; |
|
|
|
if (patial_a == patial_b) { |
|
|
|
++i; |
|
|
|
++j; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
return unmodified; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimReduceMax) && !IsPrimitiveCNode(node, prim::kPrimReduceMin) && |
|
|
|
!IsPrimitiveCNode(node, prim::kPrimReduceSum)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
PatternNode<AnfNodePtr> x; |
|
|
|
auto trans_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto shape = GetNodeShape(node); |
|
|
|
if (shape.size() != 0 && shape.size() != 1) { |
|
|
|
return node; |
|
|
|
} else { |
|
|
|
auto tmp_node = node->cast<CNodePtr>(); |
|
|
|
auto transpose_node = tmp_node->input(1); |
|
|
|
auto transpose_dimensions = GetValue<std::vector<int>>(AnfAlgo::GetNodeAttr<ValuePtr>(transpose_node, "perm")); |
|
|
|
ShapeVector new_dimensions; |
|
|
|
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis")); |
|
|
|
std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions), |
|
|
|
[&transpose_dimensions](const int &dim) { return transpose_dimensions[dim]; }); |
|
|
|
std::sort(new_dimensions.begin(), new_dimensions.end()); |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); |
|
|
|
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); |
|
|
|
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
}; |
|
|
|
auto reduce_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto tmp_node = node->cast<CNodePtr>(); |
|
|
|
auto arg_node = tmp_node->input(1); |
|
|
|
auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(arg_node, "axis")); |
|
|
|
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis")); |
|
|
|
ShapeVector new_dimensions; |
|
|
|
for (size_t i = 0; i < arg_dimensions.size(); ++i) { |
|
|
|
for (size_t j = 0; j < reduce_dimensions.size(); ++j) { |
|
|
|
if (reduce_dimensions[j] >= arg_dimensions[i]) { |
|
|
|
++reduce_dimensions[j]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(), |
|
|
|
std::back_inserter(new_dimensions)); |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); |
|
|
|
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); |
|
|
|
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); |
|
|
|
return new_cnode; |
|
|
|
}; |
|
|
|
auto reshape_reduce_lamda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { |
|
|
|
auto tmp_node = node->cast<CNodePtr>(); |
|
|
|
auto arg_node = tmp_node->input(1); |
|
|
|
auto input_shape = GetNodeShape(arg_node->cast<CNodePtr>()->input(1)); |
|
|
|
auto re_shape = GetNodeShape(arg_node); |
|
|
|
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis")); |
|
|
|
auto unmodified_dim_pair = GetUnmodifiedDim(input_shape, re_shape); |
|
|
|
std::vector<bool> dim_in_output(re_shape.size(), true); |
|
|
|
std::vector<bool> dim_unmodified(re_shape.size(), false); |
|
|
|
for (auto dim : reduce_dimensions) { |
|
|
|
dim_in_output[dim] = false; |
|
|
|
} |
|
|
|
for (auto pair_dim : unmodified_dim_pair) { |
|
|
|
dim_unmodified[pair_dim.second] = true; |
|
|
|
} |
|
|
|
bool replace = true; |
|
|
|
for (size_t i = 0; i < dim_in_output.size(); ++i) { |
|
|
|
if (dim_in_output[i] && !dim_unmodified[i]) { |
|
|
|
replace = false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (replace) { |
|
|
|
ShapeVector un_dimensions; |
|
|
|
for (auto pair_dim : unmodified_dim_pair) { |
|
|
|
if (dim_in_output[pair_dim.second]) { |
|
|
|
un_dimensions.emplace_back(pair_dim.first); |
|
|
|
} |
|
|
|
} |
|
|
|
ShapeVector new_dimensions; |
|
|
|
for (size_t i = 0; i < input_shape.size(); ++i) { |
|
|
|
if (std::find(un_dimensions.begin(), un_dimensions.end(), i) == un_dimensions.end()) { |
|
|
|
new_dimensions.emplace_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); |
|
|
|
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); |
|
|
|
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); |
|
|
|
return new_cnode; |
|
|
|
} |
|
|
|
return node; |
|
|
|
}; |
|
|
|
std::list<PrimitivePtr> ReduceOperations = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; |
|
|
|
for (auto operation : ReduceOperations) { |
|
|
|
// Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lamda, |
|
|
|
operation); |
|
|
|
// Reduce(Reduce(A)) = Reduce(A) |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lamda, operation); |
|
|
|
// Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions |
|
|
|
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lamda, |
|
|
|
operation); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr TrySimplify(const AnfNodePtr &node) { |
|
|
|
std::list<std::function<AnfNodePtr(AnfNodePtr)>> SimplifyFuncList = { |
|
|
|
SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg, |
|
|
|
SimplifyPow, SimplifyRsqrt, SimplifySelect, SimplifySqrt, SimplifySub}; |
|
|
|
SimplifyAdd, SimplifyDiv, SimplifyLog, SimplifyMul, SimplifyNeg, SimplifyPow, SimplifyRsqrt, |
|
|
|
SimplifySelect, SimplifySqrt, SimplifySub, SimplifyTranspose, SimplifyMatMul, SimplifyReduce}; |
|
|
|
for (auto f : SimplifyFuncList) { |
|
|
|
auto ret = f(node); |
|
|
|
if (ret != nullptr) { |
|
|
|
|