diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc index ca2cc18fbd..ecf67a1b45 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc @@ -102,7 +102,8 @@ bool AkgKernelBuilder::AkgOpParallelBuild(const std::vector &build return true; } - kernel::KernelBuildClient *client = GetClient(); + auto client = GetClient(); + MS_EXCEPTION_IF_NULL(client); if (!client->AkgStart(PROCESS_NUM, TIME_OUT)) { MS_LOG(ERROR) << "Akg start failed."; return false; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc index 13084e7392..6a2d276805 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc @@ -250,7 +250,7 @@ AnfNodePtr SimplifySelect(const AnfNodePtr &node) { return nullptr; } -AnfNodePtr SimplifyMul(const AnfNodePtr &node) { +AnfNodePtr SimplifyMul1(const AnfNodePtr &node) { if (!IsPrimitiveCNode(node, prim::kPrimMul)) { return nullptr; } @@ -278,6 +278,28 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), node_tmp}, node); return new_cnode; }; + // (x*C1)*(y*C2) ==> (x*y)*(C1*C2) + MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda); + // (x*C1)*C2 ==> x*(C1*C2) + MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2); + // exp(x)*exp(y) ==> exp(x+y) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda); + // sqrt(x)*sqrt(x) ==> x + MATCH_REPLACE_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), x, + PIsEqual()(x.GetNode(node), y.GetNode(node))); + // sqrt(x)*sqrt(y) ==> sqrt(x*y) + MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), + sqrt_merge_lambda, !PIsEqual()(x.GetNode(node), y.GetNode(node))); + return nullptr; +} + +AnfNodePtr SimplifyMul2(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMul)) { + return nullptr; + } + PatternNode x, y; + PConstant const_1(node), const_2(node); + auto rsqrt_merge_lambda = [&node, &x]() -> AnfNodePtr { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReciprocal), x.GetNode(node)}, node); return new_cnode; @@ -296,18 +318,6 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node); return new_cnode; }; - // (x*C1)*(y*C2) ==> (x*y)*(C1*C2) - MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda); - // (x*C1)*C2 ==> x*(C1*C2) - MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2); - // exp(x)*exp(y) ==> exp(x+y) - MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda); - // sqrt(x)*sqrt(x) ==> x - MATCH_REPLACE_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), x, - PIsEqual()(x.GetNode(node), y.GetNode(node))); - // sqrt(x)*sqrt(y) ==> sqrt(x*y) - MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), - sqrt_merge_lambda, !PIsEqual()(x.GetNode(node), y.GetNode(node))); // rsqrt(x)*rsqrt(x) ==> 1/x MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y), rsqrt_merge_lambda, PIsEqual()(x.GetNode(node), y.GetNode(node))); @@ -323,12 +333,12 @@ AnfNodePtr SimplifyMul(const AnfNodePtr &node) { return nullptr; } -AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { +AnfNodePtr SimplifyDiv1(const AnfNodePtr &node) { if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) { return nullptr; } PatternNode x, y, u, v; - PConstant const_1(node), const_2(node); + PConstant const_1(node); PConstant const_one(node, false, 1); PConstant const_one_scalar(node, false, 1, true); @@ -353,6 +363,28 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node); return new_cnode; }; + // x/1 ==> x + MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x); + MATCH_REPLACE(node, x / const_one, x); + // e^x/e^y ==> e^(x-y) + MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_1); + // x / e^y ==> x * e^(-y) + MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_2); + // x / y^const ==> x * y^(-const) + MATCH_REPLACE_LAMBDA(node, x / PBinOperation(prim::kPrimPow, y, const_1), div_pow_const); + // x / sqrt(x) ==> sqrt(x) + MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_1, + PIsEqual()(x.GetNode(node), y.GetNode(node))); + return nullptr; +} + +AnfNodePtr SimplifyDiv2(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) { + return nullptr; + } + PatternNode x, y, u, v; + PConstant const_1(node); + auto div_sqrt_lambda_2 = [&node, &x, &y]() -> AnfNodePtr { auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), y.GetNode(node)}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node); @@ -377,6 +409,25 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, new_rhs}, node); return new_cnode; }; + // x / sqrt(y) ==> x * rsqrt(y) + MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_2, + !PIsEqual()(x.GetNode(node), y.GetNode(node))); + // x / rsqrt(y) ==> x * sqrt(y) + MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda); + // // x / const ==> x * (1/const) + MATCH_REPLACE_LAMBDA(node, x / const_1, div_const); + // (x/y) / (u/v) ==> (x*v) / (y*u) + MATCH_REPLACE_LAMBDA(node, (x / y) / (u / v), div_lambda_1); + return nullptr; +} + +AnfNodePtr SimplifyDiv3(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) { + return nullptr; + } + PatternNode x, y, u, v; + PConstant const_1(node), const_2(node); + auto div_lambda_2 = [&node, &x, &y, &u]() -> AnfNodePtr { auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node); auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node); @@ -392,29 +443,8 @@ AnfNodePtr SimplifyDiv(const AnfNodePtr &node) { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node); return new_cnode; }; - // x/1 ==> x - MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x); - MATCH_REPLACE(node, x / const_one, x); - // e^x/e^y ==> e^(x-y) - MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_1); - // x / e^y ==> x * e^(-y) - MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_2); - // x / y^const ==> x * y^(-const) - MATCH_REPLACE_LAMBDA(node, x / PBinOperation(prim::kPrimPow, y, const_1), div_pow_const); - // x / sqrt(x) ==> sqrt(x) - MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_1, - PIsEqual()(x.GetNode(node), y.GetNode(node))); - // x / sqrt(y) ==> x * rsqrt(y) - MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_2, - !PIsEqual()(x.GetNode(node), y.GetNode(node))); - // x / rsqrt(y) ==> x * sqrt(y) - MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda); // Neg(x) / const = x / (-const) MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimNeg, x) / const_1, neg_div_lambda); - // // x / const ==> x * (1/const) - MATCH_REPLACE_LAMBDA(node, x / const_1, div_const); - // (x/y) / (u/v) ==> (x*v) / (y*u) - MATCH_REPLACE_LAMBDA(node, (x / y) / (u / v), div_lambda_1); // (x/y) / u ==> x / (y*u) MATCH_REPLACE_LAMBDA(node, (x / y) / u, div_lambda_2); // x / (u/v) ==> (x*v) / u @@ -556,52 +586,22 @@ std::vector> GetUnmodifiedDim(const ShapeVector &a, return unmodified; } -AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimReduceMax) && !IsPrimitiveCNode(node, prim::kPrimReduceMin) && - !IsPrimitiveCNode(node, prim::kPrimReduceSum)) { +std::list RedOps = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; + +bool IsRedOps(const AnfNodePtr &node) { + if (std::any_of(RedOps.begin(), RedOps.end(), + [&node](const PrimitivePtr &ops) { return IsPrimitiveCNode(node, ops); })) { + return true; + } + return false; +} + +// Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions +AnfNodePtr SimplifyReduce1(const AnfNodePtr &node) { + if (!IsRedOps(node)) { return nullptr; } PatternNode x; - auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { - auto shape = GetNodeShape(node); - if (shape.size() != 0 && shape.size() != 1) { - return nullptr; - } else { - auto tmp_node = node->cast(); - auto transpose_node = tmp_node->input(1); - auto transpose_dimensions = - GetValue>(AnfAlgo::GetNodeAttr(transpose_node, "perm")); - ShapeVector new_dimensions; - auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(tmp_node, "axis")); - std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions), - [&transpose_dimensions](const int64_t &dim) { return transpose_dimensions[dim]; }); - std::sort(new_dimensions.begin(), new_dimensions.end()); - auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); - AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); - AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); - return new_cnode; - } - }; - auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { - auto tmp_node = node->cast(); - auto arg_node = tmp_node->input(1); - auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(arg_node, "axis")); - auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(tmp_node, "axis")); - ShapeVector new_dimensions; - for (size_t i = 0; i < arg_dimensions.size(); ++i) { - for (size_t j = 0; j < reduce_dimensions.size(); ++j) { - if (reduce_dimensions[j] >= arg_dimensions[i]) { - ++reduce_dimensions[j]; - } - } - } - std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(), - std::back_inserter(new_dimensions)); - auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); - AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); - AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); - return new_cnode; - }; auto reshape_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { auto tmp_node = node->cast(); auto arg_node = tmp_node->input(1); @@ -643,6 +643,37 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { } return nullptr; }; + for (auto op : RedOps) { + MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(op, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lambda, op); + } + return nullptr; +} + +AnfNodePtr SimplifyReduce2(const AnfNodePtr &node) { + if (!IsRedOps(node)) { + return nullptr; + } + PatternNode x; + auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { + auto tmp_node = node->cast(); + auto arg_node = tmp_node->input(1); + auto arg_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(arg_node, "axis")); + auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(tmp_node, "axis")); + ShapeVector new_dimensions; + for (size_t i = 0; i < arg_dimensions.size(); ++i) { + for (size_t j = 0; j < reduce_dimensions.size(); ++j) { + if (reduce_dimensions[j] >= arg_dimensions[i]) { + ++reduce_dimensions[j]; + } + } + } + std::merge(arg_dimensions.begin(), arg_dimensions.end(), reduce_dimensions.begin(), reduce_dimensions.end(), + std::back_inserter(new_dimensions)); + auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); + AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); + AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); + return new_cnode; + }; auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr { auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node); AnfAlgo::CopyNodeAttr("axis", node, arg_node); @@ -650,16 +681,9 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), arg_node}, node); return new_cnode; }; - std::list ReduceOperations = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}; - for (auto operation : ReduceOperations) { - // Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector - MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lambda, - operation); + for (auto operation : RedOps) { // Reduce(Reduce(A)) = Reduce(A) MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lambda, operation); - // Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions - MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lambda, - operation); } // ReduceSum(Neg(x)) = Neg(ReduceSum(x)) MATCH_REPLACE_LAMBDA(node, PPrimitive(prim::kPrimReduceSum, PUnaryOperation(prim::kPrimNeg, x)), @@ -667,8 +691,41 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) { return nullptr; } +// Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector +AnfNodePtr SimplifyReduce3(const AnfNodePtr &node) { + if (!IsRedOps(node)) { + return nullptr; + } + PatternNode x; + auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr { + auto shape = GetNodeShape(node); + if (shape.size() != 0 && shape.size() != 1) { + return nullptr; + } else { + auto tmp_node = node->cast(); + auto transpose_node = tmp_node->input(1); + auto transpose_dimensions = + GetValue>(AnfAlgo::GetNodeAttr(transpose_node, "perm")); + ShapeVector new_dimensions; + auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr(tmp_node, "axis")); + std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions), + [&transpose_dimensions](const int64_t &dim) { return transpose_dimensions[dim]; }); + std::sort(new_dimensions.begin(), new_dimensions.end()); + auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node); + AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode); + AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode); + return new_cnode; + } + }; + for (auto operation : RedOps) { + MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lambda, + operation); + } + return nullptr; +} + AnfNodePtr TrySimplify(const AnfNodePtr &node) { - std::list> SimplifyFuncList = {SimplifyReduce}; + std::list> SimplifyFuncList = {SimplifyReduce1}; for (auto f : SimplifyFuncList) { auto ret = f(node); if (ret != nullptr) {