Browse Source

fix unsorted_segment_sum_fission pass

tags/v1.1.0
huanghui 5 years ago
parent
commit
1c6c280da7
2 changed files with 30 additions and 18 deletions
  1. +22
    -14
      mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc
  2. +8
    -4
      tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc

+ 22
- 14
mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc View File

@@ -74,6 +74,26 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice);
return slice;
}

bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
<< ". CNode= " << origin_node->DebugString();
return false;
}
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
auto y_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
if (x_shape.empty() || y_shape.empty()) {
return false;
}
if (x_shape[x_shape.size() - 1] != 1) {
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<< x_shape[x_shape.size() - 1];
return false;
}
return x_shape.size() > y_shape.size();
}
} // namespace

const BaseRef UnsortSegmentSumFission::DefinePattern() const {
@@ -88,19 +108,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(node);
auto origin_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_node);
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
MS_LOG(INFO) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
<< ". CNode= " << origin_node->DebugString();
return nullptr;
}
auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
if (input0_shape.size() < 2) {
MS_LOG(INFO) << "Input0's shape size less than 2, not optimize";
return nullptr;
}
if (input0_shape[input0_shape.size() - 1] != 1) {
MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<< input0_shape[input0_shape.size() - 1];
if (!CheckInputs(origin_node)) {
return nullptr;
}
size_t pad_dim_size;
@@ -110,7 +118,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
} else if (input_dtype == kNumberTypeFloat16) {
pad_dim_size = 16;
} else {
MS_LOG(INFO) << "UnsortedSegmentSum data type not in (float21, float16), no need change";
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change";
return nullptr;
}



+ 8
- 4
tests/ut/cpp/pre_activate/ascend/ir_fission/unsorted_segment_sum_fission_test.cc View File

@@ -32,9 +32,11 @@ class TestHWUnsortedSegmentSumFission : public BackendCommon {
TEST_F(TestHWUnsortedSegmentSumFission, test_fission) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before1");
EXPECT_NE(g, nullptr);
std::vector<int64_t> shp_x{16, 1};
std::vector<int64_t> shp_x{3, 39, 1};
std::vector<int64_t> shp_y{3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
auto kg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
@@ -50,9 +52,11 @@ TEST_F(TestHWUnsortedSegmentSumFission, test_fission) {
TEST_F(TestHWUnsortedSegmentSumFission, test_no_fission) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_unsorted_segment_sum_fission", "before2");
EXPECT_NE(g, nullptr);
std::vector<int64_t> shp_x{16, 2};
std::vector<int64_t> shp_x{3, 39, 2};
std::vector<int64_t> shp_y{3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
auto kg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();


Loading…
Cancel
Save