Browse Source

!15158 fix codex check

From: @lingyunli63
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
tags/v1.2.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
627fbb7137
3 changed files with 7 additions and 242 deletions
  1. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc
  2. +1
    -240
      mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc
  3. +4
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc

+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc View File

@@ -102,7 +102,8 @@ bool AkgKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build
return true;
}

kernel::KernelBuildClient *client = GetClient();
auto client = GetClient();
MS_EXCEPTION_IF_NULL(client);
if (!client->AkgStart(PROCESS_NUM, TIME_OUT)) {
MS_LOG(ERROR) << "Akg start failed.";
return false;


+ 1
- 240
mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc View File

@@ -250,178 +250,6 @@ AnfNodePtr SimplifySelect(const AnfNodePtr &node) {
return nullptr;
}

AnfNodePtr SimplifyMul(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimMul)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y;
PConstant<AnfNodePtr> const_1(node), const_2(node);

auto const_dup_lambda = [&node, &x, &y, &const_1, &const_2]() -> AnfNodePtr {
auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node);
auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node));
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), new_lhs, new_rhs}, node);
return new_cnode;
};
auto const_dup_lambda2 = [&node, &x, &const_1, &const_2]() -> AnfNodePtr {
auto new_rhs = const_1.MulByPatternConst(const_2, x.GetNode(node));
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
auto exp_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimAdd), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
return new_cnode;
};
auto sqrt_merge_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), node_tmp}, node);
return new_cnode;
};
auto rsqrt_merge_lambda = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimReciprocal), x.GetNode(node)}, node);
return new_cnode;
};
auto rsqrt_merge_lambda_2 = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), node_tmp}, node);
return new_cnode;
};
auto rsqrt_merge_lambda_3 = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node);
return new_cnode;
};
auto neg_mul_lambda = [&node, &x, &const_1]() -> AnfNodePtr {
auto new_rhs = const_1.ValueNodeWithOprations(prim::kPrimNeg);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
// (x*C1)*(y*C2) ==> (x*y)*(C1*C2)
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * (const_2 * y), const_dup_lambda);
// (x*C1)*C2 ==> x*(C1*C2)
MATCH_REPLACE_LAMBDA(node, (const_1 * x) * const_2, const_dup_lambda2);
// exp(x)*exp(y) ==> exp(x+y)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) * PUnaryOperation(prim::kPrimExp, y), exp_merge_lambda);
// sqrt(x)*sqrt(x) ==> x
MATCH_REPLACE_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y), x,
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// sqrt(x)*sqrt(y) ==> sqrt(x*y)
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimSqrt, x) * PUnaryOperation(prim::kPrimSqrt, y),
sqrt_merge_lambda, !PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// rsqrt(x)*rsqrt(x) ==> 1/x
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y),
rsqrt_merge_lambda, PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// rsqrt(x)*rsqrt(y) ==> rsqrt(x*y)
MATCH_REPLACE_LAMBDA_IF(node, PUnaryOperation(prim::kPrimRsqrt, x) * PUnaryOperation(prim::kPrimRsqrt, y),
rsqrt_merge_lambda_2, !PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// x*rsqrt(x) ==> sqrt(x)
MATCH_REPLACE_LAMBDA_IF(node, x * PUnaryOperation(prim::kPrimRsqrt, y), rsqrt_merge_lambda_3,
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// Neg(x) * const | const * Neg(x) = x * (-const)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimNeg, x) * const_1, neg_mul_lambda);
MATCH_REPLACE_LAMBDA(node, const_1 * PUnaryOperation(prim::kPrimNeg, x), neg_mul_lambda);
return nullptr;
}

AnfNodePtr SimplifyDiv(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimRealDiv)) {
return nullptr;
}
PatternNode<AnfNodePtr> x, y, u, v;
PConstant<AnfNodePtr> const_1(node), const_2(node);
PConstant<AnfNodePtr> const_one(node, false, 1);
PConstant<AnfNodePtr> const_one_scalar(node, false, 1, true);

