| @@ -16,6 +16,7 @@ | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/engine/data_buffer.h" | |||
| #include "dataset/engine/datasetops/skip_op.h" | |||
| #include "dataset/engine/db_connector.h" | |||
| @@ -26,7 +27,10 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Builder constructor. Creates the builder object. | |||
| SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {} | |||
| SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_op_connector_size_ = cfg->op_connector_size(); | |||
| } | |||
| Status SkipOp::Builder::SanityCheck() const { | |||
| if (build_max_skips_ < 0) { | |||
| @@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const { | |||
| // The builder "build" method creates the final object. | |||
| Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *ptr = std::make_shared<SkipOp>(build_max_skips_); | |||
| *ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of the SkipOp. | |||
| SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} | |||
| SkipOp::SkipOp(int32_t count, int32_t op_connector_size) | |||
| : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} | |||
| // Destructor | |||
| SkipOp::~SkipOp() {} | |||
| @@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const { | |||
| << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; | |||
| } | |||
| // Since the buffer may contain multi rows, this function will drop the rows | |||
| // that need to skip in it, and then return the buffer. | |||
| Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) { | |||
| if (child_.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node."); | |||
| } | |||
| std::unique_ptr<DataBuffer> buf; | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| // Drop first max_skips_ rows | |||
| while (skip_count_ < max_skips_) { | |||
| if (buf->eoe() || buf->eof()) { | |||
| break; | |||
| } | |||
| // Consider the rows of buffer more than 1 | |||
| TensorRow drop_row; | |||
| int row_num = buf->NumRows(); | |||
| int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; | |||
| skip_count_ += drop_num; | |||
| for (int i = 0; i < drop_num; i++) { | |||
| RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); | |||
| } | |||
| if (buf->NumRows() == 0) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); | |||
| } | |||
| } | |||
| // Handling eoe | |||
| if (buf->eoe()) { | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| } | |||
| // Handling eof | |||
| if (buf->eof()) { | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| } | |||
| *p_buffer = std::move(buf); | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for handling cases when an eoe is received. | |||
| Status SkipOp::EoeReceived(int32_t worker_id) { | |||
| skip_count_ = 0; | |||
| @@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) { | |||
| return Status::OK(); | |||
| } | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to | |||
| // launch the functor since this op runs inlined inside another operator. The | |||
| // function is overloaded to ensure that it is not called by mistake (it will | |||
| // generate an error). | |||
| Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } | |||
| // main entry point for skip | |||
| Status SkipOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| std::unique_ptr<DataBuffer> curr_buffer; | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| while (curr_buffer->eof() == false) { | |||
| // Reset count | |||
| skip_count_ = 0; | |||
| while (curr_buffer->eoe() == false) { | |||
| // Drop first count rows | |||
| while (skip_count_ < max_skips_) { | |||
| if (curr_buffer->eoe() || curr_buffer->eof()) { | |||
| break; | |||
| } | |||
| // Consider the rows of buffer more than one | |||
| TensorRow drop_row; | |||
| int row_num = curr_buffer->NumRows(); | |||
| int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; | |||
| skip_count_ += drop_num; | |||
| for (int i = 0; i < drop_num; i++) { | |||
| RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); | |||
| } | |||
| if (curr_buffer->NumRows() == 0) { | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| } | |||
| // we got eoe, now try again until we got eof | |||
| MS_LOG(DEBUG) << "Skip operator EOE Received."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)))); | |||
| RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); | |||
| } | |||
| MS_LOG(DEBUG) << "Skip operator EOF Received."; | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)))); | |||
| return Status::OK(); | |||
| } | |||
| // Base-class override for handling cases when an eof is received. | |||
| Status SkipOp::EofReceived(int32_t worker_id) { | |||
| @@ -42,6 +42,7 @@ class SkipOp : public PipelineOp { | |||
| private: | |||
| int32_t build_max_skips_; | |||
| int32_t builder_op_connector_size_; | |||
| Status SanityCheck() const; | |||
| }; | |||
| @@ -49,7 +50,7 @@ class SkipOp : public PipelineOp { | |||
| // Constructor of the SkipOp. | |||
| // @note The builder class should be used to call it | |||
| // @param count - The number of skips to do | |||
| explicit SkipOp(int32_t count); | |||
| explicit SkipOp(int32_t count, int32_t op_connector_size); | |||
| // Destructor | |||
| ~SkipOp(); | |||
| @@ -60,23 +61,11 @@ class SkipOp : public PipelineOp { | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Class functor operator () override. | |||
| // Most dataset ops operate by launching a thread (see ExecutionTree). | |||
| // However, the SkipOp is defined as a inlined operator, so it is invalid to launch the | |||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||
| // ensure that it is not called by mistake (it will generate an error). | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| // typically our parent node, when the parent is asking us to provide the next buffer of data. | |||
| // Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get | |||
| // a buffer from our child. | |||
| // @param p_buffer - output pointer to the buffer that it will fetch. | |||
| // @param worker_id - The worker id | |||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||
| // @return Status - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||
| // Base-class override for handling cases when an eoe is received. | |||
| // @param worker_id - The worker id | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| @@ -17,6 +17,7 @@ | |||
| from abc import abstractmethod | |||
| import copy | |||
| import weakref | |||
| from importlib import import_module | |||
| from mindspore._c_dataengine import DEPipeline | |||
| from mindspore._c_dataengine import OpName | |||
| @@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName | |||
| from mindspore import log as logger | |||
| from . import datasets as de | |||
| try: | |||
| context = import_module("mindspore.context") | |||
| except ModuleNotFoundError: | |||
| context = None | |||
| ITERATORS_LIST = list() | |||
| def _cleanup(): | |||
| """Release all the Iterator.""" | |||
| for itr_ref in ITERATORS_LIST: | |||
| itr = itr_ref() | |||
| if itr is not None: | |||
| itr.release() | |||
| if context: | |||
| device_type = context.get_context("device_target") | |||
| if device_type == "GPU": | |||
| itr_ref.release() | |||
| else: | |||
| itr = itr_ref() | |||
| if itr is not None: | |||
| itr.release() | |||
| else: | |||
| itr = itr_ref() | |||
| if itr is not None: | |||
| itr.release() | |||
| def alter_tree(node): | |||
| @@ -85,7 +101,14 @@ class Iterator: | |||
| """ | |||
| def __init__(self, dataset): | |||
| ITERATORS_LIST.append(weakref.ref(self)) | |||
| if context: | |||
| device_type = context.get_context("device_target") | |||
| if device_type == "GPU": | |||
| ITERATORS_LIST.append(self) | |||
| else: | |||
| ITERATORS_LIST.append(weakref.ref(self)) | |||
| else: | |||
| ITERATORS_LIST.append(weakref.ref(self)) | |||
| # create a copy of tree and work on it. | |||
| self.dataset = copy.deepcopy(dataset) | |||
| self.dataset = alter_tree(self.dataset) | |||
| @@ -76,8 +76,13 @@ class BoundingBoxEncode(PrimitiveWithInfer): | |||
| Tensor, encoded bounding boxes. | |||
| Examples: | |||
| >>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32) | |||
| >>> groundtruth_box = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32) | |||
| >>> boundingbox_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) | |||
| >>> delta_box = boundingbox_encode(anchor_box, groundtruth_box) | |||
| >>> boundingbox_encode(anchor_box, groundtruth_box) | |||
| [[5.0000000e-01 5.0000000e-01 -6.5504000e+04 6.9335938e-01] | |||
| [-1.0000000e+00 2.5000000e-01 0.0000000e+00 4.0551758e-01]] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -118,9 +123,14 @@ class BoundingBoxDecode(PrimitiveWithInfer): | |||
| Tensor, decoded boxes. | |||
| Examples: | |||
| >>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32) | |||
| >>> deltas = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32) | |||
| >>> boundingbox_decode = P.BoundingBoxDecode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), | |||
| >>> max_shape=(768, 1280), wh_ratio_clip=0.016) | |||
| >>> bbox = boundingbox_decode(anchor_box, deltas) | |||
| >>> boundingbox_decode(anchor_box, deltas) | |||
| [[4.1953125 0. 0. 5.1953125] | |||
| [2.140625 0. 3.859375 60.59375]] | |||
| """ | |||
| @prim_attr_register | |||
| @@ -0,0 +1,198 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore import context | |||
| import mindspore.common.dtype as mstype | |||
| import os | |||
| import numpy as np | |||
| import mindspore.ops.functional as F | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from resnet import resnet50 | |||
| import random | |||
| import time | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| ds.config.set_seed(1) | |||
| data_home = "/home/workspace/mindspore_dataset" | |||
| def create_dataset(repeat_num=1, training=True, batch_size=32): | |||
| data_dir = data_home + "/cifar-10-batches-bin" | |||
| if not training: | |||
| data_dir = data_home + "/cifar-10-verify-bin" | |||
| data_set = ds.Cifar10Dataset(data_dir) | |||
| resize_height = 224 | |||
| resize_width = 224 | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| # define map operations | |||
| random_crop_op = vision.RandomCrop( | |||
| (32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT | |||
| random_horizontal_op = vision.RandomHorizontalFlip() | |||
| # interpolation default BILINEAR | |||
| resize_op = vision.Resize((resize_height, resize_width)) | |||
| rescale_op = vision.Rescale(rescale, shift) | |||
| normalize_op = vision.Normalize( | |||
| (0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) | |||
| changeswap_op = vision.HWC2CHW() | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| c_trans = [] | |||
| if training: | |||
| c_trans = [random_crop_op, random_horizontal_op] | |||
| c_trans += [resize_op, rescale_op, normalize_op, | |||
| changeswap_op] | |||
| # apply map operations on images | |||
| data_set = data_set.map(input_columns="label", operations=type_cast_op) | |||
| data_set = data_set.map(input_columns="image", operations=c_trans) | |||
| # apply shuffle operations | |||
| data_set = data_set.shuffle(buffer_size=1000) | |||
| # apply batch operations | |||
| data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) | |||
| # apply repeat operations | |||
| data_set = data_set.repeat(repeat_num) | |||
| return data_set | |||
| class CrossEntropyLoss(nn.Cell): | |||
| def __init__(self): | |||
| super(CrossEntropyLoss, self).__init__() | |||
| self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() | |||
| self.mean = P.ReduceMean() | |||
| self.one_hot = P.OneHot() | |||
| self.one = Tensor(1.0, mstype.float32) | |||
| self.zero = Tensor(0.0, mstype.float32) | |||
| def construct(self, logits, label): | |||
| label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero) | |||
| loss = self.cross_entropy(logits, label)[0] | |||
| loss = self.mean(loss, (-1,)) | |||
| return loss | |||
| class LossGet(Callback): | |||
| def __init__(self, per_print_times=1): | |||
| super(LossGet, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self._loss = 0.0 | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| loss = cb_params.net_outputs | |||
| if isinstance(loss, (tuple, list)): | |||
| if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): | |||
| loss = loss[0] | |||
| if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): | |||
| loss = np.mean(loss.asnumpy()) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): | |||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." | |||
| .format(cb_params.cur_epoch_num, cur_step_in_epoch)) | |||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||
| self._loss = loss | |||
| print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss)) | |||
| def get_loss(self): | |||
| return self._loss | |||
| def train_process(device_id, epoch_size, num_classes, device_num, batch_size): | |||
| os.system("mkdir " + str(device_id)) | |||
| os.chdir(str(device_id)) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(enable_task_sink=True, device_id=device_id) | |||
| context.set_context(enable_loop_sink=True) | |||
| context.set_context(enable_mem_reuse=True) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = resnet50(batch_size, num_classes) | |||
| loss = CrossEntropyLoss() | |||
| opt = Momentum(filter(lambda x: x.requires_grad, | |||
| net.get_parameters()), 0.01, 0.9) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) | |||
| batch_num = dataset.get_dataset_size() | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./", | |||
| config=config_ck) | |||
| loss_cb = LossGet() | |||
| model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) | |||
| def eval(batch_size, num_classes): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(enable_task_sink=True, device_id=0) | |||
| context.set_context(enable_loop_sink=True) | |||
| context.set_context(enable_mem_reuse=True) | |||
| net = resnet50(batch_size, num_classes) | |||
| loss = CrossEntropyLoss() | |||
| opt = Momentum(filter(lambda x: x.requires_grad, | |||
| net.get_parameters()), 0.01, 0.9) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt" | |||
| param_dict = load_checkpoint(checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| eval_dataset = create_dataset(1, training=False) | |||
| res = model.eval(eval_dataset) | |||
| print("result: ", res) | |||
| return res | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_resnet_cifar_1p(): | |||
| device_num = 1 | |||
| epoch_size = 1 | |||
| num_classes = 10 | |||
| batch_size = 32 | |||
| device_id = 0 | |||
| train_process(device_id, epoch_size, num_classes, device_num, batch_size) | |||
| time.sleep(3) | |||
| acc = eval(batch_size, num_classes) | |||
| os.chdir("../") | |||
| os.system("rm -rf " + str(device_id)) | |||
| print("End training...") | |||
| assert (acc['acc'] > 0.35) | |||
| @@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) { | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // SkipOp | |||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5); | |||
| std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2); | |||
| rc = my_tree->AssociateNode(skip_op); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| @@ -12,7 +12,6 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| @@ -51,7 +50,7 @@ def generator_md(): | |||
| def test_generator_skip(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4) | |||
| # Here ds1 should be [3, 4] | |||
| ds1 = ds1.skip(3) | |||
| @@ -60,6 +59,7 @@ def test_generator_skip(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 2 | |||
| assert buf == [3, 4] | |||
| def test_skip_1(): | |||
| @@ -72,6 +72,7 @@ def test_skip_1(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 0 | |||
| assert buf == [] | |||
| def test_skip_2(): | |||
| @@ -84,6 +85,7 @@ def test_skip_2(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 5 | |||
| assert buf == [0, 1, 2, 3, 4] | |||
| def test_skip_repeat_1(): | |||
| @@ -99,6 +101,7 @@ def test_skip_repeat_1(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 7 | |||
| assert buf == [3, 4, 0, 1, 2, 3, 4] | |||
| def test_skip_repeat_2(): | |||
| @@ -114,6 +117,7 @@ def test_skip_repeat_2(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 4 | |||
| assert buf == [3, 4, 3, 4] | |||
| def test_skip_repeat_3(): | |||
| @@ -132,6 +136,62 @@ def test_skip_repeat_3(): | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 6 | |||
| assert buf == [3, 4, 3, 4, 3, 4] | |||
| def test_skip_take_1(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [0, 1, 2, 3] | |||
| ds1 = ds1.take(4) | |||
| # Here ds1 should be [2, 3] | |||
| ds1 = ds1.skip(2) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 2 | |||
| assert buf == [2, 3] | |||
| def test_skip_take_2(): | |||
| ds1 = ds.GeneratorDataset(generator_md, ["data"]) | |||
| # Here ds1 should be [2, 3, 4] | |||
| ds1 = ds1.skip(2) | |||
| # Here ds1 should be [2, 3] | |||
| ds1 = ds1.take(2) | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 2 | |||
| assert buf == [2, 3] | |||
| def generator_1d(): | |||
| for i in range(64): | |||
| yield (np.array([i]), ) | |||
| def test_skip_filter_1(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ['data']) | |||
| dataset = dataset.skip(5) | |||
| dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) | |||
| buf = [] | |||
| for item in dataset: | |||
| buf.append(item[0][0]) | |||
| assert buf == [5, 6, 7, 8, 9, 10] | |||
| def test_skip_filter_2(): | |||
| dataset = ds.GeneratorDataset(generator_1d, ['data']) | |||
| dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) | |||
| dataset = dataset.skip(5) | |||
| buf = [] | |||
| for item in dataset: | |||
| buf.append(item[0][0]) | |||
| assert buf == [5, 6, 7, 8, 9, 10] | |||
| if __name__ == "__main__": | |||
| @@ -142,3 +202,7 @@ if __name__ == "__main__": | |||
| test_skip_repeat_1() | |||
| test_skip_repeat_2() | |||
| test_skip_repeat_3() | |||
| test_skip_take_1() | |||
| test_skip_take_2() | |||
| test_skip_filter_1() | |||
| test_skip_filter_2() | |||