From 848e07d022d04d57f8fd00a77139722c46539c95 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Thu, 4 Jun 2020 10:30:48 -0400 Subject: [PATCH] initial commit, start of BucketBatchByLengthOp c implementation done, just need to call batch/pad added python api and validator added pybind/de_pipeline stuff, fixed some compile errors, figure out how null py::function works added tiny bit of doc integrated with static batch methods fixed some bugs some more bug fixes and cleanup ci fix fix ci ci fix fix ci added test_cases and debugged addressed code review comments addressed code review comments ci fix ci fix addressed code review comments addressed code review comments --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 135 +++++-- mindspore/ccsrc/dataset/api/de_pipeline.h | 3 + .../ccsrc/dataset/api/python_bindings.cc | 1 + .../dataset/engine/datasetops/CMakeLists.txt | 1 + .../dataset/engine/datasetops/batch_op.h | 33 +- .../datasetops/bucket_batch_by_length_op.cc | 242 ++++++++++++ .../datasetops/bucket_batch_by_length_op.h | 153 +++++++ mindspore/dataset/engine/datasets.py | 117 +++++- mindspore/dataset/engine/iterators.py | 2 + mindspore/dataset/engine/validators.py | 61 +++ .../dataset/test_bucket_batch_by_length.py | 373 ++++++++++++++++++ 11 files changed, 1060 insertions(+), 61 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h create mode 100644 tests/ut/python/dataset/test_bucket_batch_by_length.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index a596d339ec..54c998a92d 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -19,62 +19,65 @@ #include #include "common/utils.h" -#include "dataset/kernels/py_func_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" #include "dataset/core/tensor.h" #include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" +#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" +#include "dataset/engine/datasetops/source/cifar_op.h" +#include "dataset/engine/datasetops/source/clue_op.h" +#include "dataset/engine/datasetops/source/coco_op.h" +#include "dataset/engine/datasetops/source/image_folder_op.h" +#include "dataset/engine/datasetops/source/manifest_op.h" +#include "dataset/engine/datasetops/source/mnist_op.h" #include "dataset/engine/datasetops/source/random_data_op.h" #include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/engine/datasetops/filter_op.h" +#include "dataset/engine/datasetops/source/voc_op.h" +#include "dataset/kernels/py_func_op.h" +#include "dataset/util/random.h" +#include "dataset/util/status.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "utils/log_adapter.h" #include "pybind11/stl.h" +#include "utils/log_adapter.h" namespace mindspore { namespace dataset { using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *); -static std::unordered_map g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp}, - {kShuffle, &DEPipeline::ParseShuffleOp}, - {kMindrecord, &DEPipeline::ParseMindRecordOp}, - {kMap, &DEPipeline::ParseMapOp}, - {kFilter, &DEPipeline::ParseFilterOp}, - {kBatch, &DEPipeline::ParseBatchOp}, - {kBarrier, &DEPipeline::ParseBarrierOp}, - {kRepeat, &DEPipeline::ParseRepeatOp}, - {kSkip, &DEPipeline::ParseSkipOp}, - {kZip, &DEPipeline::ParseZipOp}, - {kConcat, &DEPipeline::ParseConcatOp}, - {kRename, &DEPipeline::ParseRenameOp}, - {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, - {kGenerator, &DEPipeline::ParseGeneratorOp}, - {kTfReader, &DEPipeline::ParseTFReaderOp}, - {kProject, &DEPipeline::ParseProjectOp}, - {kTake, &DEPipeline::ParseTakeOp}, - {kImageFolder, &DEPipeline::ParseImageFolderOp}, - {kMnist, &DEPipeline::ParseMnistOp}, - {kManifest, &DEPipeline::ParseManifestOp}, - {kVoc, &DEPipeline::ParseVOCOp}, - {kCoco, &DEPipeline::ParseCocoOp}, - {kCifar10, &DEPipeline::ParseCifar10Op}, - {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}, - {kRandomData, &DEPipeline::ParseRandomDataOp}, - {kTextFile, &DEPipeline::ParseTextFileOp}, - {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, - {kClue, &DEPipeline::ParseClueOp}}; +static std::unordered_map g_parse_op_func_ = { + {kStorage, &DEPipeline::ParseStorageOp}, + {kShuffle, &DEPipeline::ParseShuffleOp}, + {kMindrecord, &DEPipeline::ParseMindRecordOp}, + {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, + {kBatch, &DEPipeline::ParseBatchOp}, + {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, + {kRepeat, &DEPipeline::ParseRepeatOp}, + {kSkip, &DEPipeline::ParseSkipOp}, + {kZip, &DEPipeline::ParseZipOp}, + {kConcat, &DEPipeline::ParseConcatOp}, + {kRename, &DEPipeline::ParseRenameOp}, + {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, + {kGenerator, &DEPipeline::ParseGeneratorOp}, + {kTfReader, &DEPipeline::ParseTFReaderOp}, + {kProject, &DEPipeline::ParseProjectOp}, + {kTake, &DEPipeline::ParseTakeOp}, + {kImageFolder, &DEPipeline::ParseImageFolderOp}, + {kMnist, &DEPipeline::ParseMnistOp}, + {kManifest, &DEPipeline::ParseManifestOp}, + {kVoc, &DEPipeline::ParseVOCOp}, + {kCoco, &DEPipeline::ParseCocoOp}, + {kCifar10, &DEPipeline::ParseCifar10Op}, + {kCifar100, &DEPipeline::ParseCifar100Op}, + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kRandomData, &DEPipeline::ParseRandomDataOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}, + {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, + {kClue, &DEPipeline::ParseClueOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { try { @@ -672,6 +675,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr return Status::OK(); } +Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *ptr) { + std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", + "bucket_batch_sizes"}; + for (auto name : mandatory_arguments) { + if (args[name.c_str()].is_none()) { + std::string err_msg = "Error: " + name + " is not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + std::shared_ptr builder = std::make_shared( + ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), + ToIntVector(args[mandatory_arguments[2].c_str()])); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "length_dependent_columns") { + (void)builder->SetLengthDependentColumns(ToStringVector(value)); + } + if (key == "bucket_boundaries") { + (void)builder->SetBucketBoundaries(ToIntVector(value)); + } + if (key == "bucket_batch_sizes") { + (void)builder->SetBucketBatchSizes(ToIntVector(value)); + } + if (key == "element_length_function") { + (void)builder->SetElementLengthFunction(value.cast()); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPadInfo(pad_info); + } + if (key == "pad_to_bucket_boundary") { + (void)builder->SetPadToBucketBoundary(ToBool(value)); + } + if (key == "drop_remainder") { + (void)builder->SetDropRemainder(ToBool(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr) { std::shared_ptr builder = std::make_shared(); // Right now barrier should only take num_rows_per_buffer = 1 diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index f856b3b2ca..493c092b1f 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -40,6 +40,7 @@ enum OpName { kShuffle, kMindrecord, kBatch, + kBucketBatch, kBarrier, kCache, kRepeat, @@ -121,6 +122,8 @@ class DEPipeline { Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr); Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 01728df86a..a5a0fc895d 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -616,6 +616,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("STORAGE", OpName::kStorage) .value("SHUFFLE", OpName::kShuffle) .value("BATCH", OpName::kBatch) + .value("BUCKETBATCH", OpName::kBucketBatch) .value("BARRIER", OpName::kBarrier) .value("MINDRECORD", OpName::kMindrecord) .value("CACHE", OpName::kCache) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 63265a2225..ed57421030 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -8,6 +8,7 @@ add_library(engine-datasetops OBJECT pipeline_op.cc barrier_op.cc batch_op.cc + bucket_batch_by_length_op.cc device_queue_op.cc map_op.cc project_op.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h index 6dc7a337b6..28df5e7e81 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h @@ -193,6 +193,22 @@ class BatchOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "BatchOp"; } + // batch the rows in src table then put it to dest table + // @param const std::unique_ptr *src - table that has the rows for batching + // @param const std::unique_ptr *dest - dest_table to hold batched rows + // @param int32_t size - batch_size + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, + dsize_t batch_size); + + // @param table + // @param const PadInfo &pad_info pad info + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, + const std::unordered_map &column_name_id_map); + private: // Worker thread for doing the memcpy of batch // @param int32_t param workerId @@ -203,16 +219,6 @@ class BatchOp : public ParallelOp { // @return Status - The error code return Status MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, std::unique_ptr *db); - - // batch the rows in src table then put it to dest table - // @param const std::unique_ptr *src - table that has the rows for batching - // @param const std::unique_ptr *dest - dest_table to hold batched rows - // @param int32_t size - batch_size - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, - dsize_t batch_size); - // Function that calls pyfunc to perform map on batch // @param (std::pair, batch_stats> *table_pair - contains un-batched tensor // @return Status - The error code return @@ -229,13 +235,6 @@ class BatchOp : public ParallelOp { std::set *pad_cols, std::vector> *pad_vals, std::vector> *pad_shapes); - // @param table - // @param const PadInfo &pad_info pad info - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, - const std::unordered_map &column_name_id_map); - // the number of thread pulling from the mOutConnector of the Op below // @return int32_t, 1 int32_t num_consumers() const override { return 1; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc new file mode 100644 index 0000000000..1f7edf5af5 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -0,0 +1,242 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" + +#include +#include +#include +#include +#include + +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "dataset/core/pybind_support.h" +#include "dataset/core/config_manager.h" +#include "dataset/core/tensor.h" +#include "dataset/core/tensor_shape.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/opt/pass.h" +#include "dataset/util/status.h" + +namespace py = pybind11; +namespace mindspore { +namespace dataset { +BucketBatchByLengthOp::Builder::Builder(std::vector length_dependent_columns, + std::vector bucket_boundaries, std::vector bucket_batch_sizes) + : builder_length_dependent_columns_(length_dependent_columns), + builder_bucket_boundaries_(bucket_boundaries), + builder_bucket_batch_sizes_(bucket_batch_sizes), + builder_pad_info_({}), + builder_pad_to_bucket_boundary_(false), + builder_drop_remainder_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_op_connector_size_ = config_manager->op_connector_size(); +} + +Status BucketBatchByLengthOp::Builder::SanityCheck() { + std::string error_message; + + if (builder_length_dependent_columns_.empty()) { + error_message += "At least 1 column must be specified for element length calculation.\n"; + } + + if (builder_bucket_boundaries_.empty()) { + error_message += "At least 1 bucket boundary must be specified.\n"; + } + + if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { + error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; + } + + CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr *new_bucket_batch_by_length_op) { + RETURN_IF_NOT_OK(SanityCheck()); + + // insert 0 for the first bucket + builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); + + *new_bucket_batch_by_length_op = std::make_shared( + builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, + builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, + builder_op_connector_size_); + + return Status::OK(); +} + +BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector length_dependent_columns, + std::vector bucket_boundaries, + std::vector bucket_batch_sizes, + py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, + int32_t op_connector_size) + : PipelineOp(op_connector_size), + length_dependent_columns_(length_dependent_columns), + bucket_boundaries_(bucket_boundaries), + bucket_batch_sizes_(bucket_batch_sizes), + element_length_function_(element_length_function), + pad_info_(pad_info), + pad_to_bucket_boundary_(pad_to_bucket_boundary), + drop_remainder_(drop_remainder), + batch_count_(0) { + for (int i = 0; i < bucket_batch_sizes_.size(); i++) { + buckets_.push_back(std::make_unique()); + } +} + +Status BucketBatchByLengthOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } + +Status BucketBatchByLengthOp::operator()() { + TaskManager::FindMe()->Post(); + + TensorRow current_row; + child_iterator_ = std::make_unique(this, 0, 0); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + RETURN_IF_NOT_OK(AssignColMapFromChild()); + while (!child_iterator_->eof_handled()) { + while (!current_row.empty()) { + int32_t element_length; + RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); + + int bucket_index = bucket_boundaries_.size() - 1; + while (element_length < bucket_boundaries_[bucket_index]) { + bucket_index--; + } + + buckets_[bucket_index]->push_back(current_row); + + if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { + RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); + } + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + // got EOE, do what we need to do with remainders in each bucket + if (!drop_remainder_) { + for (int i = 0; i < bucket_boundaries_.size(); i++) { + if (!buckets_[i]->empty()) { + RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); + } + } + } + + // need to send EOE manually since we set state to idle in EoeRecieved() + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { + // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of + // the single column specified in length_dependent_columns_ + if (element_length_function_) { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + size_t number_of_arguments = length_dependent_columns_.size(); + py::tuple input_arguments(number_of_arguments); + for (size_t i = 0; i < number_of_arguments; i++) { + py::array argument_value; + int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; + RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); + input_arguments[i] = argument_value; + } + + py::object length = element_length_function_(*input_arguments); + *out_element_length = length.cast(); + if (*out_element_length < 0) { + return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); + } + } else { + *out_element_length = element[0]->shape()[0]; + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { + std::unique_ptr *bucket = &buckets_[bucket_index]; + + PadInfo pad_info_copy = pad_info_; + if (pad_to_bucket_boundary_) { + for (auto &pair : pad_info_copy) { + std::vector pad_shape = pair.second.first.AsVector(); + + for (size_t i = 0; i < pad_shape.size(); i++) { + if (pad_shape[i] == TensorShape::kDimUnknown) { + if (bucket_index + 1 >= bucket_boundaries_.size()) { + std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); + } + + pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; + } + } + + pair.second.first = TensorShape(pad_shape); + } + } + + // PadColumns will change the data in bucket + RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); + + std::unique_ptr batched_bucket = std::make_unique(); + RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); + (*bucket)->clear(); + + std::unique_ptr batched_buffer = std::make_unique(batch_count_, DataBuffer::kDeBFlagNone); + batched_buffer->set_tensor_table(std::move(batched_bucket)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); + + batch_count_++; + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Reset() { + batch_count_ = 0; + + for (int i = 0; i < buckets_.size(); i++) { + buckets_[i] = std::make_unique(); + } + + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h new file mode 100644 index 0000000000..64ed523d59 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -0,0 +1,153 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ + +#include +#include +#include +#include +#include + +#include "dataset/core/config_manager.h" +#include "dataset/core/tensor.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/batch_op.h" +#include "dataset/engine/datasetops/pipeline_op.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DataBuffer; + +class BucketBatchByLengthOp : public PipelineOp { + public: + class Builder { + public: + Builder(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes); + + ~Builder() = default; + + Builder &SetLengthDependentColumns(std::vector length_dependent_columns) { + builder_length_dependent_columns_ = length_dependent_columns; + return *this; + } + + Builder &SetBucketBoundaries(std::vector bucket_boundaries) { + builder_bucket_boundaries_ = bucket_boundaries; + return *this; + } + + Builder &SetBucketBatchSizes(std::vector bucket_batch_sizes) { + builder_bucket_batch_sizes_ = bucket_batch_sizes; + return *this; + } + + Builder &SetElementLengthFunction(py::function element_length_function) { + builder_element_length_function_ = element_length_function; + return *this; + } + + Builder &SetPadInfo(PadInfo pad_info) { + builder_pad_info_ = pad_info; + return *this; + } + + Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { + builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; + return *this; + } + + Builder &SetDropRemainder(bool drop_remainder) { + builder_drop_remainder_ = drop_remainder; + return *this; + } + + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + Status Build(std::shared_ptr *new_bucket_batch_by_length_op); + + private: + Status SanityCheck(); + + std::vector builder_length_dependent_columns_; + std::vector builder_bucket_boundaries_; + std::vector builder_bucket_batch_sizes_; + py::function builder_element_length_function_; + PadInfo builder_pad_info_; + bool builder_pad_to_bucket_boundary_; + bool builder_drop_remainder_; + int32_t builder_op_connector_size_; + }; + + BucketBatchByLengthOp(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); + + // Might need to batch remaining buckets after receiving eoe, so override this method. + // @param int32_t workerId + // @return Status - The error code returned + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sO - reference to the BucketBatchByLengthOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { + bo.Print(out, false); + return out; + } + + // Main loop of batch + // @return Status - The error code returned + Status operator()() override; + + // Function that is called by ResetOp at the end of every epoch + // @return Status - The error code returned + Status Reset() override; + + private: + Status ObtainElementLength(int32_t *out_element_length, TensorRow element); + + Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); + + std::vector length_dependent_columns_; + std::vector bucket_boundaries_; + std::vector bucket_batch_sizes_; + py::function element_length_function_; + PadInfo pad_info_; + bool pad_to_bucket_boundary_; + bool drop_remainder_; + + int32_t batch_count_; + std::unique_ptr child_iterator_; + std::vector> buckets_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index fe29738cb8..12151d4737 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -42,9 +42,9 @@ from .iterators import DictIterator, TupleIterator from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ - check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ - check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_split, check_cluedataset + check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\ + check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\ + check_split, check_bucket_batch_by_length, check_cluedataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -165,6 +165,76 @@ class Dataset: args["num_parallel_workers"] = self.num_parallel_workers return args + @check_bucket_batch_by_length + def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, + element_length_function=None, pad_info=None, + pad_to_bucket_boundary=False, drop_remainder=False): + """ + Bucket elements according to their lengths, and pad and batch the buckets when + they are full. + + A length function is called on each row in the dataset, the row is then + bucketed based on its length and bucket_boundaries. When a bucket reaches its + corresponding size specified in bucket_batch_sizes, the entire bucket will be + padded according to batch_info, and then batched. Each batch will be full, + except for maybe the last batch for each bucket. + + Args: + column_names (list of string): Columns passed to element_length_function. + bucket_boundaries (list of int): A list consisting of the upper boundaries + of the buckets. Must be strictly increasing. If there are n boundaries, + n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one + bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each + 0>> import mindspore.dataset as ds + >>> # data is an instance of Dataset object. + >>> + >>> # creates a dataset where every 100 rows is combined into a batch + >>> # and drops the last incomplete batch if there is one. + >>> column_names = ["col1", "col2"] + >>> buket_boundaries = [5, 10] + >>> bucket_batch_sizes = [5, 1, 1] + >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2))) + >>> + >>> # will pad col1 to shape [2, bucket_boundaries[i]] where i is the + >>> # index of the bucket that is currently being batched. + >>> # will pad col2 to a shape where each dimension is the longest in all + >>> # the elements currently being batched. + >>> pad_info = {"col1", ([2, None], -1)} + >>> pad_to_bucket_boundary = True + >>> + >>> data = data.bucket_batch_by_length(column_names, bucket_boundaries, + >>> bucket_batch_sizes, + >>> element_length_function, pad_info), + >>> pad_to_bucket_boundary) + """ + return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes, + element_length_function, pad_info, + pad_to_bucket_boundary, drop_remainder) + @check_batch def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, pad_info=None): @@ -1400,6 +1470,47 @@ class DatasetOp(Dataset): # No need for __init__ since it is the same as the super's init +class BucketBatchByLengthDataset(DatasetOp): + """ + The result of applying BucketBatchByLength operator to the input dataset. + """ + + def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, + element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder): + super().__init__() + + self.column_names = column_names + self.bucket_boundaries = bucket_boundaries + self.bucket_batch_sizes = bucket_batch_sizes + self.element_length_function = element_length_function + self.pad_info = pad_info + self.pad_to_bucket_boundary = pad_to_bucket_boundary + self.drop_remainder = drop_remainder + + self.input.append(input_dataset) + input_dataset.output.append(self) + self._input_indexs = input_dataset.input_indexs + + def get_args(self): + args = super().get_args() + args["length_dependent_columns"] = self.column_names + args["bucket_boundaries"] = self.bucket_boundaries + args["bucket_batch_sizes"] = self.bucket_batch_sizes + args["element_length_function"] = self.element_length_function + args["pad_info"] = self.pad_info + args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary + args["drop_remainder"] = self.drop_remainder + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + return None + class BatchDataset(DatasetOp): """ diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 11b082b0e0..d8cd53982d 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -132,6 +132,8 @@ class Iterator: op_type = OpName.MINDRECORD elif isinstance(dataset, de.BatchDataset): op_type = OpName.BATCH + elif isinstance(dataset, de.BucketBatchByLengthDataset): + op_type = OpName.BUCKETBATCH elif isinstance(dataset, de.SyncWaitDataset): op_type = OpName.BARRIER elif isinstance(dataset, de.ZipDataset): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b5ffbbdfc0..9a95463bee 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -752,6 +752,67 @@ def check_pad_info(key, val): check_type(val[1], "pad_value", (int, float, str, bytes)) +def check_bucket_batch_by_length(method): + """check the input arguments of bucket_batch_by_length.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] + check_param_type(nreq_param_list, param_dict, list) + + # check column_names: must be list of string. + column_names = param_dict.get("column_names") + all_string = all(isinstance(item, str) for item in column_names) + if not all_string: + raise TypeError("column_names should be a list of str.") + + element_length_function = param_dict.get("element_length_function") + if element_length_function is None and len(column_names) != 1: + raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") + + # check bucket_boundaries: must be list of int, positive and strictly increasing + bucket_boundaries = param_dict.get('bucket_boundaries') + + if not bucket_boundaries: + raise ValueError("bucket_boundaries cannot be empty.") + + all_int = all(isinstance(item, int) for item in bucket_boundaries) + if not all_int: + raise TypeError("bucket_boundaries should be a list of int.") + + all_non_negative = all(item >= 0 for item in bucket_boundaries) + if not all_non_negative: + raise ValueError("bucket_boundaries cannot contain any negative numbers.") + + for i in range(len(bucket_boundaries) - 1): + if not bucket_boundaries[i + 1] > bucket_boundaries[i]: + raise ValueError("bucket_boundaries should be strictly increasing.") + + # check bucket_batch_sizes: must be list of int and positive + bucket_batch_sizes = param_dict.get('bucket_batch_sizes') + if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: + raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") + + all_int = all(isinstance(item, int) for item in bucket_batch_sizes) + if not all_int: + raise TypeError("bucket_batch_sizes should be a list of int.") + + all_non_negative = all(item >= 0 for item in bucket_batch_sizes) + if not all_non_negative: + raise ValueError("bucket_batch_sizes cannot contain any negative numbers.") + + if param_dict.get('pad_info') is not None: + check_type(param_dict["pad_info"], "pad_info", dict) + for k, v in param_dict.get('pad_info').items(): + check_pad_info(k, v) + + return method(*args, **kwargs) + + return new_method + + def check_batch(method): """check the input arguments of batch.""" diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py new file mode 100644 index 0000000000..bca30723e9 --- /dev/null +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -0,0 +1,373 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import pytest +import numpy as np +import mindspore.dataset as ds + +# generates 1 column [0], [0, 1], ..., [0, ..., n-1] +def generate_sequential(n): + for i in range(n): + yield (np.array([j for j in range(i + 1)]),) + + +# generates 1 column [0], [1], ..., [n-1] +def generate_sequential_same_shape(n): + for i in range(n): + yield (np.array([i]),) + + +# combines generate_sequential_same_shape and generate_sequential +def generate_2_columns(n): + for i in range(n): + yield (np.array([i]), np.array([j for j in range(i + 1)])) + + +def test_bucket_batch_invalid_input(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + invalid_column_names = [1, 2, 3] + + bucket_boundaries = [1, 2, 3] + empty_bucket_boundaries = [] + invalid_bucket_boundaries = ["1", "2", "3"] + negative_bucket_boundaries = [1, 2, -3] + decreasing_bucket_boundaries = [3, 2, 1] + non_increasing_bucket_boundaries = [1, 2, 2] + + bucket_batch_sizes = [1, 1, 1, 1] + invalid_bucket_batch_sizes = ["1", "2", "3", "4"] + negative_bucket_batch_sizes = [1, 2, 3, -4] + + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) + assert "column_names should be a list of str" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries cannot be empty" in str(info.value) + + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries should be a list of int" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries cannot contain any negative numbers" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries should be strictly increasing" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, non_increasing_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries should be strictly increasing" in str(info.value) + + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, invalid_bucket_batch_sizes) + assert "bucket_batch_sizes should be a list of int" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) + assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) + assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value) + + +def test_bucket_batch_multi_bucket_no_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [3, 3, 2, 2] + element_length_function = (lambda x: x[0] % 4) + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function) + + expected_output = [[[2], [6]], + [[3], [7]], + [[0], [4], [8]], + [[1], [5], [9]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_multi_bucket_with_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [2, 3, 3, 2] + element_length_function = (lambda x: len(x) % 4) + pad_info = {"col1": ([10], 0)} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 1, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0, 0]], + [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0]], + [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_single_bucket_no_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [1, 1, 5, 1] + element_length_function = (lambda x: 2) + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function) + + expected_output = [[[0], [1], [2], [3], [4]], + [[5], [6], [7], [8], [9]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_single_bucket_with_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [1, 1, 1, 3] + element_length_function = (lambda x: 7) + pad_info = {"col1": ([12], 0)} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_pad_to_bucket_boundary(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [3, 6, 15] + bucket_batch_sizes = [2, 3, 4, 1] + element_length_function = len + pad_info = {"col1": ([None], 0)} + pad_to_bucket_boundary = True + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info, pad_to_bucket_boundary) + + expected_output = [[[0, 0], + [0, 1]], + [[0, 1, 2, 0, 0], + [0, 1, 2, 3, 0], + [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_default_pad(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(15)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [5, 8, 17] + bucket_batch_sizes = [2, 1, 4, 1] + element_length_function = len + pad_info = {"col1": ([None], 0)} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 0], + [0, 1]], + [[0, 1, 2, 0], + [0, 1, 2, 3]], + [[0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5]], + [[0, 1, 2, 3, 4, 5, 6]], + [[0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] + + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_drop_remainder(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(27)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2] + bucket_batch_sizes = [2, 3, 5] + element_length_function = (lambda x: x[0] % 3) + pad_info = None + pad_to_bucket_boundary = False + drop_remainder = True + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info, pad_to_bucket_boundary, drop_remainder) + + expected_output = [[[0], [3]], + [[1], [4], [7]], + [[6], [9]], + [[2], [5], [8], [11], [14]], + [[12], [15]], + [[10], [13], [16]], + [[18], [21]], + [[19], [22], [25]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_default_length_function(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [6, 12] + bucket_batch_sizes = [5, 4, 1] + element_length_function = None + pad_info = {} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 2, 0, 0], + [0, 1, 2, 3, 0], + [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_multi_column(): + dataset = ds.GeneratorDataset((lambda: generate_2_columns(10)), ["same_shape", "variable_shape"]) + + column_names = ["same_shape"] + bucket_boundaries = [6, 12] + bucket_batch_sizes = [5, 5, 1] + element_length_function = None + pad_info = {} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + same_shape_expected_output = [[[0], [1], [2], [3], [4]], + [[5], [6], [7], [8], [9]]] + + variable_shape_expected_output = [[[0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 2, 0, 0], + [0, 1, 2, 3, 0], + [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] + + same_shape_output = [] + variable_shape_output = [] + for data in dataset.create_dict_iterator(): + same_shape_output.append(data["same_shape"].tolist()) + variable_shape_output.append(data["variable_shape"].tolist()) + + assert same_shape_output == same_shape_expected_output + assert variable_shape_output == variable_shape_expected_output + + +if __name__ == '__main__': + test_bucket_batch_invalid_input() + test_bucket_batch_multi_bucket_no_padding() + test_bucket_batch_multi_bucket_with_padding() + test_bucket_batch_single_bucket_no_padding() + test_bucket_batch_single_bucket_with_padding() + test_bucket_batch_pad_to_bucket_boundary() + test_bucket_batch_default_pad() + test_bucket_batch_drop_remainder() + test_bucket_batch_default_length_function() + test_bucket_batch_multi_column()