| @@ -47,6 +47,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D | |||
| {kMap, &DEPipeline::ParseMapOp}, | |||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||
| {kZip, &DEPipeline::ParseZipOp}, | |||
| {kRename, &DEPipeline::ParseRenameOp}, | |||
| {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | |||
| @@ -511,6 +512,17 @@ Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| if (args["count"].is_none()) { | |||
| std::string err_msg = "Error: count is invalid or not set."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::shared_ptr<SkipOp> op; | |||
| RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>(); | |||
| for (auto arg : args) { | |||
| @@ -42,6 +42,7 @@ enum OpName { | |||
| kBatch, | |||
| kCache, | |||
| kRepeat, | |||
| kSkip, | |||
| kTake, | |||
| kZip, | |||
| kMap, | |||
| @@ -107,6 +108,8 @@ class DEPipeline { | |||
| Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -446,6 +446,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("MINDRECORD", OpName::kMindrecord) | |||
| .value("CACHE", OpName::kCache) | |||
| .value("REPEAT", OpName::kRepeat) | |||
| .value("SKIP", OpName::kSkip) | |||
| .value("TAKE", OpName::kTake) | |||
| .value("ZIP", OpName::kZip) | |||
| .value("MAP", OpName::kMap) | |||
| @@ -32,6 +32,7 @@ | |||
| #include "dataset/engine/datasetops/project_op.h" | |||
| #include "dataset/engine/datasetops/rename_op.h" | |||
| #include "dataset/engine/datasetops/repeat_op.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| #include "dataset/engine/datasetops/source/generator_op.h" | |||
| #include "dataset/engine/datasetops/source/mindrecord_op.h" | |||
| @@ -11,6 +11,7 @@ add_library(engine-datasetops OBJECT | |||
| project_op.cc | |||
| rename_op.cc | |||
| repeat_op.cc | |||
| skip_op.cc | |||
| shuffle_op.cc | |||
| zip_op.cc | |||
| ) | |||
| @@ -0,0 +1,128 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Builder constructor. Creates the builder object. | |||
| SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {} | |||
| Status SkipOp::Builder::SanityCheck() const { | |||
| if (build_max_skips_ < 0) { | |||
| std::string err_msg("Skip count must be positive integer or 0."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // The builder "build" method creates the final object. | |||
| Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<SkipOp>(build_max_skips_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the SkipOp. | |||
| SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} | |||
| // Destructor | |||
| SkipOp::~SkipOp() {} | |||
| // A print method typically used for debugging | |||
| void SkipOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call base class printer first | |||
| PipelineOp::Print(out, show_all); | |||
| // Then display our own stuff | |||
| out << "SkipOp:" | |||
| << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; | |||
| } | |||
| // Since the buffer may contain multi rows, this function will drop the rows | |||
| // that need to skip in it, and then return the buffer. | |||
| Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { | |||
| if (child_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node."); | |||
| } | |||
| std::unique_ptr<DataBuffer> buf; | |||
| // Drop first max_skips_ rows | |||
| while (skip_count_ < max_skips_) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| if (buf->eoe() || buf->eof()) { | |||
| break; | |||
| } | |||
| // Consider the rows of buffer more than 1 | |||
| TensorRow drop_row; | |||
| int row_num = buf->NumRows(); | |||
| for (int i = 0; i < row_num; i++) { | |||
| RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); | |||
| if (++skip_count_ == max_skips_) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| // If buffer is none or the rows of buffer is 0, | |||
| // then get a buffer from child. | |||
| if (!buf || buf->NumRows() == 0) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| // Handling eoe and eof | |||
| if (buf->eoe() || buf->eof()) { | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| if (state_ == OpState::kDeOpIdle) { | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for handling cases when an eoe is received. | |||
| Status SkipOp::EoeReceived(int32_t worker_id) { | |||
| skip_count_ = 0; | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to | |||
| // launch the functor since this op runs inlined inside another operator. The | |||
| // function is overloaded to ensure that it is not called by mistake (it will | |||
| // generate an error). | |||
| Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||
| // Base-class override for handling cases when an eof is received. | |||
| Status SkipOp::EofReceived(int32_t worker_id) { | |||
| MS_LOG(INFO) << "Skip operator EOF received, do nothing now."; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class SkipOp : public PipelineOp { | |||
| public: | |||
| class Builder { | |||
| public: | |||
| // Builder constructor. Creates the builder object. | |||
| // @note No default args | |||
| // @param count - The number of skip to do | |||
| // @return This is a constructor. | |||
| explicit Builder(int32_t count); | |||
| // Default destructor | |||
| ~Builder() = default; | |||
| // The builder "build" method creates the final object. | |||
| // @return shared_ptr to the new StorageOp object | |||
| Status Build(std::shared_ptr<SkipOp> *); | |||
| private: | |||
| int32_t build_max_skips_; | |||
| Status SanityCheck() const; | |||
| }; | |||
| // Constructor of the SkipOp. | |||
| // @note The builder class should be used to call it | |||
| // @param count - The number of skips to do | |||
| explicit SkipOp(int32_t count); | |||
| // Destructor | |||
| ~SkipOp(); | |||
| // A print method typically used for debugging | |||
| // @param out - The output stream to write output to | |||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to launch the | |||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||
| // ensure that it is not called by mistake (it will generate an error). | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||
| // Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get | |||
| // a buffer from our child. | |||
| // @param p_buffer - output pointer to the buffer that it will fetch. | |||
| // @param worker_id - The worker id | |||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||
| // @return Status - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||
| // Base-class override for handling cases when an eoe is received. | |||
| // @param worker_id - The worker id | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| // Base-class override for handling cases when an eof is received. | |||
| // @param worker_id - The worker id | |||
| Status EofReceived(int32_t worker_id) override; | |||
| private: | |||
| int32_t max_skips_; // The number of skips that the user requested | |||
| int32_t skip_count_; // A counter for the current number of executed skips | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ | |||
| @@ -35,7 +35,7 @@ from mindspore._c_expression import typing | |||
| from mindspore import log as logger | |||
| from . import samplers | |||
| from .iterators import DictIterator, TupleIterator | |||
| from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_zip, check_rename, \ | |||
| from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ | |||
| check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | |||
| check_zip_dataset, check_add_column | |||
| @@ -423,6 +423,25 @@ class Dataset: | |||
| return self | |||
| return RepeatDataset(self, count) | |||
| @check_skip | |||
| def skip(self, count): | |||
| """ | |||
| Skip the first N elements of this dataset. | |||
| Args: | |||
| count (int): Number of elements the dataset should be skipped. | |||
| Returns: | |||
| SkipDataset, dataset skipped. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # data is an instance of Dataset object. | |||
| >>> # creates a dataset which skips first 3 elements from data | |||
| >>> data = data.skip(3) | |||
| """ | |||
| return SkipDataset(self, count) | |||
| @check_zip_dataset | |||
| def zip(self, datasets): | |||
| """ | |||
| @@ -1081,6 +1100,39 @@ class RepeatDataset(DatasetOp): | |||
| """ | |||
| return self.count | |||
| class SkipDataset(DatasetOp): | |||
| """ | |||
| The result of applying Skip operator to the input Dataset. | |||
| Args: | |||
| datasets (tuple): A tuple of datasets to be skipped. | |||
| count (int): Number of rows the dataset should be skipped. | |||
| """ | |||
| def __init__(self, input_dataset, count): | |||
| super().__init__() | |||
| self.count = count | |||
| self.input.append(input_dataset) | |||
| input_dataset.output.append(self) | |||
| self._input_indexs = input_dataset.input_indexs | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["count"] = self.count | |||
| return args | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| child_size = self.input[0].get_dataset_size() | |||
| output_size = 0 | |||
| if self.count >= 0 and self.count < child_size: | |||
| output_size = child_size - self.count | |||
| return output_size | |||
| class ZipDataset(DatasetOp): | |||
| """ | |||
| @@ -127,6 +127,8 @@ class Iterator: | |||
| op_type = OpName.MAP | |||
| elif isinstance(dataset, de.RepeatDataset): | |||
| op_type = OpName.REPEAT | |||
| elif isinstance(dataset, de.SkipDataset): | |||
| op_type = OpName.SKIP | |||
| elif isinstance(dataset, de.StorageDataset): | |||
| op_type = OpName.STORAGE | |||
| elif isinstance(dataset, de.ImageFolderDatasetV2): | |||
| @@ -297,6 +297,9 @@ def create_node(node): | |||
| elif dataset_op == 'RepeatDataset': | |||
| pyobj = de.Dataset().repeat(node.get('count')) | |||
| elif dataset_op == 'SkipDataset': | |||
| pyobj = de.Dataset().skip(node.get('count')) | |||
| elif dataset_op == 'MapDataset': | |||
| tensor_ops = construct_tensor_ops(node.get('operations')) | |||
| pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'), | |||
| @@ -709,6 +709,20 @@ def check_repeat(method): | |||
| return new_method | |||
| def check_skip(method): | |||
| """check the input arguments of skip.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| count = param_dict.get('count') | |||
| check_type(count, 'count', int) | |||
| if count < 0: | |||
| raise ValueError("Skip count must be positive integer or 0.") | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_zip(method): | |||
| """check the input arguments of zip.""" | |||
| @@ -41,6 +41,7 @@ SET(DE_UT_SRCS | |||
| random_vertical_flip_op_test.cc | |||
| rename_op_test.cc | |||
| repeat_op_test.cc | |||
| skip_op_test.cc | |||
| rescale_op_test.cc | |||
| resize_bilinear_op_test.cc | |||
| resize_op_test.cc | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/circular_pool.h" | |||
| #include "dataset/core/client.h" | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::MsLogLevel::INFO; | |||
| using mindspore::ExceptionType::NoExceptionType; | |||
| using mindspore::LogStream; | |||
| class MindDataTestSkipOp : public UT::DatasetOpTesting {}; | |||
| TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) { | |||
| // Start with an empty execution tree | |||
| auto my_tree = std::make_shared<ExecutionTree>(); | |||
| std::string dataset_path; | |||
| dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; | |||
| std::shared_ptr<TFReaderOp> my_tfreader_op; | |||
| TFReaderOp::Builder builder; | |||
| builder.SetDatasetFilesList({dataset_path}) | |||
| .SetRowsPerBuffer(16) | |||
| .SetWorkerConnectorSize(16) | |||
| .SetNumWorkers(16); | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); | |||
| builder.SetDataSchema(std::move(schema)); | |||
| Status rc = builder.Build(&my_tfreader_op); ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssociateNode(my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // SkipOp | |||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5); | |||
| rc = my_tree->AssociateNode(skip_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Set children/root layout. | |||
| rc = skip_op->AddChild(my_tfreader_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->AssignRoot(skip_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration."; | |||
| rc = my_tree->Prepare(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = my_tree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(my_tree); | |||
| TensorRow tensor_list; | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int row_count = 0; | |||
| while (!tensor_list.empty()) { | |||
| MS_LOG(INFO) << "Row display for row #: " << row_count << "."; | |||
| // Display the tensor by calling the printer on it | |||
| for (int i = 0; i < tensor_list.size(); i++) { | |||
| std::ostringstream ss; | |||
| ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; | |||
| MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; | |||
| } | |||
| rc = di.FetchNextTensorRow(&tensor_list); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| row_count++; | |||
| } | |||
| ASSERT_EQ(row_count, 7); | |||
| } | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_tf_skip(): | |||
| data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) | |||
| resize_height, resize_width = 32, 32 | |||
| decode_op = vision.Decode() | |||
| resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| data1 = data1.map(input_columns=["image"], operations=resize_op) | |||
| data1 = data1.skip(2) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert num_iter == 1 | |||
| def generator_md(): | |||
| # Create a dataset with [0, 1, 2, 3, 4] | |||
| for i in range(5): | |||
| yield (np.array([i]), ) | |||
| def test_generator_skip(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [3, 4] | |||
| ds1 = ds1.skip(3) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 2 | |||
| def test_skip_1(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [] | |||
| ds1 = ds1.skip(7) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 0 | |||
| def test_skip_2(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [0, 1, 2, 3, 4] | |||
| ds1 = ds1.skip(0) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 5 | |||
| def test_skip_repeat_1(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] | |||
| ds1 = ds1.repeat(2) | |||
| # Here ds1 should be [3, 4, 0, 1, 2, 3, 4] | |||
| ds1 = ds1.skip(3) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 7 | |||
| def test_skip_repeat_2(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [3, 4] | |||
| ds1 = ds1.skip(3) | |||
| # Here ds1 should be [3, 4, 3, 4] | |||
| ds1 = ds1.repeat(2) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 4 | |||
| def test_skip_repeat_3(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] | |||
| ds1 = ds1.repeat(2) | |||
| # Here ds1 should be [3, 4] | |||
| ds1 = ds1.skip(8) | |||
| # Here ds1 should be [3, 4, 3, 4, 3, 4] | |||
| ds1 = ds1.repeat(3) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 6 | |||
| if __name__ == "__main__": | |||
| test_tf_skip() | |||
| test_generator_skip() | |||
| test_skip_1() | |||
| test_skip_2() | |||
| test_skip_repeat_1() | |||
| test_skip_repeat_2() | |||
| test_skip_repeat_3() | |||