auto div_exp_lambda_1 = [&node, &x, &y]() -> AnfNodePtr {
auto node_tmp = NewCNodeWithInfo({NewValueNode(prim::kPrimSub), x.GetNode(node), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_tmp}, node);
return new_cnode;
};
auto div_exp_lambda_2 = [&node, &x, &y]() -> AnfNodePtr {
auto node_neg = NewCNodeWithInfo({NewValueNode(prim::kPrimNeg), y.GetNode(node)}, node);
auto node_exp = NewCNodeWithInfo({NewValueNode(prim::kPrimExp), node_neg}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_exp}, node);
return new_cnode;
};
auto div_pow_const = [&node, &x, &y, &const_1]() -> AnfNodePtr {
auto new_const = const_1.ValueNodeWithOprations(prim::kPrimNeg);
auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimPow), y.GetNode(node), new_const}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
auto div_sqrt_lambda_1 = [&node, &x]() -> AnfNodePtr {
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), x.GetNode(node)}, node);
return new_cnode;
};
auto div_sqrt_lambda_2 = [&node, &x, &y]() -> AnfNodePtr {
auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimRsqrt), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node);
return new_cnode;
};
auto div_const = [&node, &x, &const_1]() -> AnfNodePtr {
auto new_const = const_1.ValueNodeWithOprations(prim::kPrimReciprocal);
if (new_const == nullptr) {
return nullptr;
}
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), new_const}, node);
return new_cnode;
};
auto div_rsqrt_lambda = [&node, &x, &y]() -> AnfNodePtr {
auto node_rsqrt = NewCNodeWithInfo({NewValueNode(prim::kPrimSqrt), y.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), node_rsqrt}, node);
return new_cnode;
};
auto div_lambda_1 = [&node, &x, &y, &u, &v]() -> AnfNodePtr {
auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), v.GetNode(node)}, node);
auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, new_rhs}, node);
return new_cnode;
};
auto div_lambda_2 = [&node, &x, &y, &u]() -> AnfNodePtr {
auto new_rhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), y.GetNode(node), u.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
auto div_lambda_3 = [&node, &x, &u, &v]() -> AnfNodePtr {
auto new_lhs = NewCNodeWithInfo({NewValueNode(prim::kPrimMul), x.GetNode(node), v.GetNode(node)}, node);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), new_lhs, u.GetNode(node)}, node);
return new_cnode;
};
auto neg_div_lambda = [&node, &x, &const_1]() -> AnfNodePtr {
auto new_rhs = const_1.ValueNodeWithOprations(prim::kPrimNeg);
auto new_cnode = NewCNodeWithInfo({NewValueNode(prim::kPrimRealDiv), x.GetNode(node), new_rhs}, node);
return new_cnode;
};
// x/1 ==> x
MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarDiv, x, const_one_scalar, false), x);
MATCH_REPLACE(node, x / const_one, x);
// e^x/e^y ==> e^(x-y)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimExp, x) / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_1);
// x / e^y ==> x * e^(-y)
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimExp, y), div_exp_lambda_2);
// x / y^const ==> x * y^(-const)
MATCH_REPLACE_LAMBDA(node, x / PBinOperation(prim::kPrimPow, y, const_1), div_pow_const);
// x / sqrt(x) ==> sqrt(x)
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_1,
PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// x / sqrt(y) ==> x * rsqrt(y)
MATCH_REPLACE_LAMBDA_IF(node, x / PUnaryOperation(prim::kPrimSqrt, y), div_sqrt_lambda_2,
!PIsEqual<AnfNodePtr>()(x.GetNode(node), y.GetNode(node)));
// x / rsqrt(y) ==> x * sqrt(y)
MATCH_REPLACE_LAMBDA(node, x / PUnaryOperation(prim::kPrimRsqrt, y), div_rsqrt_lambda);
// Neg(x) / const = x / (-const)
MATCH_REPLACE_LAMBDA(node, PUnaryOperation(prim::kPrimNeg, x) / const_1, neg_div_lambda);
// // x / const ==> x * (1/const)
MATCH_REPLACE_LAMBDA(node, x / const_1, div_const);
// (x/y) / (u/v) ==> (x*v) / (y*u)
MATCH_REPLACE_LAMBDA(node, (x / y) / (u / v), div_lambda_1);
// (x/y) / u ==> x / (y*u)
MATCH_REPLACE_LAMBDA(node, (x / y) / u, div_lambda_2);
// x / (u/v) ==> (x*v) / u
MATCH_REPLACE_LAMBDA(node, x / (u / v), div_lambda_3);
return nullptr;
}

