diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index a596d339ec..54c998a92d 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -19,62 +19,65 @@ #include #include "common/utils.h" -#include "dataset/kernels/py_func_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" #include "dataset/core/tensor.h" #include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" +#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" +#include "dataset/engine/datasetops/source/cifar_op.h" +#include "dataset/engine/datasetops/source/clue_op.h" +#include "dataset/engine/datasetops/source/coco_op.h" +#include "dataset/engine/datasetops/source/image_folder_op.h" +#include "dataset/engine/datasetops/source/manifest_op.h" +#include "dataset/engine/datasetops/source/mnist_op.h" #include "dataset/engine/datasetops/source/random_data_op.h" #include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/engine/datasetops/filter_op.h" +#include "dataset/engine/datasetops/source/voc_op.h" +#include "dataset/kernels/py_func_op.h" +#include "dataset/util/random.h" +#include "dataset/util/status.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_distributed_sample.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "utils/log_adapter.h" #include "pybind11/stl.h" +#include "utils/log_adapter.h" namespace mindspore { namespace dataset { using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *); -static std::unordered_map g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp}, - {kShuffle, &DEPipeline::ParseShuffleOp}, - {kMindrecord, &DEPipeline::ParseMindRecordOp}, - {kMap, &DEPipeline::ParseMapOp}, - {kFilter, &DEPipeline::ParseFilterOp}, - {kBatch, &DEPipeline::ParseBatchOp}, - {kBarrier, &DEPipeline::ParseBarrierOp}, - {kRepeat, &DEPipeline::ParseRepeatOp}, - {kSkip, &DEPipeline::ParseSkipOp}, - {kZip, &DEPipeline::ParseZipOp}, - {kConcat, &DEPipeline::ParseConcatOp}, - {kRename, &DEPipeline::ParseRenameOp}, - {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, - {kGenerator, &DEPipeline::ParseGeneratorOp}, - {kTfReader, &DEPipeline::ParseTFReaderOp}, - {kProject, &DEPipeline::ParseProjectOp}, - {kTake, &DEPipeline::ParseTakeOp}, - {kImageFolder, &DEPipeline::ParseImageFolderOp}, - {kMnist, &DEPipeline::ParseMnistOp}, - {kManifest, &DEPipeline::ParseManifestOp}, - {kVoc, &DEPipeline::ParseVOCOp}, - {kCoco, &DEPipeline::ParseCocoOp}, - {kCifar10, &DEPipeline::ParseCifar10Op}, - {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}, - {kRandomData, &DEPipeline::ParseRandomDataOp}, - {kTextFile, &DEPipeline::ParseTextFileOp}, - {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, - {kClue, &DEPipeline::ParseClueOp}}; +static std::unordered_map g_parse_op_func_ = { + {kStorage, &DEPipeline::ParseStorageOp}, + {kShuffle, &DEPipeline::ParseShuffleOp}, + {kMindrecord, &DEPipeline::ParseMindRecordOp}, + {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, + {kBatch, &DEPipeline::ParseBatchOp}, + {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, + {kRepeat, &DEPipeline::ParseRepeatOp}, + {kSkip, &DEPipeline::ParseSkipOp}, + {kZip, &DEPipeline::ParseZipOp}, + {kConcat, &DEPipeline::ParseConcatOp}, + {kRename, &DEPipeline::ParseRenameOp}, + {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, + {kGenerator, &DEPipeline::ParseGeneratorOp}, + {kTfReader, &DEPipeline::ParseTFReaderOp}, + {kProject, &DEPipeline::ParseProjectOp}, + {kTake, &DEPipeline::ParseTakeOp}, + {kImageFolder, &DEPipeline::ParseImageFolderOp}, + {kMnist, &DEPipeline::ParseMnistOp}, + {kManifest, &DEPipeline::ParseManifestOp}, + {kVoc, &DEPipeline::ParseVOCOp}, + {kCoco, &DEPipeline::ParseCocoOp}, + {kCifar10, &DEPipeline::ParseCifar10Op}, + {kCifar100, &DEPipeline::ParseCifar100Op}, + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kRandomData, &DEPipeline::ParseRandomDataOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}, + {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, + {kClue, &DEPipeline::ParseClueOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { try { @@ -672,6 +675,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr return Status::OK(); } +Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *ptr) { + std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", + "bucket_batch_sizes"}; + for (auto name : mandatory_arguments) { + if (args[name.c_str()].is_none()) { + std::string err_msg = "Error: " + name + " is not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + std::shared_ptr builder = std::make_shared( + ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), + ToIntVector(args[mandatory_arguments[2].c_str()])); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "length_dependent_columns") { + (void)builder->SetLengthDependentColumns(ToStringVector(value)); + } + if (key == "bucket_boundaries") { + (void)builder->SetBucketBoundaries(ToIntVector(value)); + } + if (key == "bucket_batch_sizes") { + (void)builder->SetBucketBatchSizes(ToIntVector(value)); + } + if (key == "element_length_function") { + (void)builder->SetElementLengthFunction(value.cast()); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPadInfo(pad_info); + } + if (key == "pad_to_bucket_boundary") { + (void)builder->SetPadToBucketBoundary(ToBool(value)); + } + if (key == "drop_remainder") { + (void)builder->SetDropRemainder(ToBool(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr) { std::shared_ptr builder = std::make_shared(); // Right now barrier should only take num_rows_per_buffer = 1 diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index f856b3b2ca..493c092b1f 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -40,6 +40,7 @@ enum OpName { kShuffle, kMindrecord, kBatch, + kBucketBatch, kBarrier, kCache, kRepeat, @@ -121,6 +122,8 @@ class DEPipeline { Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr); Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 01728df86a..a5a0fc895d 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -616,6 +616,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("STORAGE", OpName::kStorage) .value("SHUFFLE", OpName::kShuffle) .value("BATCH", OpName::kBatch) + .value("BUCKETBATCH", OpName::kBucketBatch) .value("BARRIER", OpName::kBarrier) .value("MINDRECORD", OpName::kMindrecord) .value("CACHE", OpName::kCache) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 63265a2225..ed57421030 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -8,6 +8,7 @@ add_library(engine-datasetops OBJECT pipeline_op.cc barrier_op.cc batch_op.cc + bucket_batch_by_length_op.cc device_queue_op.cc map_op.cc project_op.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h index 6dc7a337b6..28df5e7e81 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h @@ -193,6 +193,22 @@ class BatchOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "BatchOp"; } + // batch the rows in src table then put it to dest table + // @param const std::unique_ptr *src - table that has the rows for batching + // @param const std::unique_ptr *dest - dest_table to hold batched rows + // @param int32_t size - batch_size + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, + dsize_t batch_size); + + // @param table + // @param const PadInfo &pad_info pad info + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, + const std::unordered_map &column_name_id_map); + private: // Worker thread for doing the memcpy of batch // @param int32_t param workerId @@ -203,16 +219,6 @@ class BatchOp : public ParallelOp { // @return Status - The error code return Status MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, std::unique_ptr *db); - - // batch the rows in src table then put it to dest table - // @param const std::unique_ptr *src - table that has the rows for batching - // @param const std::unique_ptr *dest - dest_table to hold batched rows - // @param int32_t size - batch_size - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, - dsize_t batch_size); - // Function that calls pyfunc to perform map on batch // @param (std::pair, batch_stats> *table_pair - contains un-batched tensor // @return Status - The error code return @@ -229,13 +235,6 @@ class BatchOp : public ParallelOp { std::set *pad_cols, std::vector> *pad_vals, std::vector> *pad_shapes); - // @param table - // @param const PadInfo &pad_info pad info - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, - const std::unordered_map &column_name_id_map); - // the number of thread pulling from the mOutConnector of the Op below // @return int32_t, 1 int32_t num_consumers() const override { return 1; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc new file mode 100644 index 0000000000..1f7edf5af5 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -0,0 +1,242 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" + +#include +#include +#include +#include +#include + +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "dataset/core/pybind_support.h" +#include "dataset/core/config_manager.h" +#include "dataset/core/tensor.h" +#include "dataset/core/tensor_shape.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/opt/pass.h" +#include "dataset/util/status.h" + +namespace py = pybind11; +namespace mindspore { +namespace dataset { +BucketBatchByLengthOp::Builder::Builder(std::vector length_dependent_columns, + std::vector bucket_boundaries, std::vector bucket_batch_sizes) + : builder_length_dependent_columns_(length_dependent_columns), + builder_bucket_boundaries_(bucket_boundaries), + builder_bucket_batch_sizes_(bucket_batch_sizes), + builder_pad_info_({}), + builder_pad_to_bucket_boundary_(false), + builder_drop_remainder_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_op_connector_size_ = config_manager->op_connector_size(); +} + +Status BucketBatchByLengthOp::Builder::SanityCheck() { + std::string error_message; + + if (builder_length_dependent_columns_.empty()) { + error_message += "At least 1 column must be specified for element length calculation.\n"; + } + + if (builder_bucket_boundaries_.empty()) { + error_message += "At least 1 bucket boundary must be specified.\n"; + } + + if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { + error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; + } + + CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr *new_bucket_batch_by_length_op) { + RETURN_IF_NOT_OK(SanityCheck()); + + // insert 0 for the first bucket + builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); + + *new_bucket_batch_by_length_op = std::make_shared( + builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, + builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, + builder_op_connector_size_); + + return Status::OK(); +} + +BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector length_dependent_columns, + std::vector bucket_boundaries, + std::vector bucket_batch_sizes, + py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, + int32_t op_connector_size) + : PipelineOp(op_connector_size), + length_dependent_columns_(length_dependent_columns), + bucket_boundaries_(bucket_boundaries), + bucket_batch_sizes_(bucket_batch_sizes), + element_length_function_(element_length_function), + pad_info_(pad_info), + pad_to_bucket_boundary_(pad_to_bucket_boundary), + drop_remainder_(drop_remainder), + batch_count_(0) { + for (int i = 0; i < bucket_batch_sizes_.size(); i++) { + buckets_.push_back(std::make_unique()); + } +} + +Status BucketBatchByLengthOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } + +Status BucketBatchByLengthOp::operator()() { + TaskManager::FindMe()->Post(); + + TensorRow current_row; + child_iterator_ = std::make_unique(this, 0, 0); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + RETURN_IF_NOT_OK(AssignColMapFromChild()); + while (!child_iterator_->eof_handled()) { + while (!current_row.empty()) { + int32_t element_length; + RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); + + int bucket_index = bucket_boundaries_.size() - 1; + while (element_length < bucket_boundaries_[bucket_index]) { + bucket_index--; + } + + buckets_[bucket_index]->push_back(current_row); + + if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { + RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); + } + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + // got EOE, do what we need to do with remainders in each bucket + if (!drop_remainder_) { + for (int i = 0; i < bucket_boundaries_.size(); i++) { + if (!buckets_[i]->empty()) { + RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); + } + } + } + + // need to send EOE manually since we set state to idle in EoeRecieved() + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { + // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of + // the single column specified in length_dependent_columns_ + if (element_length_function_) { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + size_t number_of_arguments = length_dependent_columns_.size(); + py::tuple input_arguments(number_of_arguments); + for (size_t i = 0; i < number_of_arguments; i++) { + py::array argument_value; + int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; + RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); + input_arguments[i] = argument_value; + } + + py::object length = element_length_function_(*input_arguments); + *out_element_length = length.cast(); + if (*out_element_length < 0) { + return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); + } + } else { + *out_element_length = element[0]->shape()[0]; + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { + std::unique_ptr *bucket = &buckets_[bucket_index]; + + PadInfo pad_info_copy = pad_info_; + if (pad_to_bucket_boundary_) { + for (auto &pair : pad_info_copy) { + std::vector pad_shape = pair.second.first.AsVector(); + + for (size_t i = 0; i < pad_shape.size(); i++) { + if (pad_shape[i] == TensorShape::kDimUnknown) { + if (bucket_index + 1 >= bucket_boundaries_.size()) { + std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); + } + + pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; + } + } + + pair.second.first = TensorShape(pad_shape); + } + } + + // PadColumns will change the data in bucket + RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); + + std::unique_ptr batched_bucket = std::make_unique(); + RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); + (*bucket)->clear(); + + std::unique_ptr batched_buffer = std::make_unique(batch_count_, DataBuffer::kDeBFlagNone); + batched_buffer->set_tensor_table(std::move(batched_bucket)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); + + batch_count_++; + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Reset() { + batch_count_ = 0; + + for (int i = 0; i < buckets_.size(); i++) { + buckets_[i] = std::make_unique(); + } + + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h new file mode 100644 index 0000000000..64ed523d59 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -0,0 +1,153 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ + +#include +#include +#include +#include +#include + +#include "dataset/core/config_manager.h" +#include "dataset/core/tensor.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/batch_op.h" +#include "dataset/engine/datasetops/pipeline_op.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DataBuffer; + +class BucketBatchByLengthOp : public PipelineOp { + public: + class Builder { + public: + Builder(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes); + + ~Builder() = default; + + Builder &SetLengthDependentColumns(std::vector length_dependent_columns) { + builder_length_dependent_columns_ = length_dependent_columns; + return *this; + } + + Builder &SetBucketBoundaries(std::vector bucket_boundaries) { + builder_bucket_boundaries_ = bucket_boundaries; + return *this; + } + + Builder &SetBucketBatchSizes(std::vector bucket_batch_sizes) { + builder_bucket_batch_sizes_ = bucket_batch_sizes; + return *this; + } + + Builder &SetElementLengthFunction(py::function element_length_function) { + builder_element_length_function_ = element_length_function; + return *this; + } + + Builder &SetPadInfo(PadInfo pad_info) { + builder_pad_info_ = pad_info; + return *this; + } + + Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { + builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; + return *this; + } + + Builder &SetDropRemainder(bool drop_remainder) { + builder_drop_remainder_ = drop_remainder; + return *this; + } + + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + Status Build(std::shared_ptr *new_bucket_batch_by_length_op); + + private: + Status SanityCheck(); + + std::vector builder_length_dependent_columns_; + std::vector builder_bucket_boundaries_; + std::vector builder_bucket_batch_sizes_; + py::function builder_element_length_function_; + PadInfo builder_pad_info_; + bool builder_pad_to_bucket_boundary_; + bool builder_drop_remainder_; + int32_t builder_op_connector_size_; + }; + + BucketBatchByLengthOp(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); + + // Might need to batch remaining buckets after receiving eoe, so override this method. + // @param int32_t workerId + // @return Status - The error code returned + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sO - reference to the BucketBatchByLengthOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { + bo.Print(out, false); + return out; + } + + // Main loop of batch + // @return Status - The error code returned + Status operator()() override; + + // Function that is called by ResetOp at the end of every epoch + // @return Status - The error code returned + Status Reset() override; + + private: + Status ObtainElementLength(int32_t *out_element_length, TensorRow element); + + Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); + + std::vector length_dependent_columns_; + std::vector bucket_boundaries_; + std::vector bucket_batch_sizes_; + py::function element_length_function_; + PadInfo pad_info_; + bool pad_to_bucket_boundary_; + bool drop_remainder_; + + int32_t batch_count_; + std::unique_ptr child_iterator_; + std::vector> buckets_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index fe29738cb8..12151d4737 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -42,9 +42,9 @@ from .iterators import DictIterator, TupleIterator from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ - check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ - check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_split, check_cluedataset + check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\ + check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\ + check_split, check_bucket_batch_by_length, check_cluedataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -165,6 +165,76 @@ class Dataset: args["num_parallel_workers"] = self.num_parallel_workers return args + @check_bucket_batch_by_length + def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, + element_length_function=None, pad_info=None, + pad_to_bucket_boundary=False, drop_remainder=False): + """ + Bucket elements according to their lengths, and pad and batch the buckets when + they are full. + + A length function is called on each row in the dataset, the row is then + bucketed based on its length and bucket_boundaries. When a bucket reaches its + corresponding size specified in bucket_batch_sizes, the entire bucket will be + padded according to batch_info, and then batched. Each batch will be full, + except for maybe the last batch for each bucket. + + Args: + column_names (list of string): Columns passed to element_length_function. + bucket_boundaries (list of int): A list consisting of the upper boundaries + of the buckets. Must be strictly increasing. If there are n boundaries, + n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one + bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each + 0>> import mindspore.dataset as ds + >>> # data is an instance of Dataset object. + >>> + >>> # creates a dataset where every 100 rows is combined into a batch + >>> # and drops the last incomplete batch if there is one. + >>> column_names = ["col1", "col2"] + >>> buket_boundaries = [5, 10] + >>> bucket_batch_sizes = [5, 1, 1] + >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2))) + >>> + >>> # will pad col1 to shape [2, bucket_boundaries[i]] where i is the + >>> # index of the bucket that is currently being batched. + >>> # will pad col2 to a shape where each dimension is the longest in all + >>> # the elements currently being batched. + >>> pad_info = {"col1", ([2, None], -1)} + >>> pad_to_bucket_boundary = True + >>> + >>> data = data.bucket_batch_by_length(column_names, bucket_boundaries, + >>> bucket_batch_sizes, + >>> element_length_function, pad_info), + >>> pad_to_bucket_boundary) + """ + return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes, + element_length_function, pad_info, + pad_to_bucket_boundary, drop_remainder) + @check_batch def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, input_columns=None, pad_info=None): @@ -1400,6 +1470,47 @@ class DatasetOp(Dataset): # No need for __init__ since it is the same as the super's init +class BucketBatchByLengthDataset(DatasetOp): + """ + The result of applying BucketBatchByLength operator to the input dataset. + """ + + def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, + element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder): + super().__init__() + + self.column_names = column_names + self.bucket_boundaries = bucket_boundaries + self.bucket_batch_sizes = bucket_batch_sizes + self.element_length_function = element_length_function + self.pad_info = pad_info + self.pad_to_bucket_boundary = pad_to_bucket_boundary + self.drop_remainder = drop_remainder + + self.input.append(input_dataset) + input_dataset.output.append(self) + self._input_indexs = input_dataset.input_indexs + + def get_args(self): + args = super().get_args() + args["length_dependent_columns"] = self.column_names + args["bucket_boundaries"] = self.bucket_boundaries + args["bucket_batch_sizes"] = self.bucket_batch_sizes + args["element_length_function"] = self.element_length_function + args["pad_info"] = self.pad_info + args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary + args["drop_remainder"] = self.drop_remainder + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + return None + class BatchDataset(DatasetOp): """ diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 11b082b0e0..d8cd53982d 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -132,6 +132,8 @@ class Iterator: op_type = OpName.MINDRECORD elif isinstance(dataset, de.BatchDataset): op_type = OpName.BATCH + elif isinstance(dataset, de.BucketBatchByLengthDataset): + op_type = OpName.BUCKETBATCH elif isinstance(dataset, de.SyncWaitDataset): op_type = OpName.BARRIER elif isinstance(dataset, de.ZipDataset): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b5ffbbdfc0..9a95463bee 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -752,6 +752,67 @@ def check_pad_info(key, val): check_type(val[1], "pad_value", (int, float, str, bytes)) +def check_bucket_batch_by_length(method): + """check the input arguments of bucket_batch_by_length.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] + check_param_type(nreq_param_list, param_dict, list) + + # check column_names: must be list of string. + column_names = param_dict.get("column_names") + all_string = all(isinstance(item, str) for item in column_names) + if not all_string: + raise TypeError("column_names should be a list of str.") + + element_length_function = param_dict.get("element_length_function") + if element_length_function is None and len(column_names) != 1: + raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") + + # check bucket_boundaries: must be list of int, positive and strictly increasing + bucket_boundaries = param_dict.get('bucket_boundaries') + + if not bucket_boundaries: + raise ValueError("bucket_boundaries cannot be empty.") + + all_int = all(isinstance(item, int) for item in bucket_boundaries) + if not all_int: + raise TypeError("bucket_boundaries should be a list of int.") + + all_non_negative = all(item >= 0 for item in bucket_boundaries) + if not all_non_negative: + raise ValueError("bucket_boundaries cannot contain any negative numbers.") + + for i in range(len(bucket_boundaries) - 1): + if not bucket_boundaries[i + 1] > bucket_boundaries[i]: + raise ValueError("bucket_boundaries should be strictly increasing.") + + # check bucket_batch_sizes: must be list of int and positive + bucket_batch_sizes = param_dict.get('bucket_batch_sizes') + if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: + raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") + + all_int = all(isinstance(item, int) for item in bucket_batch_sizes) + if not all_int: + raise TypeError("bucket_batch_sizes should be a list of int.") + + all_non_negative = all(item >= 0 for item in bucket_batch_sizes) + if not all_non_negative: + raise ValueError("bucket_batch_sizes cannot contain any negative numbers.") + + if param_dict.get('pad_info') is not None: + check_type(param_dict["pad_info"], "pad_info", dict) + for k, v in param_dict.get('pad_info').items(): + check_pad_info(k, v) + + return method(*args, **kwargs) + + return new_method + + def check_batch(method): """check the input arguments of batch.""" diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py new file mode 100644 index 0000000000..bca30723e9 --- /dev/null +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -0,0 +1,373 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import pytest +import numpy as np +import mindspore.dataset as ds + +# generates 1 column [0], [0, 1], ..., [0, ..., n-1] +def generate_sequential(n): + for i in range(n): + yield (np.array([j for j in range(i + 1)]),) + + +# generates 1 column [0], [1], ..., [n-1] +def generate_sequential_same_shape(n): + for i in range(n): + yield (np.array([i]),) + + +# combines generate_sequential_same_shape and generate_sequential +def generate_2_columns(n): + for i in range(n): + yield (np.array([i]), np.array([j for j in range(i + 1)])) + + +def test_bucket_batch_invalid_input(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + invalid_column_names = [1, 2, 3] + + bucket_boundaries = [1, 2, 3] + empty_bucket_boundaries = [] + invalid_bucket_boundaries = ["1", "2", "3"] + negative_bucket_boundaries = [1, 2, -3] + decreasing_bucket_boundaries = [3, 2, 1] + non_increasing_bucket_boundaries = [1, 2, 2] + + bucket_batch_sizes = [1, 1, 1, 1] + invalid_bucket_batch_sizes = ["1", "2", "3", "4"] + negative_bucket_batch_sizes = [1, 2, 3, -4] + + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) + assert "column_names should be a list of str" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries cannot be empty" in str(info.value) + + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries should be a list of int" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries cannot contain any negative numbers" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries should be strictly increasing" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, non_increasing_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries should be strictly increasing" in str(info.value) + + with pytest.raises(TypeError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, invalid_bucket_batch_sizes) + assert "bucket_batch_sizes should be a list of int" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) + assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) + assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value) + + +def test_bucket_batch_multi_bucket_no_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [3, 3, 2, 2] + element_length_function = (lambda x: x[0] % 4) + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function) + + expected_output = [[[2], [6]], + [[3], [7]], + [[0], [4], [8]], + [[1], [5], [9]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_multi_bucket_with_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [2, 3, 3, 2] + element_length_function = (lambda x: len(x) % 4) + pad_info = {"col1": ([10], 0)} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 1, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0, 0]], + [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0]], + [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_single_bucket_no_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [1, 1, 5, 1] + element_length_function = (lambda x: 2) + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function) + + expected_output = [[[0], [1], [2], [3], [4]], + [[5], [6], [7], [8], [9]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_single_bucket_with_padding(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [1, 1, 1, 3] + element_length_function = (lambda x: 7) + pad_info = {"col1": ([12], 0)} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_pad_to_bucket_boundary(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [3, 6, 15] + bucket_batch_sizes = [2, 3, 4, 1] + element_length_function = len + pad_info = {"col1": ([None], 0)} + pad_to_bucket_boundary = True + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info, pad_to_bucket_boundary) + + expected_output = [[[0, 0], + [0, 1]], + [[0, 1, 2, 0, 0], + [0, 1, 2, 3, 0], + [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_default_pad(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(15)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [5, 8, 17] + bucket_batch_sizes = [2, 1, 4, 1] + element_length_function = len + pad_info = {"col1": ([None], 0)} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 0], + [0, 1]], + [[0, 1, 2, 0], + [0, 1, 2, 3]], + [[0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5]], + [[0, 1, 2, 3, 4, 5, 6]], + [[0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] + + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_drop_remainder(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(27)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2] + bucket_batch_sizes = [2, 3, 5] + element_length_function = (lambda x: x[0] % 3) + pad_info = None + pad_to_bucket_boundary = False + drop_remainder = True + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info, pad_to_bucket_boundary, drop_remainder) + + expected_output = [[[0], [3]], + [[1], [4], [7]], + [[6], [9]], + [[2], [5], [8], [11], [14]], + [[12], [15]], + [[10], [13], [16]], + [[18], [21]], + [[19], [22], [25]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_default_length_function(): + dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [6, 12] + bucket_batch_sizes = [5, 4, 1] + element_length_function = None + pad_info = {} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + expected_output = [[[0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 2, 0, 0], + [0, 1, 2, 3, 0], + [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8]]] + + output = [] + for data in dataset.create_dict_iterator(): + output.append(data["col1"].tolist()) + + assert output == expected_output + + +def test_bucket_batch_multi_column(): + dataset = ds.GeneratorDataset((lambda: generate_2_columns(10)), ["same_shape", "variable_shape"]) + + column_names = ["same_shape"] + bucket_boundaries = [6, 12] + bucket_batch_sizes = [5, 5, 1] + element_length_function = None + pad_info = {} + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function, + pad_info) + + same_shape_expected_output = [[[0], [1], [2], [3], [4]], + [[5], [6], [7], [8], [9]]] + + variable_shape_expected_output = [[[0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 2, 0, 0], + [0, 1, 2, 3, 0], + [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 0, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 0, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0], + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] + + same_shape_output = [] + variable_shape_output = [] + for data in dataset.create_dict_iterator(): + same_shape_output.append(data["same_shape"].tolist()) + variable_shape_output.append(data["variable_shape"].tolist()) + + assert same_shape_output == same_shape_expected_output + assert variable_shape_output == variable_shape_expected_output + + +if __name__ == '__main__': + test_bucket_batch_invalid_input() + test_bucket_batch_multi_bucket_no_padding() + test_bucket_batch_multi_bucket_with_padding() + test_bucket_batch_single_bucket_no_padding() + test_bucket_batch_single_bucket_with_padding() + test_bucket_batch_pad_to_bucket_boundary() + test_bucket_batch_default_pad() + test_bucket_batch_drop_remainder() + test_bucket_batch_default_length_function() + test_bucket_batch_multi_column()