/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" #include "minddata/dataset/engine/opt/pre/getter_pass.h" using namespace mindspore::dataset; using mindspore::LogStream; using mindspore::MsLogLevel::INFO; class MindDataTestOptimizationPass : public UT::DatasetOpTesting { public: MindDataTestOptimizationPass() = default; void SetUp() override { GlobalInit(); } // this recursive function helps build a ExecutionTree from a IR node, it is copied from TreeAdapter Status DFSBuild(std::shared_ptr ir, std::shared_ptr *op, ExecutionTree *tree) { std::vector> ops = ir->Build(); CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty() && tree != nullptr && op != nullptr, "Fail To Build Tree."); (*op) = ops.front(); RETURN_IF_NOT_OK(tree->AssociateNode(*op)); for (size_t i = 1; i < ops.size(); i++) { RETURN_IF_NOT_OK(tree->AssociateNode(ops[i])); RETURN_IF_NOT_OK(ops[i - 1]->AddChild(ops[i])); } for (std::shared_ptr child_ir : ir->Children()) { std::shared_ptr child_op; RETURN_IF_NOT_OK(DFSBuild(child_ir, &child_op, tree)); RETURN_IF_NOT_OK(ops.back()->AddChild(child_op)); // append children to the last of ops } return Status::OK(); } // this function will build an execution_tree from a root ir node. nullptr will be returned if error occurs std::unique_ptr BuildTree(std::shared_ptr ir) { std::unique_ptr tree = std::make_unique(); std::shared_ptr root; if (DFSBuild(ir, &root, tree.get()).IsError()) return nullptr; if (tree->AssignRoot(root).IsError()) return nullptr; return tree; } }; TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) { MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestOutputShapeAndTypePass."; // config leaf_op, use random_data to avoid I/O std::shared_ptr schema = std::make_shared(); ASSERT_TRUE(schema->add_column("label", "uint32", {})); std::shared_ptr ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2); std::unique_ptr exe_tree = BuildTree(ds->IRNode()); ASSERT_NE(exe_tree, nullptr); // test the optimization pass // OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty std::function pass = [](OptPass pre) { // return a new pass, this will override all the existing pre-pass es pre.clear(); pre.push_back(std::make_unique(GetterPass::kOutputShapeAndType)); return pre; }; exe_tree->SetPrePassOverride(pass); ASSERT_OK(exe_tree->PrepareTreePreAction()); std::stringstream ss; // print the tree in std::string as a way to verify that nodes are indeed removed exe_tree->Print(ss); std::string ss_str = ss.str(); // ss_str would look like this // +- ( 0) : [workers: 4] [batch size: 2] // +- ( 2) : [workers: 0 (inlined)] // +- ( 4) : [workers: 4] [total rows: 44] // // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); EXPECT_EQ(ss_str.find("RepeatOp"), ss_str.npos); EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); } TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { MS_LOG(INFO) << "Doing MindDataTestOptimizationPass-MindDataTestDatasetSizePass."; // config leaf_op, use random_data to avoid I/O std::shared_ptr schema = std::make_shared(); ASSERT_TRUE(schema->add_column("label", "uint32", {})); std::shared_ptr ds = RandomData(44, schema)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2); std::unique_ptr exe_tree = BuildTree(ds->IRNode()); ASSERT_NE(exe_tree, nullptr); // test the optimization pass // OptPass is supposed to remove concat, filter repeat, shuffle skip, take and set the callback of map to empty std::function pass = [](OptPass pre) { // return a new pass, this will override all the existing pre-pass es pre.clear(); // remove all existing pre pass pre.push_back(std::make_unique(GetterPass::kDatasetSize)); return pre; }; exe_tree->SetPrePassOverride(pass); ASSERT_OK(exe_tree->PrepareTreePreAction()); std::stringstream ss; // print the tree in std::string as a way to verify that nodes are indeed removed exe_tree->Print(ss); std::string ss_str = ss.str(); // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); }