#define PERFORM_REPLACE(OldNode, NewNode, Graph, FLAG) \
if ((NewNode) != nullptr) { \
(Graph)->manager()->Replace((OldNode), (NewNode)); \
@@ -562,26 +390,6 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
return nullptr;
}
PatternNode<AnfNodePtr> x;
auto trans_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
auto shape = GetNodeShape(node);
if (shape.size() != 0 && shape.size() != 1) {
return nullptr;
} else {
auto tmp_node = node->cast<CNodePtr>();
auto transpose_node = tmp_node->input(1);
auto transpose_dimensions =
GetValue<std::vector<int64_t>>(AnfAlgo::GetNodeAttr<ValuePtr>(transpose_node, "perm"));
ShapeVector new_dimensions;
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
std::transform(reduce_dimensions.begin(), reduce_dimensions.end(), std::back_inserter(new_dimensions),
[&transpose_dimensions](const int64_t &dim) { return transpose_dimensions[dim]; });
std::sort(new_dimensions.begin(), new_dimensions.end());
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
return new_cnode;
}
};
auto reduce_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
auto tmp_node = node->cast<CNodePtr>();
auto arg_node = tmp_node->input(1);
@@ -602,47 +410,6 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
return new_cnode;
};
auto reshape_reduce_lambda = [&node, &x](PrimitivePtr &operation) -> AnfNodePtr {
auto tmp_node = node->cast<CNodePtr>();
auto arg_node = tmp_node->input(1);
auto input_shape = GetNodeShape(arg_node->cast<CNodePtr>()->input(1));
auto re_shape = GetNodeShape(arg_node);
auto reduce_dimensions = TransAxisValueToVector(AnfAlgo::GetNodeAttr<ValuePtr>(tmp_node, "axis"));
auto unmodified_dim_pair = GetUnmodifiedDim(input_shape, re_shape);
std::vector<bool> dim_in_output(re_shape.size(), true);
std::vector<bool> dim_unmodified(re_shape.size(), false);
for (auto dim : reduce_dimensions) {
dim_in_output[dim] = false;
}
for (auto pair_dim : unmodified_dim_pair) {
dim_unmodified[pair_dim.second] = true;
}
bool replace = true;
for (size_t i = 0; i < dim_in_output.size(); ++i) {
if (dim_in_output[i] && !dim_unmodified[i]) {
replace = false;
}
}
if (replace) {
ShapeVector un_dimensions;
for (auto pair_dim : unmodified_dim_pair) {
if (dim_in_output[pair_dim.second]) {
un_dimensions.emplace_back(pair_dim.first);
}
}
ShapeVector new_dimensions;
for (size_t i = 0; i < input_shape.size(); ++i) {
if (std::find(un_dimensions.begin(), un_dimensions.end(), i) == un_dimensions.end()) {
new_dimensions.emplace_back(i);
}
}
auto new_cnode = NewCNodeWithInfo({NewValueNode(operation), x.GetNode(node)}, node);
AnfAlgo::SetNodeAttr("axis", MakeValue(new_dimensions), new_cnode);
AnfAlgo::CopyNodeAttr("keep_dims", node, new_cnode);
return new_cnode;
}
return nullptr;
};
auto neg_reducesum_lambda = [&node, &x]() -> AnfNodePtr {
auto arg_node = NewCNodeWithInfo({NewValueNode(prim::kPrimReduceSum), x.GetNode(node)}, node);
AnfAlgo::CopyNodeAttr("axis", node, arg_node);
@@ -652,14 +419,8 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
};
std::list<PrimitivePtr> ReduceOperations = {prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin};
for (auto operation : ReduceOperations) {
// Reduce(Transpose(A)) = Reduce(A) if result is a scalar or vector
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimTranspose, x)), trans_reduce_lambda,
operation);
// Reduce(Reduce(A)) = Reduce(A)
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(operation, x)), reduce_reduce_lambda, operation);
// Reduce(Reshape(A)) = Reduce(A) if reduce dimensions is not in reshape dimensions
MATCH_REPLACE_LAMBDA_FLAG(node, PPrimitive(operation, PPrimitive(prim::kPrimReshape, x)), reshape_reduce_lambda,
operation);
}
// ReduceSum(Neg(x)) = Neg(ReduceSum(x))
MATCH_REPLACE_LAMBDA(node, PPrimitive(prim::kPrimReduceSum, PUnaryOperation(prim::kPrimNeg, x)),
@@ -668,7 +429,7 @@ AnfNodePtr SimplifyReduce(const AnfNodePtr &node) {
}

AnfNodePtr TrySimplify(const AnfNodePtr &node) {
std::list<std::function<AnfNodePtr(AnfNodePtr)>> SimplifyFuncList = {SimplifyReduce};
std::list<std::function<AnfNodePtr(const AnfNodePtr &)>> SimplifyFuncList = {SimplifyReduce};
for (auto f : SimplifyFuncList) {
auto ret = f(node);
if (ret != nullptr) {


+ 4
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc View File

@@ -15,10 +15,13 @@
*/
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"

#include <memory>
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include <vector>
#include <string>



Loading…
Cancel
Save