| @@ -25,6 +25,7 @@ endif () | |||||
| add_library(cpp-API OBJECT | add_library(cpp-API OBJECT | ||||
| config.cc | config.cc | ||||
| datasets.cc | datasets.cc | ||||
| execute.cc | |||||
| iterator.cc | iterator.cc | ||||
| transforms.cc | transforms.cc | ||||
| samplers.cc | samplers.cc | ||||
| @@ -15,7 +15,9 @@ | |||||
| */ | */ | ||||
| #include "minddata/dataset/include/execute.h" | #include "minddata/dataset/include/execute.h" | ||||
| #ifdef ENABLE_ANDROID | |||||
| #include "minddata/dataset/include/de_tensor.h" | #include "minddata/dataset/include/de_tensor.h" | ||||
| #endif | |||||
| #include "minddata/dataset/include/tensor.h" | #include "minddata/dataset/include/tensor.h" | ||||
| #include "minddata/dataset/kernels/tensor_op.h" | #include "minddata/dataset/kernels/tensor_op.h" | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -29,6 +31,7 @@ namespace dataset { | |||||
| Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} | Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} | ||||
| #ifdef ENABLE_ANDROID | |||||
| std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MSTensor> input) { | std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MSTensor> input) { | ||||
| // Build the op | // Build the op | ||||
| if (op_ == nullptr) { | if (op_ == nullptr) { | ||||
| @@ -52,6 +55,7 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS | |||||
| } | } | ||||
| return std::make_shared<tensor::DETensor>(std::move(de_output)); | return std::make_shared<tensor::DETensor>(std::move(de_output)); | ||||
| } | } | ||||
| #endif | |||||
| std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) { | std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) { | ||||
| // Build the op | // Build the op | ||||
| @@ -298,30 +298,45 @@ Status DatasetOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| *dataset_size = dataset_size_; | *dataset_size = dataset_size_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Can't get the dataset size for the current tree."); | |||||
| return child_[0]->GetDatasetSize(dataset_size); | |||||
| if (child_.size() == 1) { | |||||
| return child_[0]->GetDatasetSize(dataset_size); | |||||
| } else if (child_.size() > 1) { | |||||
| // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. | |||||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||||
| // always be in front of the child_ structure, so we get the dataset size from the last child. | |||||
| return child_[child_.size() - 1]->GetDatasetSize(dataset_size); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | |||||
| } | |||||
| } | } | ||||
| // Gets the number of classes | // Gets the number of classes | ||||
| Status DatasetOp::GetNumClasses(int64_t *num_classes) { | Status DatasetOp::GetNumClasses(int64_t *num_classes) { | ||||
| if (num_classes_ > 0) { | |||||
| *num_classes = num_classes_; | |||||
| return Status::OK(); | |||||
| } | |||||
| if (!child_.empty()) { | |||||
| if (child_.size() == 1) { | |||||
| return child_[0]->GetNumClasses(num_classes); | return child_[0]->GetNumClasses(num_classes); | ||||
| } else if (child_.size() > 1) { | |||||
| // It is okay for dataset to have more than 1 child, GetNumClasses shouldn't fail in this case. | |||||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||||
| // always be in front of the child_ structure, so we get num classes from the last child. | |||||
| return child_[child_.size() - 1]->GetNumClasses(num_classes); | |||||
| } else { | } else { | ||||
| // when num classes isn't found, the default behavior is to return -1 | |||||
| *num_classes = -1; | *num_classes = -1; | ||||
| RETURN_STATUS_UNEXPECTED("Can't get the number of classes for the current tree."); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } | } | ||||
| Status DatasetOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | Status DatasetOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | ||||
| if (!child_.empty()) { | |||||
| if (child_.size() == 1) { | |||||
| return child_[0]->GetClassIndexing(output_class_indexing); | return child_[0]->GetClassIndexing(output_class_indexing); | ||||
| } else if (child_.size() > 1) { | |||||
| // It is okay for dataset to have more than 1 child, GetClassIndexing shouldn't fail in this case. | |||||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||||
| // always be in the front of the child_ structure, so we get data from the last child. | |||||
| return child_[child_.size() - 1]->GetClassIndexing(output_class_indexing); | |||||
| } else { | } else { | ||||
| RETURN_STATUS_UNEXPECTED("Can't get the class index for the current tree."); | |||||
| *output_class_indexing = {}; | |||||
| RETURN_STATUS_UNEXPECTED("Trying to get class index from leaf node, missing override"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -478,17 +493,31 @@ void DatasetOp::UpdateRepeatAndEpochCounter() { | |||||
| if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; | if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; | ||||
| MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; | MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; | ||||
| } | } | ||||
| int64_t DatasetOp::GetTreeBatchSize() { | int64_t DatasetOp::GetTreeBatchSize() { | ||||
| if (!child_.empty()) { | |||||
| if (child_.size() == 1) { | |||||
| return child_[0]->GetTreeBatchSize(); | return child_[0]->GetTreeBatchSize(); | ||||
| } else if (child_.size() > 1) { | |||||
| // It is okay for dataset to have more than 1 child, GetBatchSize shouldn't fail in this case. | |||||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||||
| // always be in front of the child_ structure, so we get data from the last child. | |||||
| return child_[child_.size() - 1]->GetTreeBatchSize(); | |||||
| } else { | |||||
| return 1; | |||||
| } | } | ||||
| return 1; | |||||
| } | } | ||||
| int64_t DatasetOp::GetTreeRepeatCount() { | int64_t DatasetOp::GetTreeRepeatCount() { | ||||
| if (!child_.empty()) { | |||||
| if (child_.size() == 1) { | |||||
| return child_[0]->GetTreeRepeatCount(); | return child_[0]->GetTreeRepeatCount(); | ||||
| } else if (child_.size() > 1) { | |||||
| // It is okay for dataset to have more than 1 child, GetRepeatCount shouldn't fail in this case. | |||||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||||
| // always be in front of the child_ structure, so we get data from the last child. | |||||
| return child_[child_.size() - 1]->GetTreeRepeatCount(); | |||||
| } else { | |||||
| return 1; | |||||
| } | } | ||||
| return 1; | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -70,15 +70,6 @@ Status TFRecordNode::ValidateParams() { | |||||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); | return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); | ||||
| } | } | ||||
| if (cache_ == nullptr && !shard_equal_rows_ && dataset_files_.size() < num_shards_) { | |||||
| // This check only makes sense in a non-cache path. We should make sure there is at least one file per | |||||
| // shard in file-based sharding | |||||
| std::string err_msg = | |||||
| "TFRecordNode: Invalid number of dataset files, should at least be " + std::to_string(num_shards_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); | |||||
| } | |||||
| std::vector<std::string> invalid_files(dataset_files_.size()); | std::vector<std::string> invalid_files(dataset_files_.size()); | ||||
| auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), | auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), | ||||
| [](const std::string &filename) { return !TFReaderOp::ValidateFirstRowCrc(filename); }); | [](const std::string &filename) { return !TFReaderOp::ValidateFirstRowCrc(filename); }); | ||||
| @@ -20,7 +20,9 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #ifdef ENABLE_ANDROID | |||||
| #include "minddata/dataset/include/de_tensor.h" | #include "minddata/dataset/include/de_tensor.h" | ||||
| #endif | |||||
| #include "minddata/dataset/include/tensor.h" | #include "minddata/dataset/include/tensor.h" | ||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| @@ -35,10 +37,16 @@ class Execute { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit Execute(std::shared_ptr<TensorOperation> op); | explicit Execute(std::shared_ptr<TensorOperation> op); | ||||
| #ifdef ENABLE_ANDROID | |||||
| /// \brief callable function to execute the TensorOperation in eager mode | /// \brief callable function to execute the TensorOperation in eager mode | ||||
| /// \param[inout] input - the tensor to be transformed | /// \param[inout] input - the tensor to be transformed | ||||
| /// \return - the output tensor, nullptr if Compute fails | /// \return - the output tensor, nullptr if Compute fails | ||||
| std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input); | std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input); | ||||
| #endif | |||||
| /// \brief callable function to execute the TensorOperation in eager mode | |||||
| /// \param[inout] input - the tensor to be transformed | |||||
| /// \return - the output tensor, nullptr if Compute fails | |||||
| std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input); | std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input); | ||||
| private: | private: | ||||
| @@ -57,6 +57,7 @@ SET(DE_UT_SRCS | |||||
| distributed_sampler_test.cc | distributed_sampler_test.cc | ||||
| epoch_ctrl_op_test.cc | epoch_ctrl_op_test.cc | ||||
| equalize_op_test.cc | equalize_op_test.cc | ||||
| execute_test.cc | |||||
| execution_tree_test.cc | execution_tree_test.cc | ||||
| fill_op_test.cc | fill_op_test.cc | ||||
| global_context_test.cc | global_context_test.cc | ||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common.h" | |||||
| #include "common/cvop_common.h" | |||||
| #include "minddata/dataset/include/execute.h" | |||||
| #include "minddata/dataset/include/transforms.h" | |||||
| #include "minddata/dataset/include/vision.h" | |||||
| #include "utils/log_adapter.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestExecute : public UT::CVOP::CVOpCommon { | |||||
| protected: | |||||
| MindDataTestExecute() : CVOpCommon() {} | |||||
| std::shared_ptr<Tensor> output_tensor_; | |||||
| }; | |||||
| TEST_F(MindDataTestExecute, TestOp1) { | |||||
| MS_LOG(INFO) << "Doing testCrop."; | |||||
| // Crop params | |||||
| std::shared_ptr<TensorOperation> center_crop = vision::CenterCrop({30}); | |||||
| std::shared_ptr<Tensor> out_image = Execute(std::move(center_crop))(input_tensor_); | |||||
| EXPECT_NE(out_image, nullptr); | |||||
| EXPECT_EQ(30, out_image->shape()[0]); | |||||
| EXPECT_EQ(30, out_image->shape()[1]); | |||||
| } | |||||