| @@ -54,6 +54,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | {kGenerator, &DEPipeline::ParseGeneratorOp}, | ||||
| {kTfReader, &DEPipeline::ParseTFReaderOp}, | {kTfReader, &DEPipeline::ParseTFReaderOp}, | ||||
| {kProject, &DEPipeline::ParseProjectOp}, | {kProject, &DEPipeline::ParseProjectOp}, | ||||
| {kTake, &DEPipeline::ParseTakeOp}, | |||||
| {kImageFolder, &DEPipeline::ParseImageFolderOp}, | {kImageFolder, &DEPipeline::ParseImageFolderOp}, | ||||
| {kMnist, &DEPipeline::ParseMnistOp}, | {kMnist, &DEPipeline::ParseMnistOp}, | ||||
| {kManifest, &DEPipeline::ParseManifestOp}, | {kManifest, &DEPipeline::ParseManifestOp}, | ||||
| @@ -650,7 +651,16 @@ Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| DsOpPtr DEPipeline::ParseTakeOp(const py::dict &args) const { return DsOpPtr(); } | |||||
| Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||||
| if (args["count"].is_none()) { | |||||
| std::string err_msg = "Error: count is invalid or not set."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| std::shared_ptr<TakeOp> op; | |||||
| RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); | |||||
| *ptr = op; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | ||||
| std::shared_ptr<ZipOp::Builder> builder = std::make_shared<ZipOp::Builder>(); | std::shared_ptr<ZipOp::Builder> builder = std::make_shared<ZipOp::Builder>(); | ||||
| @@ -116,7 +116,7 @@ class DEPipeline { | |||||
| Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| DsOpPtr ParseTakeOp(const py::dict &args) const; | |||||
| Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| @@ -38,6 +38,7 @@ | |||||
| #include "dataset/engine/datasetops/source/mindrecord_op.h" | #include "dataset/engine/datasetops/source/mindrecord_op.h" | ||||
| #include "dataset/engine/datasetops/source/storage_op.h" | #include "dataset/engine/datasetops/source/storage_op.h" | ||||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | #include "dataset/engine/datasetops/source/tf_reader_op.h" | ||||
| #include "dataset/engine/datasetops/take_op.h" | |||||
| #include "dataset/engine/datasetops/zip_op.h" | #include "dataset/engine/datasetops/zip_op.h" | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| @@ -5,13 +5,13 @@ add_library(engine-datasetops OBJECT | |||||
| parallel_op.cc | parallel_op.cc | ||||
| pipeline_op.cc | pipeline_op.cc | ||||
| batch_op.cc | batch_op.cc | ||||
| batch_op.cc | |||||
| device_queue_op.cc | device_queue_op.cc | ||||
| map_op.cc | map_op.cc | ||||
| project_op.cc | project_op.cc | ||||
| rename_op.cc | rename_op.cc | ||||
| repeat_op.cc | repeat_op.cc | ||||
| skip_op.cc | skip_op.cc | ||||
| take_op.cc | |||||
| shuffle_op.cc | shuffle_op.cc | ||||
| zip_op.cc | zip_op.cc | ||||
| ) | ) | ||||
| @@ -88,6 +88,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work | |||||
| // If buffer is none or the rows of buffer is 0, | // If buffer is none or the rows of buffer is 0, | ||||
| // then get a buffer from child. | // then get a buffer from child. | ||||
| if (!buf || buf->NumRows() == 0) { | if (!buf || buf->NumRows() == 0) { | ||||
| if (buf && buf->eof()) { | |||||
| *p_buffer = std::move(buf); | |||||
| return Status::OK(); | |||||
| } | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | ||||
| } | } | ||||
| @@ -0,0 +1,146 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <utility> | |||||
| #include "common/utils.h" | |||||
| #include "dataset/engine/data_buffer.h" | |||||
| #include "dataset/engine/datasetops/take_op.h" | |||||
| #include "dataset/engine/db_connector.h" | |||||
| #include "dataset/engine/execution_tree.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Builder constructor. Creates the builder object. | |||||
| TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {} | |||||
| Status TakeOp::Builder::SanityCheck() const { | |||||
| if (build_max_takes_ <= 0) { | |||||
| std::string err_msg("Take count must be greater than 0."); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // The builder "build" method creates the final object. | |||||
| Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) { | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | |||||
| *ptr = std::make_shared<TakeOp>(build_max_takes_); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Constructor of the TakeOp. | |||||
| TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {} | |||||
| // A print method typically used for debugging | |||||
| void TakeOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call base class printer first | |||||
| PipelineOp::Print(out, show_all); | |||||
| // Then display our own stuff | |||||
| out << "TakeOp:" | |||||
| << "\nCurrent take count: " << take_count_ << "\nMax take count: " << max_takes_; | |||||
| } | |||||
| // This function will be call muti times to returns the buffer, when meet required max take count or meet | |||||
| // EOF buffer then this will stop. | |||||
| Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { | |||||
| if (child_.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node."); | |||||
| } | |||||
| std::unique_ptr<DataBuffer> buf; | |||||
| bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); | |||||
| if (take_count_ == max_takes_) { | |||||
| if (state_ == OpState::kDeOpRunning) { | |||||
| MS_LOG(INFO) << "meet max count and push-back eoe buffer."; | |||||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| *p_buffer = std::move(eoe_buffer); | |||||
| state_ = OpState::kDeOpIdle; | |||||
| // Reset the count and drain | |||||
| if (!last_repeat) { | |||||
| take_count_ = 0; | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||||
| while (!buf->eoe() && !buf->eof()) { | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| MS_LOG(INFO) << "meet max count and push-back eof buffer."; | |||||
| auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| *p_buffer = std::move(eof_buffer); | |||||
| take_count_ = 0; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||||
| // Loop until non EOE is received | |||||
| if (buf->eoe()) { | |||||
| take_count_ = 0; | |||||
| *p_buffer = std::move(buf); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Check if the last buf is next eof | |||||
| if (buf->eof()) { | |||||
| *p_buffer = std::move(buf); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Get buffer and push back when take_count is still small | |||||
| if (take_count_ < max_takes_) { | |||||
| RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function FillBuffer mainly prepare the buffer for returning | |||||
| Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer) { | |||||
| int32_t buffer_size = (*buffer)->NumRows(); | |||||
| if (take_count_ + buffer_size < max_takes_) { | |||||
| *data_buffer = std::move(*buffer); | |||||
| take_count_ = take_count_ + buffer_size; | |||||
| } else { | |||||
| MS_LOG(INFO) << "In last buffer: Push one buffer."; | |||||
| std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>(); | |||||
| while (take_count_ < max_takes_) { | |||||
| TensorRow new_row; | |||||
| RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row)); | |||||
| take_count_++; | |||||
| new_tensor_table->push_back(new_row); | |||||
| } | |||||
| (*buffer)->set_tensor_table(std::move(new_tensor_table)); | |||||
| *data_buffer = std::move(*buffer); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Class functor operator () override. | |||||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||||
| // However, the TakeOp is defined as a inlined operator, so it is invalid to launch the | |||||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||||
| // ensure that it is not called by mistake (it will generate an error). | |||||
| Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); } | |||||
| Status TakeOp::PrepareNodePostAction() { | |||||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||||
| tree_->AddToRepeatStack(shared_from_this()); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ | |||||
| #define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class TakeOp : public PipelineOp { | |||||
| public: | |||||
| // The nested builder class inside of the TakeOp is used to help manage all of the arguments | |||||
| // for constructing it. This take op is very simple though, so this builder is really just | |||||
| // provided for a consistent look and feel for creators of Dataset operators overall. | |||||
| class Builder { | |||||
| public: | |||||
| // Builder constructor. Creates the builder object. | |||||
| // @note No default args | |||||
| // @param count - The number of takes to do | |||||
| // @return This is a constructor. | |||||
| explicit Builder(int32_t count); | |||||
| // Default destructor | |||||
| ~Builder() = default; | |||||
| // The builder "build" method creates the final object. | |||||
| // @return shared_ptr to the new StorageOp object | |||||
| Status Build(std::shared_ptr<TakeOp> *); | |||||
| private: | |||||
| int32_t build_max_takes_; | |||||
| Status SanityCheck() const; | |||||
| }; | |||||
| // Constructor of the TakeOp. | |||||
| // @note The builder class should be used to call it | |||||
| // @param count - The number of takes to do | |||||
| explicit TakeOp(int32_t count); | |||||
| // Destructor | |||||
| ~TakeOp() = default; | |||||
| // A print method typically used for debugging | |||||
| // @param out - The output stream to write output to | |||||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| // << Stream output operator overload | |||||
| // @notes This allows you to write the debug print info using stream operators | |||||
| // @param out - reference to the output stream being overloaded | |||||
| // @param ro - reference to the TakeOp to display | |||||
| // @return - the output stream must be returned | |||||
| friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) { | |||||
| ro.Print(out, false); | |||||
| return out; | |||||
| } | |||||
| // Class functor operator () override. | |||||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||||
| // However, the TakeOp is defined as a inlined operator, so it is invalid to launch the | |||||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||||
| // ensure that it is not called by mistake (it will generate an error). | |||||
| // @return Status - The error code return | |||||
| Status operator()() override; | |||||
| // Gets a buffer from the child node. The caller is typically our parent node. | |||||
| // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, | |||||
| // this function will retry to pop the connector again and will get the non-EOE buffer if any. | |||||
| // @param p_buffer - output pointer to the buffer that it will fetch. | |||||
| // @param worker_id - The worker id | |||||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||||
| // @return Status - The error code return | |||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| // @notes Derived versions of this function should always call it's superclass version first | |||||
| // before providing their own implementations. | |||||
| Status PrepareNodePostAction() override; | |||||
| private: | |||||
| int32_t max_takes_; // The number of takes that the user requested | |||||
| int32_t take_count_; // A counter for the current number of executed takes | |||||
| Status FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer); | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ | |||||
| @@ -36,7 +36,7 @@ from mindspore import log as logger | |||||
| from . import samplers | from . import samplers | ||||
| from .iterators import DictIterator, TupleIterator | from .iterators import DictIterator, TupleIterator | ||||
| from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ | from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ | ||||
| check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | ||||
| check_zip_dataset, check_add_column | check_zip_dataset, check_add_column | ||||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | ||||
| @@ -442,6 +442,33 @@ class Dataset: | |||||
| """ | """ | ||||
| return SkipDataset(self, count) | return SkipDataset(self, count) | ||||
| @check_take | |||||
| def take(self, count=-1): | |||||
| """ | |||||
| Takes at most given numbers of elements from the dataset. | |||||
| Note: | |||||
| 1. If count is greater than the number of element in dataset or equal to -1, | |||||
| all the element in dataset will be taken. | |||||
| 2. The order of using take and batch effects. If take before batch operation, | |||||
| then taken given number of rows, otherwise take given number of batches. | |||||
| Args: | |||||
| count (int, optional): Number of elements to be taken from the dataset (default=-1). | |||||
| Returns: | |||||
| TakeDataset, dataset taken. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> # data is an instance of Dataset object. | |||||
| >>> # creates a dataset where the dataset including 50 elements. | |||||
| >>> data = data.take(50) | |||||
| """ | |||||
| if count == -1: | |||||
| return self | |||||
| return TakeDataset(self, count) | |||||
| @check_zip_dataset | @check_zip_dataset | ||||
| def zip(self, datasets): | def zip(self, datasets): | ||||
| """ | """ | ||||
| @@ -1100,6 +1127,7 @@ class RepeatDataset(DatasetOp): | |||||
| """ | """ | ||||
| return self.count | return self.count | ||||
| class SkipDataset(DatasetOp): | class SkipDataset(DatasetOp): | ||||
| """ | """ | ||||
| The result of applying Skip operator to the input Dataset. | The result of applying Skip operator to the input Dataset. | ||||
| @@ -1134,6 +1162,41 @@ class SkipDataset(DatasetOp): | |||||
| output_size = child_size - self.count | output_size = child_size - self.count | ||||
| return output_size | return output_size | ||||
| class TakeDataset(DatasetOp): | |||||
| """ | |||||
| The result of applying Take operator to the input Dataset. | |||||
| Args: | |||||
| input_dataset (Dataset): Input Dataset to be taken element from. | |||||
| count (int): Number of elements to be taken from the dataset. | |||||
| """ | |||||
| def __init__(self, input_dataset, count): | |||||
| super().__init__() | |||||
| self.count = count | |||||
| self.input.append(input_dataset) | |||||
| input_dataset.output.append(self) | |||||
| self._input_indexs = input_dataset.input_indexs | |||||
| def get_args(self): | |||||
| args = super().get_args() | |||||
| args["count"] = self.count | |||||
| return args | |||||
| def get_dataset_size(self): | |||||
| """ | |||||
| Get the number of batches in an epoch. | |||||
| Return: | |||||
| Number, number of batches. | |||||
| """ | |||||
| child_size = self.input[0].get_dataset_size() | |||||
| if child_size < self.count: | |||||
| return child_size | |||||
| return self.count | |||||
| class ZipDataset(DatasetOp): | class ZipDataset(DatasetOp): | ||||
| """ | """ | ||||
| The result of applying Zip operator to the input Dataset. | The result of applying Zip operator to the input Dataset. | ||||
| @@ -129,6 +129,8 @@ class Iterator: | |||||
| op_type = OpName.REPEAT | op_type = OpName.REPEAT | ||||
| elif isinstance(dataset, de.SkipDataset): | elif isinstance(dataset, de.SkipDataset): | ||||
| op_type = OpName.SKIP | op_type = OpName.SKIP | ||||
| elif isinstance(dataset, de.TakeDataset): | |||||
| op_type = OpName.TAKE | |||||
| elif isinstance(dataset, de.StorageDataset): | elif isinstance(dataset, de.StorageDataset): | ||||
| op_type = OpName.STORAGE | op_type = OpName.STORAGE | ||||
| elif isinstance(dataset, de.ImageFolderDatasetV2): | elif isinstance(dataset, de.ImageFolderDatasetV2): | ||||
| @@ -304,6 +304,9 @@ def create_node(node): | |||||
| elif dataset_op == 'SkipDataset': | elif dataset_op == 'SkipDataset': | ||||
| pyobj = de.Dataset().skip(node.get('count')) | pyobj = de.Dataset().skip(node.get('count')) | ||||
| elif dataset_op == 'TakeDataset': | |||||
| pyobj = de.Dataset().take(node.get('count')) | |||||
| elif dataset_op == 'MapDataset': | elif dataset_op == 'MapDataset': | ||||
| tensor_ops = construct_tensor_ops(node.get('operations')) | tensor_ops = construct_tensor_ops(node.get('operations')) | ||||
| pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'), | pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'), | ||||
| @@ -602,7 +602,7 @@ def check_batch_size(batch_size): | |||||
| def check_count(count): | def check_count(count): | ||||
| check_type(count, 'count', int) | check_type(count, 'count', int) | ||||
| if (count <= 0 and count != -1) or count > INT32_MAX: | if (count <= 0 and count != -1) or count > INT32_MAX: | ||||
| raise ValueError("repeat count should be either -1 or positive integer.") | |||||
| raise ValueError("count should be either -1 or positive integer.") | |||||
| def check_columns(columns, name): | def check_columns(columns, name): | ||||
| @@ -709,6 +709,7 @@ def check_repeat(method): | |||||
| return new_method | return new_method | ||||
| def check_skip(method): | def check_skip(method): | ||||
| """check the input arguments of skip.""" | """check the input arguments of skip.""" | ||||
| @wraps(method) | @wraps(method) | ||||
| @@ -724,6 +725,21 @@ def check_skip(method): | |||||
| return new_method | return new_method | ||||
| def check_take(method): | |||||
| """check the input arguments of take.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| count = param_dict.get('count') | |||||
| check_count(count) | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_zip(method): | def check_zip(method): | ||||
| """check the input arguments of zip.""" | """check the input arguments of zip.""" | ||||
| @wraps(method) | @wraps(method) | ||||
| @@ -759,6 +775,7 @@ def check_zip_dataset(method): | |||||
| return new_method | return new_method | ||||
| def check_rename(method): | def check_rename(method): | ||||
| """check the input arguments of rename.""" | """check the input arguments of rename.""" | ||||
| @wraps(method) | @wraps(method) | ||||
| @@ -64,6 +64,7 @@ SET(DE_UT_SRCS | |||||
| voc_op_test.cc | voc_op_test.cc | ||||
| cifar_op_test.cc | cifar_op_test.cc | ||||
| celeba_op_test.cc | celeba_op_test.cc | ||||
| take_op_test.cc | |||||
| ) | ) | ||||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | add_executable(de_ut_tests ${DE_UT_SRCS}) | ||||
| @@ -0,0 +1,103 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "common/common.h" | |||||
| #include "common/utils.h" | |||||
| #include "dataset/core/client.h" | |||||
| #include "gtest/gtest.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace common = mindspore::common; | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::INFO; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestTakeOp : public UT::DatasetOpTesting {}; | |||||
| TEST_F(MindDataTestTakeOp, TestTakeProject) { | |||||
| // Start with an empty execution tree | |||||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||||
| std::string dataset_path; | |||||
| dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||||
| // TFReaderOp | |||||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||||
| TFReaderOp::Builder builder; | |||||
| builder.SetDatasetFilesList({dataset_path}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetWorkerConnectorSize(16) | |||||
| .SetNumWorkers(16); | |||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||||
| schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); | |||||
| builder.SetDataSchema(std::move(schema)); | |||||
| Status rc = builder.Build(&my_tfreader_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // TakeOp | |||||
| std::shared_ptr<TakeOp> my_take_op; | |||||
| TakeOp::Builder builder_take(5); | |||||
| rc = builder_take.Build(&my_take_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssociateNode(my_tfreader_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssociateNode(my_take_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Set children/root layout. | |||||
| rc = my_take_op->AddChild(my_tfreader_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->AssignRoot(my_take_op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| MS_LOG(INFO) << "Launching tree and begin iteration."; | |||||
| rc = my_tree->Prepare(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| rc = my_tree->Launch(); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| // Start the loop of reading tensors from our pipeline | |||||
| DatasetIterator di(my_tree); | |||||
| TensorRow tensor_list; | |||||
| rc = di.FetchNextTensorRow(&tensor_list); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| int row_count = 0; | |||||
| while (!tensor_list.empty()) { | |||||
| MS_LOG(INFO) << "Row display for row #: " << row_count << "."; | |||||
| // Display the tensor by calling the printer on it | |||||
| for (int i = 0; i < tensor_list.size(); i++) { | |||||
| std::ostringstream ss; | |||||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||||
| MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; | |||||
| } | |||||
| rc = di.FetchNextTensorRow(&tensor_list); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| row_count++; | |||||
| } | |||||
| ASSERT_EQ(row_count, 5); | |||||
| } | |||||
| @@ -0,0 +1,317 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||||
| from mindspore import log as logger | |||||
| import numpy as np | |||||
| # In generator dataset: Number of rows is 3, its value is 0, 1, 2 | |||||
| def generator(): | |||||
| for i in range(3): | |||||
| yield np.array([i]), | |||||
| # In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10 | |||||
| def generator_10(): | |||||
| for i in range(10): | |||||
| yield np.array([i]), | |||||
| def test_take_01(): | |||||
| """ | |||||
| Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof | |||||
| """ | |||||
| logger.info("test_take_01") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(1) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert 0 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 2 | |||||
| def test_take_02(): | |||||
| """ | |||||
| Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe | |||||
| """ | |||||
| logger.info("test_take_02") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(2) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert i % 2 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 4 | |||||
| def test_take_03(): | |||||
| """ | |||||
| Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof | |||||
| """ | |||||
| logger.info("test_take_03") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(3) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 6 | |||||
| def test_take_04(): | |||||
| """ | |||||
| Test take: origin there are 3 row, and take 4 row, this is more than the total rows | |||||
| """ | |||||
| logger.info("test_take_04") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(4) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 6 | |||||
| def test_take_05(): | |||||
| """ | |||||
| Test take: there is no repeat op | |||||
| """ | |||||
| logger.info("test_take_05") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert i == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 2 | |||||
| def test_take_06(): | |||||
| """ | |||||
| Test take: repeat is before take | |||||
| """ | |||||
| logger.info("test_take_06") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.repeat(2) | |||||
| data1 = data1.take(4) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 4 | |||||
| def test_take_07(): | |||||
| """ | |||||
| Test take: take is before batch, that mean take(N), N refer to rows num | |||||
| """ | |||||
| logger.info("test_take_07") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(2) | |||||
| data1 = data1.batch(2) | |||||
| assert sum([1 for _ in data1]) == 1 | |||||
| def test_take_08(): | |||||
| """ | |||||
| Test take: take is after batch, that mean take(N), N refer to batches num | |||||
| """ | |||||
| logger.info("test_take_08") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.batch(2) | |||||
| data1 = data1.take(2) | |||||
| assert sum([1 for _ in data1]) == 2 | |||||
| def test_take_09(): | |||||
| """ | |||||
| Test take: repeat count is -1, and read the whole dataset, take after repeat | |||||
| """ | |||||
| logger.info("test_take_09") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.repeat(2) | |||||
| data1 = data1.take(-1) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 6 | |||||
| def test_take_10(): | |||||
| """ | |||||
| Test take: repeat count is -1, and read the whole dataset, take before repeat | |||||
| """ | |||||
| logger.info("test_take_10") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(-1) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert i % 3 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 6 | |||||
| def test_take_11(): | |||||
| """ | |||||
| Test take: batch first, then do repeat and take operation | |||||
| """ | |||||
| logger.info("test_take_11") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.batch(2) | |||||
| data1 = data1.repeat(2) | |||||
| data1 = data1.take(-1) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert 2 * (i % 2) == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 4 | |||||
| def test_take_12(): | |||||
| """ | |||||
| Test take: take first, then do batch and repeat operation | |||||
| """ | |||||
| logger.info("test_take_12") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(2) | |||||
| data1 = data1.batch(2) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert 0 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 2 | |||||
| def test_take_13(): | |||||
| """ | |||||
| Test take: skip first, then do take, batch and repeat operation | |||||
| """ | |||||
| logger.info("test_take_13") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.skip(2) | |||||
| data1 = data1.take(-1) | |||||
| data1 = data1.batch(2) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert 2 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 2 | |||||
| def test_take_14(): | |||||
| """ | |||||
| Test take: take first, then do batch, skip and repeat operation | |||||
| """ | |||||
| logger.info("test_take_14") | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.take(-1) | |||||
| data1 = data1.batch(2) | |||||
| data1 = data1.skip(1) | |||||
| data1 = data1.repeat(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert 2 == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 2 | |||||
| def test_take_15(): | |||||
| """ | |||||
| Test take: large amount data, take a part, then do skip operation | |||||
| """ | |||||
| logger.info("test_take_15") | |||||
| data1 = ds.GeneratorDataset(generator_10, ["data"]) | |||||
| data1 = data1.take(6) | |||||
| data1 = data1.skip(2) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert (i + 2) == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 4 | |||||
| def test_take_16(): | |||||
| """ | |||||
| Test take: large amount data, skip a part, then do take operation | |||||
| """ | |||||
| logger.info("test_take_16") | |||||
| data1 = ds.GeneratorDataset(generator_10, ["data"]) | |||||
| data1 = data1.skip(3) | |||||
| data1 = data1.take(5) | |||||
| # Here i refers to index, d refers to data element | |||||
| for i, d in enumerate(data1): | |||||
| assert (i + 3) == d[0][0] | |||||
| assert sum([1 for _ in data1]) == 5 | |||||
| if __name__ == '__main__': | |||||
| test_take_01() | |||||
| test_take_02() | |||||
| test_take_03() | |||||
| test_take_04() | |||||
| test_take_05() | |||||
| test_take_06() | |||||
| test_take_07() | |||||
| test_take_08() | |||||
| test_take_09() | |||||
| test_take_10() | |||||
| test_take_11() | |||||
| test_take_12() | |||||
| test_take_13() | |||||
| test_take_14() | |||||
| test_take_15() | |||||
| test_take_16() | |||||
| logger.info('== test take operation finished ==') | |||||