Browse Source

!23399 fix bug of tbe check

Merge pull request !23399 from hwjiaorui/tbe-check
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
2f66938553
2 changed files with 13 additions and 7 deletions
  1. +9
    -3
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.cc
  2. +4
    -4
      tests/ut/cpp/tbe/tbe_json_creator_test.cc

+ 9
- 3
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.cc View File

@@ -389,7 +389,10 @@ void CheckTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_ou
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
shape = ori_shape;
shape = TbeJsonUtils::GetOutputDeviceShapeForTbeBuild(anf_node, node_out_idx);
if (shape.empty()) {
shape.emplace_back(1);
}
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
format = TbeAdapter::FormatPass(format, ori_shape.size());
@@ -407,11 +410,14 @@ void CheckTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t re
nlohmann::json *input_desc) {
MS_EXCEPTION_IF_NULL(anf_node);
GenDesJsonCommon(input_desc);
auto shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
auto ori_shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
auto shape = TbeJsonUtils::GetInputDeviceShapeForTbeBuild(anf_node, real_input_index);
if (shape.empty()) {
shape.emplace_back(1);
}
auto ori_shape = shape;

auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
auto format = AnfAlgo::GetInputFormat(anf_node, real_input_index);


+ 4
- 4
tests/ut/cpp/tbe/tbe_json_creator_test.cc View File

@@ -78,7 +78,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 4297213426602035622U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 5131870964632527075U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 6011570351795510237U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 6011570351795510237U);
}
@@ -121,7 +121,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 3804649253898608226U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 5736923382341947495U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4580870880229487185U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 4580870880229487185U);
}
@@ -179,7 +179,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13058640182660031121U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 5110289197661808901U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4729701784171992376U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 4729701784171992376U);
}
@@ -233,7 +233,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 1114128635775386802U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 2636386772926575020U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 9247341733773157591U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 9247341733773157591U);
}


Loading…
Cancel
Save