diff --git a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt index 9183b0e6f7..479926c1e7 100644 --- a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt @@ -25,6 +25,7 @@ endif () add_library(cpp-API OBJECT config.cc datasets.cc + execute.cc iterator.cc transforms.cc samplers.cc diff --git a/mindspore/ccsrc/minddata/dataset/api/execute.cc b/mindspore/ccsrc/minddata/dataset/api/execute.cc index 91f486f3f9..7b7e15d638 100644 --- a/mindspore/ccsrc/minddata/dataset/api/execute.cc +++ b/mindspore/ccsrc/minddata/dataset/api/execute.cc @@ -15,7 +15,9 @@ */ #include "minddata/dataset/include/execute.h" +#ifdef ENABLE_ANDROID #include "minddata/dataset/include/de_tensor.h" +#endif #include "minddata/dataset/include/tensor.h" #include "minddata/dataset/kernels/tensor_op.h" #ifndef ENABLE_ANDROID @@ -29,6 +31,7 @@ namespace dataset { Execute::Execute(std::shared_ptr op) : op_(std::move(op)) {} +#ifdef ENABLE_ANDROID std::shared_ptr Execute::operator()(std::shared_ptr input) { // Build the op if (op_ == nullptr) { @@ -52,6 +55,7 @@ std::shared_ptr Execute::operator()(std::shared_ptr(std::move(de_output)); } +#endif std::shared_ptr Execute::operator()(std::shared_ptr input) { // Build the op diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index a0d21dae20..62c49b92f3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -298,30 +298,45 @@ Status DatasetOp::GetDatasetSize(int64_t *dataset_size) { *dataset_size = dataset_size_; return Status::OK(); } - CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Can't get the dataset size for the current tree."); - - return child_[0]->GetDatasetSize(dataset_size); + if (child_.size() == 1) { + return child_[0]->GetDatasetSize(dataset_size); + } else if (child_.size() > 1) { + // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. + // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will + // always be in front of the child_ structure, so we get the dataset size from the last child. + return child_[child_.size() - 1]->GetDatasetSize(dataset_size); + } else { + RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); + } } // Gets the number of classes Status DatasetOp::GetNumClasses(int64_t *num_classes) { - if (num_classes_ > 0) { - *num_classes = num_classes_; - return Status::OK(); - } - if (!child_.empty()) { + if (child_.size() == 1) { return child_[0]->GetNumClasses(num_classes); + } else if (child_.size() > 1) { + // It is okay for dataset to have more than 1 child, GetNumClasses shouldn't fail in this case. + // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will + // always be in front of the child_ structure, so we get num classes from the last child. + return child_[child_.size() - 1]->GetNumClasses(num_classes); } else { + // when num classes isn't found, the default behavior is to return -1 *num_classes = -1; - RETURN_STATUS_UNEXPECTED("Can't get the number of classes for the current tree."); + return Status::OK(); } } Status DatasetOp::GetClassIndexing(std::vector>> *output_class_indexing) { - if (!child_.empty()) { + if (child_.size() == 1) { return child_[0]->GetClassIndexing(output_class_indexing); + } else if (child_.size() > 1) { + // It is okay for dataset to have more than 1 child, GetClassIndexing shouldn't fail in this case. + // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will + // always be in the front of the child_ structure, so we get data from the last child. + return child_[child_.size() - 1]->GetClassIndexing(output_class_indexing); } else { - RETURN_STATUS_UNEXPECTED("Can't get the class index for the current tree."); + *output_class_indexing = {}; + RETURN_STATUS_UNEXPECTED("Trying to get class index from leaf node, missing override"); } } @@ -478,17 +493,31 @@ void DatasetOp::UpdateRepeatAndEpochCounter() { if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; } + int64_t DatasetOp::GetTreeBatchSize() { - if (!child_.empty()) { + if (child_.size() == 1) { return child_[0]->GetTreeBatchSize(); + } else if (child_.size() > 1) { + // It is okay for dataset to have more than 1 child, GetBatchSize shouldn't fail in this case. + // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will + // always be in front of the child_ structure, so we get data from the last child. + return child_[child_.size() - 1]->GetTreeBatchSize(); + } else { + return 1; } - return 1; } + int64_t DatasetOp::GetTreeRepeatCount() { - if (!child_.empty()) { + if (child_.size() == 1) { return child_[0]->GetTreeRepeatCount(); + } else if (child_.size() > 1) { + // It is okay for dataset to have more than 1 child, GetRepeatCount shouldn't fail in this case. + // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will + // always be in front of the child_ structure, so we get data from the last child. + return child_[child_.size() - 1]->GetTreeRepeatCount(); + } else { + return 1; } - return 1; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index c5619eaa6b..4dea9d60e0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -70,15 +70,6 @@ Status TFRecordNode::ValidateParams() { return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); } - if (cache_ == nullptr && !shard_equal_rows_ && dataset_files_.size() < num_shards_) { - // This check only makes sense in a non-cache path. We should make sure there is at least one file per - // shard in file-based sharding - std::string err_msg = - "TFRecordNode: Invalid number of dataset files, should at least be " + std::to_string(num_shards_); - MS_LOG(ERROR) << err_msg; - return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); - } - std::vector invalid_files(dataset_files_.size()); auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), [](const std::string &filename) { return !TFReaderOp::ValidateFirstRowCrc(filename); }); diff --git a/mindspore/ccsrc/minddata/dataset/include/execute.h b/mindspore/ccsrc/minddata/dataset/include/execute.h index b3457aa487..64c1a55c3a 100644 --- a/mindspore/ccsrc/minddata/dataset/include/execute.h +++ b/mindspore/ccsrc/minddata/dataset/include/execute.h @@ -20,7 +20,9 @@ #include #include #include "minddata/dataset/core/constants.h" +#ifdef ENABLE_ANDROID #include "minddata/dataset/include/de_tensor.h" +#endif #include "minddata/dataset/include/tensor.h" #include "minddata/dataset/include/transforms.h" @@ -35,10 +37,16 @@ class Execute { /// \brief Constructor explicit Execute(std::shared_ptr op); +#ifdef ENABLE_ANDROID /// \brief callable function to execute the TensorOperation in eager mode /// \param[inout] input - the tensor to be transformed /// \return - the output tensor, nullptr if Compute fails std::shared_ptr operator()(std::shared_ptr input); +#endif + + /// \brief callable function to execute the TensorOperation in eager mode + /// \param[inout] input - the tensor to be transformed + /// \return - the output tensor, nullptr if Compute fails std::shared_ptr operator()(std::shared_ptr input); private: diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 57be1f4457..896c284896 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -57,6 +57,7 @@ SET(DE_UT_SRCS distributed_sampler_test.cc epoch_ctrl_op_test.cc equalize_op_test.cc + execute_test.cc execution_tree_test.cc fill_op_test.cc global_context_test.cc diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc new file mode 100644 index 0000000000..83029d2400 --- /dev/null +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "common/cvop_common.h" +#include "minddata/dataset/include/execute.h" +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/include/vision.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestExecute : public UT::CVOP::CVOpCommon { + protected: + MindDataTestExecute() : CVOpCommon() {} + + std::shared_ptr output_tensor_; +}; + +TEST_F(MindDataTestExecute, TestOp1) { + MS_LOG(INFO) << "Doing testCrop."; + // Crop params + std::shared_ptr center_crop = vision::CenterCrop({30}); + std::shared_ptr out_image = Execute(std::move(center_crop))(input_tensor_); + EXPECT_NE(out_image, nullptr); + EXPECT_EQ(30, out_image->shape()[0]); + EXPECT_EQ(30, out_image->shape()[1]); +}