From d5adfa52100405774c840f2ec9266a83a02ecc27 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Wed, 29 Apr 2020 10:48:15 +0800 Subject: [PATCH 1/4] add accuracy for resnet cifar --- tests/st/tbe_networks/test_resnet_cifar_1p.py | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 tests/st/tbe_networks/test_resnet_cifar_1p.py diff --git a/tests/st/tbe_networks/test_resnet_cifar_1p.py b/tests/st/tbe_networks/test_resnet_cifar_1p.py new file mode 100644 index 0000000000..058ec3aeec --- /dev/null +++ b/tests/st/tbe_networks/test_resnet_cifar_1p.py @@ -0,0 +1,198 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore import context +import mindspore.common.dtype as mstype +import os +import numpy as np +import mindspore.ops.functional as F +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as vision +from resnet import resnet50 +import random +import time + +random.seed(1) +np.random.seed(1) +ds.config.set_seed(1) + +data_home = "/home/workspace/mindspore_dataset" + + +def create_dataset(repeat_num=1, training=True, batch_size=32): + data_dir = data_home + "/cifar-10-batches-bin" + if not training: + data_dir = data_home + "/cifar-10-verify-bin" + data_set = ds.Cifar10Dataset(data_dir) + + resize_height = 224 + resize_width = 224 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + random_crop_op = vision.RandomCrop( + (32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) + rescale_op = vision.Rescale(rescale, shift) + normalize_op = vision.Normalize( + (0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) + changeswap_op = vision.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, + changeswap_op] + + # apply map operations on images + data_set = data_set.map(input_columns="label", operations=type_cast_op) + data_set = data_set.map(input_columns="image", operations=c_trans) + + # apply shuffle operations + data_set = data_set.shuffle(buffer_size=1000) + + # apply batch operations + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + return data_set + + +class CrossEntropyLoss(nn.Cell): + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean() + self.one_hot = P.OneHot() + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + + def construct(self, logits, label): + label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero) + loss = self.cross_entropy(logits, label)[0] + loss = self.mean(loss, (-1,)) + return loss + + +class LossGet(Callback): + def __init__(self, per_print_times=1): + super(LossGet, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self._loss = 0.0 + + def step_end(self, run_context): + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." + .format(cb_params.cur_epoch_num, cur_step_in_epoch)) + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + self._loss = loss + print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss)) + + def get_loss(self): + return self._loss + + +def train_process(device_id, epoch_size, num_classes, device_num, batch_size): + os.system("mkdir " + str(device_id)) + os.chdir(str(device_id)) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(enable_task_sink=True, device_id=device_id) + context.set_context(enable_loop_sink=True) + context.set_context(enable_mem_reuse=True) + context.set_context(mode=context.GRAPH_MODE) + net = resnet50(batch_size, num_classes) + loss = CrossEntropyLoss() + opt = Momentum(filter(lambda x: x.requires_grad, + net.get_parameters()), 0.01, 0.9) + + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + dataset = create_dataset(epoch_size, training=True, batch_size=batch_size) + batch_num = dataset.get_dataset_size() + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./", + config=config_ck) + loss_cb = LossGet() + model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) + + +def eval(batch_size, num_classes): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(enable_task_sink=True, device_id=0) + context.set_context(enable_loop_sink=True) + context.set_context(enable_mem_reuse=True) + + net = resnet50(batch_size, num_classes) + loss = CrossEntropyLoss() + opt = Momentum(filter(lambda x: x.requires_grad, + net.get_parameters()), 0.01, 0.9) + + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt" + param_dict = load_checkpoint(checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + eval_dataset = create_dataset(1, training=False) + res = model.eval(eval_dataset) + print("result: ", res) + return res + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resnet_cifar_1p(): + device_num = 1 + epoch_size = 1 + num_classes = 10 + batch_size = 32 + device_id = 0 + train_process(device_id, epoch_size, num_classes, device_num, batch_size) + time.sleep(3) + acc = eval(batch_size, num_classes) + os.chdir("../") + os.system("rm -rf " + str(device_id)) + print("End training...") + assert (acc['acc'] > 0.35) From e31db0e1f7eed759d55fdaa8d52cd73f22175c48 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Wed, 29 Apr 2020 17:51:02 +0800 Subject: [PATCH 2/4] minddata fix gpu issue --- mindspore/dataset/engine/iterators.py | 31 +++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index ebee204b37..2cf95aa086 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -17,6 +17,7 @@ from abc import abstractmethod import copy import weakref +from importlib import import_module from mindspore._c_dataengine import DEPipeline from mindspore._c_dataengine import OpName @@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName from mindspore import log as logger from . import datasets as de +try: + context = import_module("mindspore.context") +except ModuleNotFoundError: + context = None + ITERATORS_LIST = list() def _cleanup(): + """Release all the Iterator.""" for itr_ref in ITERATORS_LIST: - itr = itr_ref() - if itr is not None: - itr.release() + if context: + device_type = context.get_context("device_target") + if device_type == "GPU": + itr_ref.release() + else: + itr = itr_ref() + if itr is not None: + itr.release() + else: + itr = itr_ref() + if itr is not None: + itr.release() def alter_tree(node): @@ -85,7 +101,14 @@ class Iterator: """ def __init__(self, dataset): - ITERATORS_LIST.append(weakref.ref(self)) + if context: + device_type = context.get_context("device_target") + if device_type == "GPU": + ITERATORS_LIST.append(self) + else: + ITERATORS_LIST.append(weakref.ref(self)) + else: + ITERATORS_LIST.append(weakref.ref(self)) # create a copy of tree and work on it. self.dataset = copy.deepcopy(dataset) self.dataset = alter_tree(self.dataset) From 5426899569535db22b4988d6fd5f3837e290ccf3 Mon Sep 17 00:00:00 2001 From: lihongkang Date: Tue, 28 Apr 2020 19:24:33 +0800 Subject: [PATCH 3/4] update mindspore/ops/operations/other_ops.py. update mindspore/ops/operations/other_ops.py. --- mindspore/ops/operations/other_ops.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 5e66050d9a..f2c0fccca9 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -76,8 +76,13 @@ class BoundingBoxEncode(PrimitiveWithInfer): Tensor, encoded bounding boxes. Examples: + >>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32) + >>> groundtruth_box = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32) >>> boundingbox_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) - >>> delta_box = boundingbox_encode(anchor_box, groundtruth_box) + >>> boundingbox_encode(anchor_box, groundtruth_box) + [[5.0000000e-01 5.0000000e-01 -6.5504000e+04 6.9335938e-01] + [-1.0000000e+00 2.5000000e-01 0.0000000e+00 4.0551758e-01]] + """ @prim_attr_register @@ -118,9 +123,14 @@ class BoundingBoxDecode(PrimitiveWithInfer): Tensor, decoded boxes. Examples: + >>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32) + >>> deltas = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32) >>> boundingbox_decode = P.BoundingBoxDecode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), >>> max_shape=(768, 1280), wh_ratio_clip=0.016) - >>> bbox = boundingbox_decode(anchor_box, deltas) + >>> boundingbox_decode(anchor_box, deltas) + [[4.1953125 0. 0. 5.1953125] + [2.140625 0. 3.859375 60.59375]] + """ @prim_attr_register From 34bfa2f7c9199c078665107e80f7ab9a2a5d4e48 Mon Sep 17 00:00:00 2001 From: jiangzhiwen Date: Wed, 29 Apr 2020 17:18:12 +0800 Subject: [PATCH 4/4] fix skip --- .../dataset/engine/datasetops/skip_op.cc | 100 ++++++++---------- .../ccsrc/dataset/engine/datasetops/skip_op.h | 19 +--- tests/ut/cpp/dataset/skip_op_test.cc | 2 +- tests/ut/python/dataset/test_skip.py | 68 +++++++++++- 4 files changed, 118 insertions(+), 71 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc index d851f2c699..a7b642d9d1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc @@ -16,6 +16,7 @@ #include #include +#include "dataset/core/config_manager.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/db_connector.h" @@ -26,7 +27,10 @@ namespace mindspore { namespace dataset { // Builder constructor. Creates the builder object. -SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {} +SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} Status SkipOp::Builder::SanityCheck() const { if (build_max_skips_ < 0) { @@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const { // The builder "build" method creates the final object. Status SkipOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_skips_); + *ptr = std::make_shared(build_max_skips_, builder_op_connector_size_); return Status::OK(); } // Constructor of the SkipOp. -SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {} +SkipOp::SkipOp(int32_t count, int32_t op_connector_size) + : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} // Destructor SkipOp::~SkipOp() {} @@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const { << "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_; } -// Since the buffer may contain multi rows, this function will drop the rows -// that need to skip in it, and then return the buffer. -Status SkipOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { - if (child_.empty()) { - RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node."); - } - - std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - - // Drop first max_skips_ rows - while (skip_count_ < max_skips_) { - if (buf->eoe() || buf->eof()) { - break; - } - - // Consider the rows of buffer more than 1 - TensorRow drop_row; - int row_num = buf->NumRows(); - int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; - skip_count_ += drop_num; - for (int i = 0; i < drop_num; i++) { - RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); - } - if (buf->NumRows() == 0) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - } - } - - // Handling eoe - if (buf->eoe()) { - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - } - - // Handling eof - if (buf->eof()) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - - *p_buffer = std::move(buf); - return Status::OK(); -} - // Base-class override for handling cases when an eoe is received. Status SkipOp::EoeReceived(int32_t worker_id) { skip_count_ = 0; @@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) { return Status::OK(); } -// Class functor operator () override. -// Most dataset ops operate by launching a thread (see ExecutionTree). -// However, the SkipOp is defined as a inlined operator, so it is invalid to -// launch the functor since this op runs inlined inside another operator. The -// function is overloaded to ensure that it is not called by mistake (it will -// generate an error). -Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); } +// main entry point for skip +Status SkipOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr curr_buffer; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + while (curr_buffer->eof() == false) { + // Reset count + skip_count_ = 0; + while (curr_buffer->eoe() == false) { + // Drop first count rows + while (skip_count_ < max_skips_) { + if (curr_buffer->eoe() || curr_buffer->eof()) { + break; + } + // Consider the rows of buffer more than one + TensorRow drop_row; + int row_num = curr_buffer->NumRows(); + int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; + skip_count_ += drop_num; + for (int i = 0; i < drop_num; i++) { + RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); + } + if (curr_buffer->NumRows() == 0) { + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + } + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + // we got eoe, now try again until we got eof + MS_LOG(DEBUG) << "Skip operator EOE Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + + MS_LOG(DEBUG) << "Skip operator EOF Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} // Base-class override for handling cases when an eof is received. Status SkipOp::EofReceived(int32_t worker_id) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h index 0ae520c3ad..a16b82ed21 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h @@ -42,6 +42,7 @@ class SkipOp : public PipelineOp { private: int32_t build_max_skips_; + int32_t builder_op_connector_size_; Status SanityCheck() const; }; @@ -49,7 +50,7 @@ class SkipOp : public PipelineOp { // Constructor of the SkipOp. // @note The builder class should be used to call it // @param count - The number of skips to do - explicit SkipOp(int32_t count); + explicit SkipOp(int32_t count, int32_t op_connector_size); // Destructor ~SkipOp(); @@ -60,23 +61,11 @@ class SkipOp : public PipelineOp { void Print(std::ostream &out, bool show_all) const override; // Class functor operator () override. - // Most dataset ops operate by launching a thread (see ExecutionTree). - // However, the SkipOp is defined as a inlined operator, so it is invalid to launch the - // functor since this op runs inlined inside another operator. The function is overloaded to - // ensure that it is not called by mistake (it will generate an error). + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work // @return Status - The error code return Status operator()() override; - // This function returns the buffer that is at the top of our output connector. The caller is - // typically our parent node, when the parent is asking us to provide the next buffer of data. - // Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get - // a buffer from our child. - // @param p_buffer - output pointer to the buffer that it will fetch. - // @param worker_id - The worker id - // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. - // @return Status - The error code return - Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; - // Base-class override for handling cases when an eoe is received. // @param worker_id - The worker id Status EoeReceived(int32_t worker_id) override; diff --git a/tests/ut/cpp/dataset/skip_op_test.cc b/tests/ut/cpp/dataset/skip_op_test.cc index c2168b24d4..697745512d 100644 --- a/tests/ut/cpp/dataset/skip_op_test.cc +++ b/tests/ut/cpp/dataset/skip_op_test.cc @@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) { ASSERT_TRUE(rc.IsOk()); // SkipOp - std::shared_ptr skip_op = std::make_shared(5); + std::shared_ptr skip_op = std::make_shared(5, 2); rc = my_tree->AssociateNode(skip_op); ASSERT_TRUE(rc.IsOk()); diff --git a/tests/ut/python/dataset/test_skip.py b/tests/ut/python/dataset/test_skip.py index 59893f6ded..ccbf40a55b 100644 --- a/tests/ut/python/dataset/test_skip.py +++ b/tests/ut/python/dataset/test_skip.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - import numpy as np import mindspore.dataset.transforms.vision.c_transforms as vision @@ -51,7 +50,7 @@ def generator_md(): def test_generator_skip(): - ds1 = ds.GeneratorDataset(generator_md, ["data"]) + ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4) # Here ds1 should be [3, 4] ds1 = ds1.skip(3) @@ -60,6 +59,7 @@ def test_generator_skip(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 2 + assert buf == [3, 4] def test_skip_1(): @@ -72,6 +72,7 @@ def test_skip_1(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 0 + assert buf == [] def test_skip_2(): @@ -84,6 +85,7 @@ def test_skip_2(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 5 + assert buf == [0, 1, 2, 3, 4] def test_skip_repeat_1(): @@ -99,6 +101,7 @@ def test_skip_repeat_1(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 7 + assert buf == [3, 4, 0, 1, 2, 3, 4] def test_skip_repeat_2(): @@ -114,6 +117,7 @@ def test_skip_repeat_2(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 4 + assert buf == [3, 4, 3, 4] def test_skip_repeat_3(): @@ -132,6 +136,62 @@ def test_skip_repeat_3(): for data in ds1: buf.append(data[0][0]) assert len(buf) == 6 + assert buf == [3, 4, 3, 4, 3, 4] + +def test_skip_take_1(): + ds1 = ds.GeneratorDataset(generator_md, ["data"]) + + # Here ds1 should be [0, 1, 2, 3] + ds1 = ds1.take(4) + + # Here ds1 should be [2, 3] + ds1 = ds1.skip(2) + + buf = [] + for data in ds1: + buf.append(data[0][0]) + assert len(buf) == 2 + assert buf == [2, 3] + +def test_skip_take_2(): + ds1 = ds.GeneratorDataset(generator_md, ["data"]) + + # Here ds1 should be [2, 3, 4] + ds1 = ds1.skip(2) + + # Here ds1 should be [2, 3] + ds1 = ds1.take(2) + + buf = [] + for data in ds1: + buf.append(data[0][0]) + assert len(buf) == 2 + assert buf == [2, 3] + + +def generator_1d(): + for i in range(64): + yield (np.array([i]), ) + +def test_skip_filter_1(): + dataset = ds.GeneratorDataset(generator_1d, ['data']) + dataset = dataset.skip(5) + dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) + + buf = [] + for item in dataset: + buf.append(item[0][0]) + assert buf == [5, 6, 7, 8, 9, 10] + +def test_skip_filter_2(): + dataset = ds.GeneratorDataset(generator_1d, ['data']) + dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) + dataset = dataset.skip(5) + + buf = [] + for item in dataset: + buf.append(item[0][0]) + assert buf == [5, 6, 7, 8, 9, 10] if __name__ == "__main__": @@ -142,3 +202,7 @@ if __name__ == "__main__": test_skip_repeat_1() test_skip_repeat_2() test_skip_repeat_3() + test_skip_take_1() + test_skip_take_2() + test_skip_filter_1() + test_skip_filter_2()