| @@ -54,6 +54,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | |||
| {kTfReader, &DEPipeline::ParseTFReaderOp}, | |||
| {kProject, &DEPipeline::ParseProjectOp}, | |||
| {kTake, &DEPipeline::ParseTakeOp}, | |||
| {kImageFolder, &DEPipeline::ParseImageFolderOp}, | |||
| {kMnist, &DEPipeline::ParseMnistOp}, | |||
| {kManifest, &DEPipeline::ParseManifestOp}, | |||
| @@ -650,7 +651,16 @@ Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| return Status::OK(); | |||
| } | |||
| DsOpPtr DEPipeline::ParseTakeOp(const py::dict &args) const { return DsOpPtr(); } | |||
| Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| if (args["count"].is_none()) { | |||
| std::string err_msg = "Error: count is invalid or not set."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::shared_ptr<TakeOp> op; | |||
| RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<ZipOp::Builder> builder = std::make_shared<ZipOp::Builder>(); | |||
| @@ -116,7 +116,7 @@ class DEPipeline { | |||
| Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| DsOpPtr ParseTakeOp(const py::dict &args) const; | |||
| Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -38,6 +38,7 @@ | |||
| #include "dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "dataset/engine/datasetops/source/storage_op.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/datasetops/take_op.h" | |||
| #include "dataset/engine/datasetops/zip_op.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/util/status.h" | |||
| @@ -5,13 +5,13 @@ add_library(engine-datasetops OBJECT | |||
| parallel_op.cc | |||
| pipeline_op.cc | |||
| batch_op.cc | |||
| batch_op.cc | |||
| device_queue_op.cc | |||
| map_op.cc | |||
| project_op.cc | |||
| rename_op.cc | |||
| repeat_op.cc | |||
| skip_op.cc | |||
| take_op.cc | |||
| shuffle_op.cc | |||
| zip_op.cc | |||
| ) | |||
| @@ -88,6 +88,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work | |||
| // If buffer is none or the rows of buffer is 0, | |||
| // then get a buffer from child. | |||
| if (!buf || buf->NumRows() == 0) { | |||
| if (buf && buf->eof()) { | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| @@ -0,0 +1,146 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <utility> | |||
| #include "common/utils.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/datasetops/take_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Builder constructor. Creates the builder object. | |||
| TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) {} | |||
| Status TakeOp::Builder::SanityCheck() const { | |||
| if (build_max_takes_ <= 0) { | |||
| std::string err_msg("Take count must be greater than 0."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // The builder "build" method creates the final object. | |||
| Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<TakeOp>(build_max_takes_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the TakeOp. | |||
| TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {} | |||
| // A print method typically used for debugging | |||
| void TakeOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call base class printer first | |||
| PipelineOp::Print(out, show_all); | |||
| // Then display our own stuff | |||
| out << "TakeOp:" | |||
| << "\nCurrent take count: " << take_count_ << "\nMax take count: " << max_takes_; | |||
| } | |||
| // This function will be call muti times to returns the buffer, when meet required max take count or meet | |||
| // EOF buffer then this will stop. | |||
| Status TakeOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { | |||
| if (child_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("TakeOp can't be the leaf node."); | |||
| } | |||
| std::unique_ptr<DataBuffer> buf; | |||
| bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); | |||
| if (take_count_ == max_takes_) { | |||
| if (state_ == OpState::kDeOpRunning) { | |||
| MS_LOG(INFO) << "meet max count and push-back eoe buffer."; | |||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| *p_buffer = std::move(eoe_buffer); | |||
| state_ = OpState::kDeOpIdle; | |||
| // Reset the count and drain | |||
| if (!last_repeat) { | |||
| take_count_ = 0; | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| while (!buf->eoe() && !buf->eof()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(INFO) << "meet max count and push-back eof buffer."; | |||
| auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||
| *p_buffer = std::move(eof_buffer); | |||
| take_count_ = 0; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| // Loop until non EOE is received | |||
| if (buf->eoe()) { | |||
| take_count_ = 0; | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| // Check if the last buf is next eof | |||
| if (buf->eof()) { | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| // Get buffer and push back when take_count is still small | |||
| if (take_count_ < max_takes_) { | |||
| RETURN_IF_NOT_OK(FillBuffer(&buf, p_buffer)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Function FillBuffer mainly prepare the buffer for returning | |||
| Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer) { | |||
| int32_t buffer_size = (*buffer)->NumRows(); | |||
| if (take_count_ + buffer_size < max_takes_) { | |||
| *data_buffer = std::move(*buffer); | |||
| take_count_ = take_count_ + buffer_size; | |||
| } else { | |||
| MS_LOG(INFO) << "In last buffer: Push one buffer."; | |||
| std::unique_ptr<TensorQTable> new_tensor_table = std::make_unique<TensorQTable>(); | |||
| while (take_count_ < max_takes_) { | |||
| TensorRow new_row; | |||
| RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row)); | |||
| take_count_++; | |||
| new_tensor_table->push_back(new_row); | |||
| } | |||
| (*buffer)->set_tensor_table(std::move(new_tensor_table)); | |||
| *data_buffer = std::move(*buffer); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the TakeOp is defined as a inlined operator, so it is invalid to launch the | |||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||
| // ensure that it is not called by mistake (it will generate an error). | |||
| Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. TakeOp is an inlined operator."); } | |||
| Status TakeOp::PrepareNodePostAction() { | |||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | |||
| tree_->AddToRepeatStack(shared_from_this()); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class TakeOp : public PipelineOp { | |||
| public: | |||
| // The nested builder class inside of the TakeOp is used to help manage all of the arguments | |||
| // for constructing it. This take op is very simple though, so this builder is really just | |||
| // provided for a consistent look and feel for creators of Dataset operators overall. | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @param count - The number of takes to do | |||
| // @return This is a constructor. | |||
| explicit Builder(int32_t count); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // The builder "build" method creates the final object. | |||
| // @return shared_ptr to the new StorageOp object | |||
| Status Build(std::shared_ptr<TakeOp> *); | |||
| private: | |||
| int32_t build_max_takes_; | |||
| Status SanityCheck() const; | |||
| }; | |||
| // Constructor of the TakeOp. | |||
| // @note The builder class should be used to call it | |||
| // @param count - The number of takes to do | |||
| explicit TakeOp(int32_t count); | |||
| // Destructor | |||
| ~TakeOp() = default; | |||
| // A print method typically used for debugging | |||
| // @param out - The output stream to write output to | |||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // << Stream output operator overload | |||
| // @notes This allows you to write the debug print info using stream operators | |||
| // @param out - reference to the output stream being overloaded | |||
| // @param ro - reference to the TakeOp to display | |||
| // @return - the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) { | |||
| ro.Print(out, false); | |||
| return out; | |||
| } | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the TakeOp is defined as a inlined operator, so it is invalid to launch the | |||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||
| // ensure that it is not called by mistake (it will generate an error). | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Gets a buffer from the child node. The caller is typically our parent node. | |||
| // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, | |||
| // this function will retry to pop the connector again and will get the non-EOE buffer if any. | |||
| // @param p_buffer - output pointer to the buffer that it will fetch. | |||
| // @param worker_id - The worker id | |||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||
| // @return Status - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| // before providing their own implementations. | |||
| Status PrepareNodePostAction() override; | |||
| private: | |||
| int32_t max_takes_; // The number of takes that the user requested | |||
| int32_t take_count_; // A counter for the current number of executed takes | |||
| Status FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<DataBuffer> *data_buffer); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ | |||
| @@ -36,7 +36,7 @@ from mindspore import log as logger | |||
| from . import samplers | |||
| from .iterators import DictIterator, TupleIterator | |||
| from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ | |||
| check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | |||
| check_zip_dataset, check_add_column | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| @@ -442,6 +442,33 @@ class Dataset: | |||
| """ | |||
| return SkipDataset(self, count) | |||
| @check_take | |||
| def take(self, count=-1): | |||
| """ | |||
| Takes at most given numbers of elements from the dataset. | |||
| Note: | |||
| 1. If count is greater than the number of element in dataset or equal to -1, | |||
| all the element in dataset will be taken. | |||
| 2. The order of using take and batch effects. If take before batch operation, | |||
| then taken given number of rows, otherwise take given number of batches. | |||
| Args: | |||
| count (int, optional): Number of elements to be taken from the dataset (default=-1). | |||
| Returns: | |||
| TakeDataset, dataset taken. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # data is an instance of Dataset object. | |||
| >>> # creates a dataset where the dataset including 50 elements. | |||
| >>> data = data.take(50) | |||
| """ | |||
| if count == -1: | |||
| return self | |||
| return TakeDataset(self, count) | |||
| @check_zip_dataset | |||
| def zip(self, datasets): | |||
| """ | |||
| @@ -1100,6 +1127,7 @@ class RepeatDataset(DatasetOp): | |||
| """ | |||
| return self.count | |||
| class SkipDataset(DatasetOp): | |||
| """ | |||
| The result of applying Skip operator to the input Dataset. | |||
| @@ -1134,6 +1162,41 @@ class SkipDataset(DatasetOp): | |||
| output_size = child_size - self.count | |||
| return output_size | |||
| class TakeDataset(DatasetOp): | |||
| """ | |||
| The result of applying Take operator to the input Dataset. | |||
| Args: | |||
| input_dataset (Dataset): Input Dataset to be taken element from. | |||
| count (int): Number of elements to be taken from the dataset. | |||
| """ | |||
| def __init__(self, input_dataset, count): | |||
| super().__init__() | |||
| self.count = count | |||
| self.input.append(input_dataset) | |||
| input_dataset.output.append(self) | |||
| self._input_indexs = input_dataset.input_indexs | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["count"] = self.count | |||
| return args | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| child_size = self.input[0].get_dataset_size() | |||
| if child_size < self.count: | |||
| return child_size | |||
| return self.count | |||
| class ZipDataset(DatasetOp): | |||
| """ | |||
| The result of applying Zip operator to the input Dataset. | |||
| @@ -129,6 +129,8 @@ class Iterator: | |||
| op_type = OpName.REPEAT | |||
| elif isinstance(dataset, de.SkipDataset): | |||
| op_type = OpName.SKIP | |||
| elif isinstance(dataset, de.TakeDataset): | |||
| op_type = OpName.TAKE | |||
| elif isinstance(dataset, de.StorageDataset): | |||
| op_type = OpName.STORAGE | |||
| elif isinstance(dataset, de.ImageFolderDatasetV2): | |||
| @@ -304,6 +304,9 @@ def create_node(node): | |||
| elif dataset_op == 'SkipDataset': | |||
| pyobj = de.Dataset().skip(node.get('count')) | |||
| elif dataset_op == 'TakeDataset': | |||
| pyobj = de.Dataset().take(node.get('count')) | |||
| elif dataset_op == 'MapDataset': | |||
| tensor_ops = construct_tensor_ops(node.get('operations')) | |||
| pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'), | |||
| @@ -602,7 +602,7 @@ def check_batch_size(batch_size): | |||
| def check_count(count): | |||
| check_type(count, 'count', int) | |||
| if (count <= 0 and count != -1) or count > INT32_MAX: | |||
| raise ValueError("repeat count should be either -1 or positive integer.") | |||
| raise ValueError("count should be either -1 or positive integer.") | |||
| def check_columns(columns, name): | |||
| @@ -709,6 +709,7 @@ def check_repeat(method): | |||
| return new_method | |||
| def check_skip(method): | |||
| """check the input arguments of skip.""" | |||
| @wraps(method) | |||
| @@ -724,6 +725,21 @@ def check_skip(method): | |||
| return new_method | |||
| def check_take(method): | |||
| """check the input arguments of take.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| count = param_dict.get('count') | |||
| check_count(count) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_zip(method): | |||
| """check the input arguments of zip.""" | |||
| @wraps(method) | |||
| @@ -759,6 +775,7 @@ def check_zip_dataset(method): | |||
| return new_method | |||
| def check_rename(method): | |||
| """check the input arguments of rename.""" | |||
| @wraps(method) | |||
| @@ -64,6 +64,7 @@ SET(DE_UT_SRCS | |||
| voc_op_test.cc | |||
| cifar_op_test.cc | |||
| celeba_op_test.cc | |||
| take_op_test.cc | |||
| ) | |||
| add_executable(de_ut_tests ${DE_UT_SRCS}) | |||
| @@ -0,0 +1,103 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "common/common.h" | |||
| #include "common/utils.h" | |||
| #include "dataset/core/client.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace common = mindspore::common; | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestTakeOp : public UT::DatasetOpTesting {}; | |||
| TEST_F(MindDataTestTakeOp, TestTakeProject) { | |||
| // Start with an empty execution tree | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| std::string dataset_path; | |||
| dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||
| // TFReaderOp | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| TFReaderOp::Builder builder; | |||
| builder.SetDatasetFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetWorkerConnectorSize(16) | |||
| .SetNumWorkers(16); | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); | |||
| builder.SetDataSchema(std::move(schema)); | |||
| Status rc = builder.Build(&my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // TakeOp | |||
| std::shared_ptr<TakeOp> my_take_op; | |||
| TakeOp::Builder builder_take(5); | |||
| rc = builder_take.Build(&my_take_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_take_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Set children/root layout. | |||
| rc = my_take_op->AddChild(my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssignRoot(my_take_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration."; | |||
| rc = my_tree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree); | |||
| TensorRow tensor_list; | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int row_count = 0; | |||
| while (!tensor_list.empty()) { | |||
| MS_LOG(INFO) << "Row display for row #: " << row_count << "."; | |||
| // Display the tensor by calling the printer on it | |||
| for (int i = 0; i < tensor_list.size(); i++) { | |||
| std::ostringstream ss; | |||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||
| MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; | |||
| } | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| row_count++; | |||
| } | |||
| ASSERT_EQ(row_count, 5); | |||
| } | |||
| @@ -0,0 +1,317 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| import numpy as np | |||
| # In generator dataset: Number of rows is 3, its value is 0, 1, 2 | |||
| def generator(): | |||
| for i in range(3): | |||
| yield np.array([i]), | |||
| # In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10 | |||
| def generator_10(): | |||
| for i in range(10): | |||
| yield np.array([i]), | |||
| def test_take_01(): | |||
| """ | |||
| Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof | |||
| """ | |||
| logger.info("test_take_01") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(1) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert 0 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 2 | |||
| def test_take_02(): | |||
| """ | |||
| Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe | |||
| """ | |||
| logger.info("test_take_02") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(2) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i % 2 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 4 | |||
| def test_take_03(): | |||
| """ | |||
| Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof | |||
| """ | |||
| logger.info("test_take_03") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(3) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 6 | |||
| def test_take_04(): | |||
| """ | |||
| Test take: origin there are 3 row, and take 4 row, this is more than the total rows | |||
| """ | |||
| logger.info("test_take_04") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(4) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 6 | |||
| def test_take_05(): | |||
| """ | |||
| Test take: there is no repeat op | |||
| """ | |||
| logger.info("test_take_05") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i == d[0][0] | |||
| assert sum([1 for _ in data1]) == 2 | |||
| def test_take_06(): | |||
| """ | |||
| Test take: repeat is before take | |||
| """ | |||
| logger.info("test_take_06") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.repeat(2) | |||
| data1 = data1.take(4) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 4 | |||
| def test_take_07(): | |||
| """ | |||
| Test take: take is before batch, that mean take(N), N refer to rows num | |||
| """ | |||
| logger.info("test_take_07") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(2) | |||
| data1 = data1.batch(2) | |||
| assert sum([1 for _ in data1]) == 1 | |||
| def test_take_08(): | |||
| """ | |||
| Test take: take is after batch, that mean take(N), N refer to batches num | |||
| """ | |||
| logger.info("test_take_08") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.batch(2) | |||
| data1 = data1.take(2) | |||
| assert sum([1 for _ in data1]) == 2 | |||
| def test_take_09(): | |||
| """ | |||
| Test take: repeat count is -1, and read the whole dataset, take after repeat | |||
| """ | |||
| logger.info("test_take_09") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.repeat(2) | |||
| data1 = data1.take(-1) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 6 | |||
| def test_take_10(): | |||
| """ | |||
| Test take: repeat count is -1, and read the whole dataset, take before repeat | |||
| """ | |||
| logger.info("test_take_10") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(-1) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i % 3 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 6 | |||
| def test_take_11(): | |||
| """ | |||
| Test take: batch first, then do repeat and take operation | |||
| """ | |||
| logger.info("test_take_11") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.batch(2) | |||
| data1 = data1.repeat(2) | |||
| data1 = data1.take(-1) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert 2 * (i % 2) == d[0][0] | |||
| assert sum([1 for _ in data1]) == 4 | |||
| def test_take_12(): | |||
| """ | |||
| Test take: take first, then do batch and repeat operation | |||
| """ | |||
| logger.info("test_take_12") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(2) | |||
| data1 = data1.batch(2) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert 0 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 2 | |||
| def test_take_13(): | |||
| """ | |||
| Test take: skip first, then do take, batch and repeat operation | |||
| """ | |||
| logger.info("test_take_13") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.skip(2) | |||
| data1 = data1.take(-1) | |||
| data1 = data1.batch(2) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert 2 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 2 | |||
| def test_take_14(): | |||
| """ | |||
| Test take: take first, then do batch, skip and repeat operation | |||
| """ | |||
| logger.info("test_take_14") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.take(-1) | |||
| data1 = data1.batch(2) | |||
| data1 = data1.skip(1) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert 2 == d[0][0] | |||
| assert sum([1 for _ in data1]) == 2 | |||
| def test_take_15(): | |||
| """ | |||
| Test take: large amount data, take a part, then do skip operation | |||
| """ | |||
| logger.info("test_take_15") | |||
| data1 = ds.GeneratorDataset(generator_10, ["data"]) | |||
| data1 = data1.take(6) | |||
| data1 = data1.skip(2) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert (i + 2) == d[0][0] | |||
| assert sum([1 for _ in data1]) == 4 | |||
| def test_take_16(): | |||
| """ | |||
| Test take: large amount data, skip a part, then do take operation | |||
| """ | |||
| logger.info("test_take_16") | |||
| data1 = ds.GeneratorDataset(generator_10, ["data"]) | |||
| data1 = data1.skip(3) | |||
| data1 = data1.take(5) | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert (i + 3) == d[0][0] | |||
| assert sum([1 for _ in data1]) == 5 | |||
| if __name__ == '__main__': | |||
| test_take_01() | |||
| test_take_02() | |||
| test_take_03() | |||
| test_take_04() | |||
| test_take_05() | |||
| test_take_06() | |||
| test_take_07() | |||
| test_take_08() | |||
| test_take_09() | |||
| test_take_10() | |||
| test_take_11() | |||
| test_take_12() | |||
| test_take_13() | |||
| test_take_14() | |||
| test_take_15() | |||
| test_take_16() | |||
| logger.info('== test take operation finished ==') | |||