| @@ -25,6 +25,7 @@ endif () | |||
| add_library(cpp-API OBJECT | |||
| config.cc | |||
| datasets.cc | |||
| execute.cc | |||
| iterator.cc | |||
| transforms.cc | |||
| samplers.cc | |||
| @@ -15,7 +15,9 @@ | |||
| */ | |||
| #include "minddata/dataset/include/execute.h" | |||
| #ifdef ENABLE_ANDROID | |||
| #include "minddata/dataset/include/de_tensor.h" | |||
| #endif | |||
| #include "minddata/dataset/include/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -29,6 +31,7 @@ namespace dataset { | |||
| Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {} | |||
| #ifdef ENABLE_ANDROID | |||
| std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MSTensor> input) { | |||
| // Build the op | |||
| if (op_ == nullptr) { | |||
| @@ -52,6 +55,7 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS | |||
| } | |||
| return std::make_shared<tensor::DETensor>(std::move(de_output)); | |||
| } | |||
| #endif | |||
| std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) { | |||
| // Build the op | |||
| @@ -298,30 +298,45 @@ Status DatasetOp::GetDatasetSize(int64_t *dataset_size) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Can't get the dataset size for the current tree."); | |||
| return child_[0]->GetDatasetSize(dataset_size); | |||
| if (child_.size() == 1) { | |||
| return child_[0]->GetDatasetSize(dataset_size); | |||
| } else if (child_.size() > 1) { | |||
| // It is okay for dataset to have more than 1 child, GetDatasetSize shouldn't fail in this case. | |||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||
| // always be in front of the child_ structure, so we get the dataset size from the last child. | |||
| return child_[child_.size() - 1]->GetDatasetSize(dataset_size); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); | |||
| } | |||
| } | |||
| // Gets the number of classes | |||
| Status DatasetOp::GetNumClasses(int64_t *num_classes) { | |||
| if (num_classes_ > 0) { | |||
| *num_classes = num_classes_; | |||
| return Status::OK(); | |||
| } | |||
| if (!child_.empty()) { | |||
| if (child_.size() == 1) { | |||
| return child_[0]->GetNumClasses(num_classes); | |||
| } else if (child_.size() > 1) { | |||
| // It is okay for dataset to have more than 1 child, GetNumClasses shouldn't fail in this case. | |||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||
| // always be in front of the child_ structure, so we get num classes from the last child. | |||
| return child_[child_.size() - 1]->GetNumClasses(num_classes); | |||
| } else { | |||
| // when num classes isn't found, the default behavior is to return -1 | |||
| *num_classes = -1; | |||
| RETURN_STATUS_UNEXPECTED("Can't get the number of classes for the current tree."); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| Status DatasetOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | |||
| if (!child_.empty()) { | |||
| if (child_.size() == 1) { | |||
| return child_[0]->GetClassIndexing(output_class_indexing); | |||
| } else if (child_.size() > 1) { | |||
| // It is okay for dataset to have more than 1 child, GetClassIndexing shouldn't fail in this case. | |||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||
| // always be in the front of the child_ structure, so we get data from the last child. | |||
| return child_[child_.size() - 1]->GetClassIndexing(output_class_indexing); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Can't get the class index for the current tree."); | |||
| *output_class_indexing = {}; | |||
| RETURN_STATUS_UNEXPECTED("Trying to get class index from leaf node, missing override"); | |||
| } | |||
| } | |||
| @@ -478,17 +493,31 @@ void DatasetOp::UpdateRepeatAndEpochCounter() { | |||
| if (op_current_repeats_ % op_num_repeats_per_epoch_ == 0) op_current_epochs_++; | |||
| MS_LOG(DEBUG) << Name() << " current repeats: " << op_current_repeats_ << ", current epochs: " << op_current_epochs_; | |||
| } | |||
| int64_t DatasetOp::GetTreeBatchSize() { | |||
| if (!child_.empty()) { | |||
| if (child_.size() == 1) { | |||
| return child_[0]->GetTreeBatchSize(); | |||
| } else if (child_.size() > 1) { | |||
| // It is okay for dataset to have more than 1 child, GetBatchSize shouldn't fail in this case. | |||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||
| // always be in front of the child_ structure, so we get data from the last child. | |||
| return child_[child_.size() - 1]->GetTreeBatchSize(); | |||
| } else { | |||
| return 1; | |||
| } | |||
| return 1; | |||
| } | |||
| int64_t DatasetOp::GetTreeRepeatCount() { | |||
| if (!child_.empty()) { | |||
| if (child_.size() == 1) { | |||
| return child_[0]->GetTreeRepeatCount(); | |||
| } else if (child_.size() > 1) { | |||
| // It is okay for dataset to have more than 1 child, GetRepeatCount shouldn't fail in this case. | |||
| // This is done mostly for cache, which injects cache lookup/merge operators. Cache path will | |||
| // always be in front of the child_ structure, so we get data from the last child. | |||
| return child_[child_.size() - 1]->GetTreeRepeatCount(); | |||
| } else { | |||
| return 1; | |||
| } | |||
| return 1; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -70,15 +70,6 @@ Status TFRecordNode::ValidateParams() { | |||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| if (cache_ == nullptr && !shard_equal_rows_ && dataset_files_.size() < num_shards_) { | |||
| // This check only makes sense in a non-cache path. We should make sure there is at least one file per | |||
| // shard in file-based sharding | |||
| std::string err_msg = | |||
| "TFRecordNode: Invalid number of dataset files, should at least be " + std::to_string(num_shards_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| std::vector<std::string> invalid_files(dataset_files_.size()); | |||
| auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(), | |||
| [](const std::string &filename) { return !TFReaderOp::ValidateFirstRowCrc(filename); }); | |||
| @@ -20,7 +20,9 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "minddata/dataset/core/constants.h" | |||
| #ifdef ENABLE_ANDROID | |||
| #include "minddata/dataset/include/de_tensor.h" | |||
| #endif | |||
| #include "minddata/dataset/include/tensor.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| @@ -35,10 +37,16 @@ class Execute { | |||
| /// \brief Constructor | |||
| explicit Execute(std::shared_ptr<TensorOperation> op); | |||
| #ifdef ENABLE_ANDROID | |||
| /// \brief callable function to execute the TensorOperation in eager mode | |||
| /// \param[inout] input - the tensor to be transformed | |||
| /// \return - the output tensor, nullptr if Compute fails | |||
| std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input); | |||
| #endif | |||
| /// \brief callable function to execute the TensorOperation in eager mode | |||
| /// \param[inout] input - the tensor to be transformed | |||
| /// \return - the output tensor, nullptr if Compute fails | |||
| std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input); | |||
| private: | |||
| @@ -57,6 +57,7 @@ SET(DE_UT_SRCS | |||
| distributed_sampler_test.cc | |||
| epoch_ctrl_op_test.cc | |||
| equalize_op_test.cc | |||
| execute_test.cc | |||
| execution_tree_test.cc | |||
| fill_op_test.cc | |||
| global_context_test.cc | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "common/cvop_common.h" | |||
| #include "minddata/dataset/include/execute.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| #include "minddata/dataset/include/vision.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestExecute : public UT::CVOP::CVOpCommon { | |||
| protected: | |||
| MindDataTestExecute() : CVOpCommon() {} | |||
| std::shared_ptr<Tensor> output_tensor_; | |||
| }; | |||
| TEST_F(MindDataTestExecute, TestOp1) { | |||
| MS_LOG(INFO) << "Doing testCrop."; | |||
| // Crop params | |||
| std::shared_ptr<TensorOperation> center_crop = vision::CenterCrop({30}); | |||
| std::shared_ptr<Tensor> out_image = Execute(std::move(center_crop))(input_tensor_); | |||
| EXPECT_NE(out_image, nullptr); | |||
| EXPECT_EQ(30, out_image->shape()[0]); | |||
| EXPECT_EQ(30, out_image->shape()[1]); | |||
| } | |||