Merge pull request !1984 from Peilin/BucketBatchByLengthOptags/v0.5.0-beta
| @@ -19,62 +19,65 @@ | |||
| #include <map> | |||
| #include "common/utils.h" | |||
| #include "dataset/kernels/py_func_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #include "dataset/engine/datasetops/source/coco_op.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||
| #include "dataset/engine/datasetops/source/coco_op.h" | |||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||
| #include "dataset/engine/datasetops/filter_op.h" | |||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||
| #include "dataset/kernels/py_func_op.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| #include "mindrecord/include/shard_category.h" | |||
| #include "mindrecord/include/shard_distributed_sample.h" | |||
| #include "mindrecord/include/shard_sample.h" | |||
| #include "mindrecord/include/shard_shuffle.h" | |||
| #include "dataset/util/random.h" | |||
| #include "dataset/util/status.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "pybind11/stl.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *); | |||
| static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp}, | |||
| {kShuffle, &DEPipeline::ParseShuffleOp}, | |||
| {kMindrecord, &DEPipeline::ParseMindRecordOp}, | |||
| {kMap, &DEPipeline::ParseMapOp}, | |||
| {kFilter, &DEPipeline::ParseFilterOp}, | |||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||
| {kBarrier, &DEPipeline::ParseBarrierOp}, | |||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||
| {kZip, &DEPipeline::ParseZipOp}, | |||
| {kConcat, &DEPipeline::ParseConcatOp}, | |||
| {kRename, &DEPipeline::ParseRenameOp}, | |||
| {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | |||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | |||
| {kTfReader, &DEPipeline::ParseTFReaderOp}, | |||
| {kProject, &DEPipeline::ParseProjectOp}, | |||
| {kTake, &DEPipeline::ParseTakeOp}, | |||
| {kImageFolder, &DEPipeline::ParseImageFolderOp}, | |||
| {kMnist, &DEPipeline::ParseMnistOp}, | |||
| {kManifest, &DEPipeline::ParseManifestOp}, | |||
| {kVoc, &DEPipeline::ParseVOCOp}, | |||
| {kCoco, &DEPipeline::ParseCocoOp}, | |||
| {kCifar10, &DEPipeline::ParseCifar10Op}, | |||
| {kCifar100, &DEPipeline::ParseCifar100Op}, | |||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | |||
| {kRandomData, &DEPipeline::ParseRandomDataOp}, | |||
| {kTextFile, &DEPipeline::ParseTextFileOp}, | |||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | |||
| {kClue, &DEPipeline::ParseClueOp}}; | |||
| static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { | |||
| {kStorage, &DEPipeline::ParseStorageOp}, | |||
| {kShuffle, &DEPipeline::ParseShuffleOp}, | |||
| {kMindrecord, &DEPipeline::ParseMindRecordOp}, | |||
| {kMap, &DEPipeline::ParseMapOp}, | |||
| {kFilter, &DEPipeline::ParseFilterOp}, | |||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||
| {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, | |||
| {kBarrier, &DEPipeline::ParseBarrierOp}, | |||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||
| {kZip, &DEPipeline::ParseZipOp}, | |||
| {kConcat, &DEPipeline::ParseConcatOp}, | |||
| {kRename, &DEPipeline::ParseRenameOp}, | |||
| {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | |||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | |||
| {kTfReader, &DEPipeline::ParseTFReaderOp}, | |||
| {kProject, &DEPipeline::ParseProjectOp}, | |||
| {kTake, &DEPipeline::ParseTakeOp}, | |||
| {kImageFolder, &DEPipeline::ParseImageFolderOp}, | |||
| {kMnist, &DEPipeline::ParseMnistOp}, | |||
| {kManifest, &DEPipeline::ParseManifestOp}, | |||
| {kVoc, &DEPipeline::ParseVOCOp}, | |||
| {kCoco, &DEPipeline::ParseCocoOp}, | |||
| {kCifar10, &DEPipeline::ParseCifar10Op}, | |||
| {kCifar100, &DEPipeline::ParseCifar100Op}, | |||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | |||
| {kRandomData, &DEPipeline::ParseRandomDataOp}, | |||
| {kTextFile, &DEPipeline::ParseTextFileOp}, | |||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | |||
| {kClue, &DEPipeline::ParseClueOp}}; | |||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | |||
| try { | |||
| @@ -672,6 +675,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::vector<std::string> mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", | |||
| "bucket_batch_sizes"}; | |||
| for (auto name : mandatory_arguments) { | |||
| if (args[name.c_str()].is_none()) { | |||
| std::string err_msg = "Error: " + name + " is not set."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| std::shared_ptr<BucketBatchByLengthOp::Builder> builder = std::make_shared<BucketBatchByLengthOp::Builder>( | |||
| ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), | |||
| ToIntVector(args[mandatory_arguments[2].c_str()])); | |||
| for (auto arg : args) { | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "length_dependent_columns") { | |||
| (void)builder->SetLengthDependentColumns(ToStringVector(value)); | |||
| } | |||
| if (key == "bucket_boundaries") { | |||
| (void)builder->SetBucketBoundaries(ToIntVector(value)); | |||
| } | |||
| if (key == "bucket_batch_sizes") { | |||
| (void)builder->SetBucketBatchSizes(ToIntVector(value)); | |||
| } | |||
| if (key == "element_length_function") { | |||
| (void)builder->SetElementLengthFunction(value.cast<py::function>()); | |||
| } | |||
| if (key == "pad_info") { | |||
| PadInfo pad_info; | |||
| RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); | |||
| (void)builder->SetPadInfo(pad_info); | |||
| } | |||
| if (key == "pad_to_bucket_boundary") { | |||
| (void)builder->SetPadToBucketBoundary(ToBool(value)); | |||
| } | |||
| if (key == "drop_remainder") { | |||
| (void)builder->SetDropRemainder(ToBool(value)); | |||
| } | |||
| } | |||
| } | |||
| std::shared_ptr<BucketBatchByLengthOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| *ptr = op; | |||
| return Status::OK(); | |||
| } | |||
| Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||
| std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>(); | |||
| // Right now barrier should only take num_rows_per_buffer = 1 | |||
| @@ -40,6 +40,7 @@ enum OpName { | |||
| kShuffle, | |||
| kMindrecord, | |||
| kBatch, | |||
| kBucketBatch, | |||
| kBarrier, | |||
| kCache, | |||
| kRepeat, | |||
| @@ -121,6 +122,8 @@ class DEPipeline { | |||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| @@ -616,6 +616,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||
| .value("STORAGE", OpName::kStorage) | |||
| .value("SHUFFLE", OpName::kShuffle) | |||
| .value("BATCH", OpName::kBatch) | |||
| .value("BUCKETBATCH", OpName::kBucketBatch) | |||
| .value("BARRIER", OpName::kBarrier) | |||
| .value("MINDRECORD", OpName::kMindrecord) | |||
| .value("CACHE", OpName::kCache) | |||
| @@ -8,6 +8,7 @@ add_library(engine-datasetops OBJECT | |||
| pipeline_op.cc | |||
| barrier_op.cc | |||
| batch_op.cc | |||
| bucket_batch_by_length_op.cc | |||
| device_queue_op.cc | |||
| map_op.cc | |||
| project_op.cc | |||
| @@ -193,6 +193,22 @@ class BatchOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "BatchOp"; } | |||
| // batch the rows in src table then put it to dest table | |||
| // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching | |||
| // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows | |||
| // @param int32_t size - batch_size | |||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||
| // @return Status - The error code return | |||
| static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, | |||
| dsize_t batch_size); | |||
| // @param table | |||
| // @param const PadInfo &pad_info pad info | |||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||
| // @return Status - The error code return | |||
| static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map); | |||
| private: | |||
| // Worker thread for doing the memcpy of batch | |||
| // @param int32_t param workerId | |||
| @@ -203,16 +219,6 @@ class BatchOp : public ParallelOp { | |||
| // @return Status - The error code return | |||
| Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, | |||
| std::unique_ptr<DataBuffer> *db); | |||
| // batch the rows in src table then put it to dest table | |||
| // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching | |||
| // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows | |||
| // @param int32_t size - batch_size | |||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||
| // @return Status - The error code return | |||
| static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, | |||
| dsize_t batch_size); | |||
| // Function that calls pyfunc to perform map on batch | |||
| // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor | |||
| // @return Status - The error code return | |||
| @@ -229,13 +235,6 @@ class BatchOp : public ParallelOp { | |||
| std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, | |||
| std::vector<std::vector<dsize_t>> *pad_shapes); | |||
| // @param table | |||
| // @param const PadInfo &pad_info pad info | |||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||
| // @return Status - The error code return | |||
| static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map); | |||
| // the number of thread pulling from the mOutConnector of the Op below | |||
| // @return int32_t, 1 | |||
| int32_t num_consumers() const override { return 1; } | |||
| @@ -0,0 +1,242 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "pybind11/numpy.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| #include "dataset/core/pybind_support.h" | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/core/tensor_shape.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/parallel_op.h" | |||
| #include "dataset/engine/opt/pass.h" | |||
| #include "dataset/util/status.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| BucketBatchByLengthOp::Builder::Builder(std::vector<std::string> length_dependent_columns, | |||
| std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes) | |||
| : builder_length_dependent_columns_(length_dependent_columns), | |||
| builder_bucket_boundaries_(bucket_boundaries), | |||
| builder_bucket_batch_sizes_(bucket_batch_sizes), | |||
| builder_pad_info_({}), | |||
| builder_pad_to_bucket_boundary_(false), | |||
| builder_drop_remainder_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| } | |||
| Status BucketBatchByLengthOp::Builder::SanityCheck() { | |||
| std::string error_message; | |||
| if (builder_length_dependent_columns_.empty()) { | |||
| error_message += "At least 1 column must be specified for element length calculation.\n"; | |||
| } | |||
| if (builder_bucket_boundaries_.empty()) { | |||
| error_message += "At least 1 bucket boundary must be specified.\n"; | |||
| } | |||
| if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { | |||
| error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); | |||
| return Status::OK(); | |||
| } | |||
| Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| // insert 0 for the first bucket | |||
| builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); | |||
| *new_bucket_batch_by_length_op = std::make_shared<BucketBatchByLengthOp>( | |||
| builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, | |||
| builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, | |||
| builder_op_connector_size_); | |||
| return Status::OK(); | |||
| } | |||
| BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns, | |||
| std::vector<int32_t> bucket_boundaries, | |||
| std::vector<int32_t> bucket_batch_sizes, | |||
| py::function element_length_function, PadInfo pad_info, | |||
| bool pad_to_bucket_boundary, bool drop_remainder, | |||
| int32_t op_connector_size) | |||
| : PipelineOp(op_connector_size), | |||
| length_dependent_columns_(length_dependent_columns), | |||
| bucket_boundaries_(bucket_boundaries), | |||
| bucket_batch_sizes_(bucket_batch_sizes), | |||
| element_length_function_(element_length_function), | |||
| pad_info_(pad_info), | |||
| pad_to_bucket_boundary_(pad_to_bucket_boundary), | |||
| drop_remainder_(drop_remainder), | |||
| batch_count_(0) { | |||
| for (int i = 0; i < bucket_batch_sizes_.size(); i++) { | |||
| buckets_.push_back(std::make_unique<TensorQTable>()); | |||
| } | |||
| } | |||
| Status BucketBatchByLengthOp::EoeReceived(int32_t) { | |||
| state_ = OpState::kDeOpIdle; | |||
| return Status::OK(); | |||
| } | |||
| void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } | |||
| Status BucketBatchByLengthOp::operator()() { | |||
| TaskManager::FindMe()->Post(); | |||
| TensorRow current_row; | |||
| child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); | |||
| RETURN_IF_NOT_OK(AssignColMapFromChild()); | |||
| while (!child_iterator_->eof_handled()) { | |||
| while (!current_row.empty()) { | |||
| int32_t element_length; | |||
| RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); | |||
| int bucket_index = bucket_boundaries_.size() - 1; | |||
| while (element_length < bucket_boundaries_[bucket_index]) { | |||
| bucket_index--; | |||
| } | |||
| buckets_[bucket_index]->push_back(current_row); | |||
| if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { | |||
| RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); | |||
| } | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); | |||
| } | |||
| // got EOE, do what we need to do with remainders in each bucket | |||
| if (!drop_remainder_) { | |||
| for (int i = 0; i < bucket_boundaries_.size(); i++) { | |||
| if (!buckets_[i]->empty()) { | |||
| RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); | |||
| } | |||
| } | |||
| } | |||
| // need to send EOE manually since we set state to idle in EoeRecieved() | |||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { | |||
| // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of | |||
| // the single column specified in length_dependent_columns_ | |||
| if (element_length_function_) { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| try { | |||
| size_t number_of_arguments = length_dependent_columns_.size(); | |||
| py::tuple input_arguments(number_of_arguments); | |||
| for (size_t i = 0; i < number_of_arguments; i++) { | |||
| py::array argument_value; | |||
| int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; | |||
| RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); | |||
| input_arguments[i] = argument_value; | |||
| } | |||
| py::object length = element_length_function_(*input_arguments); | |||
| *out_element_length = length.cast<int32_t>(); | |||
| if (*out_element_length < 0) { | |||
| return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); | |||
| } | |||
| } catch (const py::error_already_set &e) { | |||
| return Status(StatusCode::kPyFuncException, e.what()); | |||
| } catch (const py::cast_error &e) { | |||
| return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); | |||
| } | |||
| } else { | |||
| *out_element_length = element[0]->shape()[0]; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { | |||
| std::unique_ptr<TensorQTable> *bucket = &buckets_[bucket_index]; | |||
| PadInfo pad_info_copy = pad_info_; | |||
| if (pad_to_bucket_boundary_) { | |||
| for (auto &pair : pad_info_copy) { | |||
| std::vector<dsize_t> pad_shape = pair.second.first.AsVector(); | |||
| for (size_t i = 0; i < pad_shape.size(); i++) { | |||
| if (pad_shape[i] == TensorShape::kDimUnknown) { | |||
| if (bucket_index + 1 >= bucket_boundaries_.size()) { | |||
| std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); | |||
| } | |||
| pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; | |||
| } | |||
| } | |||
| pair.second.first = TensorShape(pad_shape); | |||
| } | |||
| } | |||
| // PadColumns will change the data in bucket | |||
| RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); | |||
| std::unique_ptr<TensorQTable> batched_bucket = std::make_unique<TensorQTable>(); | |||
| RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); | |||
| (*bucket)->clear(); | |||
| std::unique_ptr<DataBuffer> batched_buffer = std::make_unique<DataBuffer>(batch_count_, DataBuffer::kDeBFlagNone); | |||
| batched_buffer->set_tensor_table(std::move(batched_bucket)); | |||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); | |||
| batch_count_++; | |||
| return Status::OK(); | |||
| } | |||
| Status BucketBatchByLengthOp::Reset() { | |||
| batch_count_ = 0; | |||
| for (int i = 0; i < buckets_.size(); i++) { | |||
| buckets_[i] = std::make_unique<TensorQTable>(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,153 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <queue> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/dataset_iterator.h" | |||
| #include "dataset/engine/datasetops/batch_op.h" | |||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DataBuffer; | |||
| class BucketBatchByLengthOp : public PipelineOp { | |||
| public: | |||
| class Builder { | |||
| public: | |||
| Builder(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries, | |||
| std::vector<int32_t> bucket_batch_sizes); | |||
| ~Builder() = default; | |||
| Builder &SetLengthDependentColumns(std::vector<std::string> length_dependent_columns) { | |||
| builder_length_dependent_columns_ = length_dependent_columns; | |||
| return *this; | |||
| } | |||
| Builder &SetBucketBoundaries(std::vector<int32_t> bucket_boundaries) { | |||
| builder_bucket_boundaries_ = bucket_boundaries; | |||
| return *this; | |||
| } | |||
| Builder &SetBucketBatchSizes(std::vector<int32_t> bucket_batch_sizes) { | |||
| builder_bucket_batch_sizes_ = bucket_batch_sizes; | |||
| return *this; | |||
| } | |||
| Builder &SetElementLengthFunction(py::function element_length_function) { | |||
| builder_element_length_function_ = element_length_function; | |||
| return *this; | |||
| } | |||
| Builder &SetPadInfo(PadInfo pad_info) { | |||
| builder_pad_info_ = pad_info; | |||
| return *this; | |||
| } | |||
| Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { | |||
| builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; | |||
| return *this; | |||
| } | |||
| Builder &SetDropRemainder(bool drop_remainder) { | |||
| builder_drop_remainder_ = drop_remainder; | |||
| return *this; | |||
| } | |||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||
| builder_op_connector_size_ = op_connector_size; | |||
| return *this; | |||
| } | |||
| Status Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op); | |||
| private: | |||
| Status SanityCheck(); | |||
| std::vector<std::string> builder_length_dependent_columns_; | |||
| std::vector<int32_t> builder_bucket_boundaries_; | |||
| std::vector<int32_t> builder_bucket_batch_sizes_; | |||
| py::function builder_element_length_function_; | |||
| PadInfo builder_pad_info_; | |||
| bool builder_pad_to_bucket_boundary_; | |||
| bool builder_drop_remainder_; | |||
| int32_t builder_op_connector_size_; | |||
| }; | |||
| BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries, | |||
| std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, | |||
| bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); | |||
| // Might need to batch remaining buckets after receiving eoe, so override this method. | |||
| // @param int32_t workerId | |||
| // @return Status - The error code returned | |||
| Status EoeReceived(int32_t) override; | |||
| // A print method typically used for debugging | |||
| // @param out - The output stream to write output to | |||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // << Stream output operator overload | |||
| // @notes This allows you to write the debug print info using stream operators | |||
| // @param out - reference to the output stream being overloaded | |||
| // @param sO - reference to the BucketBatchByLengthOp to display | |||
| // @return - the output stream must be returned | |||
| friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { | |||
| bo.Print(out, false); | |||
| return out; | |||
| } | |||
| // Main loop of batch | |||
| // @return Status - The error code returned | |||
| Status operator()() override; | |||
| // Function that is called by ResetOp at the end of every epoch | |||
| // @return Status - The error code returned | |||
| Status Reset() override; | |||
| private: | |||
| Status ObtainElementLength(int32_t *out_element_length, TensorRow element); | |||
| Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); | |||
| std::vector<std::string> length_dependent_columns_; | |||
| std::vector<int32_t> bucket_boundaries_; | |||
| std::vector<int32_t> bucket_batch_sizes_; | |||
| py::function element_length_function_; | |||
| PadInfo pad_info_; | |||
| bool pad_to_bucket_boundary_; | |||
| bool drop_remainder_; | |||
| int32_t batch_count_; | |||
| std::unique_ptr<ChildIterator> child_iterator_; | |||
| std::vector<std::unique_ptr<TensorQTable>> buckets_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ | |||
| @@ -42,9 +42,9 @@ from .iterators import DictIterator, TupleIterator | |||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | |||
| check_rename, check_numpyslicesdataset, \ | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||
| check_split, check_cluedataset | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\ | |||
| check_split, check_bucket_batch_by_length, check_cluedataset | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| try: | |||
| @@ -165,6 +165,76 @@ class Dataset: | |||
| args["num_parallel_workers"] = self.num_parallel_workers | |||
| return args | |||
| @check_bucket_batch_by_length | |||
| def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, | |||
| element_length_function=None, pad_info=None, | |||
| pad_to_bucket_boundary=False, drop_remainder=False): | |||
| """ | |||
| Bucket elements according to their lengths, and pad and batch the buckets when | |||
| they are full. | |||
| A length function is called on each row in the dataset, the row is then | |||
| bucketed based on its length and bucket_boundaries. When a bucket reaches its | |||
| corresponding size specified in bucket_batch_sizes, the entire bucket will be | |||
| padded according to batch_info, and then batched. Each batch will be full, | |||
| except for maybe the last batch for each bucket. | |||
| Args: | |||
| column_names (list of string): Columns passed to element_length_function. | |||
| bucket_boundaries (list of int): A list consisting of the upper boundaries | |||
| of the buckets. Must be strictly increasing. If there are n boundaries, | |||
| n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one | |||
| bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each | |||
| 0<i<n, and one bucket for [bucket_boundaries[n-1], inf). | |||
| bucket_batch_sizes (list of int): A list consisting of the batch sizes for | |||
| each buclet. Must contain len(bucket_boundaries)+1 elements. | |||
| element_length_function (Callable, optional): A function that takes in | |||
| len(column_names) arguments and returns an int. If no value is | |||
| provided, then len(column_names) must be 1, and the size of the first | |||
| dimension of that column will be taken as the length (default=None). | |||
| pad_info (dict, optional): Represents how to batch each column. The key | |||
| corresponds to the column name, the value must be a tuple of 2 elements. | |||
| The first element corresponds to the shape to pad to, and the second | |||
| element corresponds to the value to pad with. If a column is not | |||
| specified, then that column will be padded to the longest in the current | |||
| batch, and 0 will be used as the padding value. Any None dimensions will | |||
| be padded to the longest in the current batch, unless if | |||
| pad_to_bucket_boundary is True. If no padding is wanted, set pad_info | |||
| to None (default=None). | |||
| pad_to_bucket_boundary (bool, optional): If True, will pad each None | |||
| dimension in pad_info to the bucket_boundary minus 1. If there are any | |||
| elements that fall into the last bucket, an error will occur | |||
| (default=False). | |||
| drop_remainder (bool, optional): If True, will drop the last batch for each | |||
| bucket if it is not a full batch (default=False). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # data is an instance of Dataset object. | |||
| >>> | |||
| >>> # creates a dataset where every 100 rows is combined into a batch | |||
| >>> # and drops the last incomplete batch if there is one. | |||
| >>> column_names = ["col1", "col2"] | |||
| >>> buket_boundaries = [5, 10] | |||
| >>> bucket_batch_sizes = [5, 1, 1] | |||
| >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2))) | |||
| >>> | |||
| >>> # will pad col1 to shape [2, bucket_boundaries[i]] where i is the | |||
| >>> # index of the bucket that is currently being batched. | |||
| >>> # will pad col2 to a shape where each dimension is the longest in all | |||
| >>> # the elements currently being batched. | |||
| >>> pad_info = {"col1", ([2, None], -1)} | |||
| >>> pad_to_bucket_boundary = True | |||
| >>> | |||
| >>> data = data.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| >>> bucket_batch_sizes, | |||
| >>> element_length_function, pad_info), | |||
| >>> pad_to_bucket_boundary) | |||
| """ | |||
| return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes, | |||
| element_length_function, pad_info, | |||
| pad_to_bucket_boundary, drop_remainder) | |||
| @check_batch | |||
| def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, | |||
| input_columns=None, pad_info=None): | |||
| @@ -1400,6 +1470,47 @@ class DatasetOp(Dataset): | |||
| # No need for __init__ since it is the same as the super's init | |||
| class BucketBatchByLengthDataset(DatasetOp): | |||
| """ | |||
| The result of applying BucketBatchByLength operator to the input dataset. | |||
| """ | |||
| def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, | |||
| element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder): | |||
| super().__init__() | |||
| self.column_names = column_names | |||
| self.bucket_boundaries = bucket_boundaries | |||
| self.bucket_batch_sizes = bucket_batch_sizes | |||
| self.element_length_function = element_length_function | |||
| self.pad_info = pad_info | |||
| self.pad_to_bucket_boundary = pad_to_bucket_boundary | |||
| self.drop_remainder = drop_remainder | |||
| self.input.append(input_dataset) | |||
| input_dataset.output.append(self) | |||
| self._input_indexs = input_dataset.input_indexs | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["length_dependent_columns"] = self.column_names | |||
| args["bucket_boundaries"] = self.bucket_boundaries | |||
| args["bucket_batch_sizes"] = self.bucket_batch_sizes | |||
| args["element_length_function"] = self.element_length_function | |||
| args["pad_info"] = self.pad_info | |||
| args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary | |||
| args["drop_remainder"] = self.drop_remainder | |||
| return args | |||
| def get_dataset_size(self): | |||
| """ | |||
| Get the number of batches in an epoch. | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| return None | |||
| class BatchDataset(DatasetOp): | |||
| """ | |||
| @@ -132,6 +132,8 @@ class Iterator: | |||
| op_type = OpName.MINDRECORD | |||
| elif isinstance(dataset, de.BatchDataset): | |||
| op_type = OpName.BATCH | |||
| elif isinstance(dataset, de.BucketBatchByLengthDataset): | |||
| op_type = OpName.BUCKETBATCH | |||
| elif isinstance(dataset, de.SyncWaitDataset): | |||
| op_type = OpName.BARRIER | |||
| elif isinstance(dataset, de.ZipDataset): | |||
| @@ -752,6 +752,67 @@ def check_pad_info(key, val): | |||
| check_type(val[1], "pad_value", (int, float, str, bytes)) | |||
| def check_bucket_batch_by_length(method): | |||
| """check the input arguments of bucket_batch_by_length.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] | |||
| check_param_type(nreq_param_list, param_dict, list) | |||
| # check column_names: must be list of string. | |||
| column_names = param_dict.get("column_names") | |||
| all_string = all(isinstance(item, str) for item in column_names) | |||
| if not all_string: | |||
| raise TypeError("column_names should be a list of str.") | |||
| element_length_function = param_dict.get("element_length_function") | |||
| if element_length_function is None and len(column_names) != 1: | |||
| raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") | |||
| # check bucket_boundaries: must be list of int, positive and strictly increasing | |||
| bucket_boundaries = param_dict.get('bucket_boundaries') | |||
| if not bucket_boundaries: | |||
| raise ValueError("bucket_boundaries cannot be empty.") | |||
| all_int = all(isinstance(item, int) for item in bucket_boundaries) | |||
| if not all_int: | |||
| raise TypeError("bucket_boundaries should be a list of int.") | |||
| all_non_negative = all(item >= 0 for item in bucket_boundaries) | |||
| if not all_non_negative: | |||
| raise ValueError("bucket_boundaries cannot contain any negative numbers.") | |||
| for i in range(len(bucket_boundaries) - 1): | |||
| if not bucket_boundaries[i + 1] > bucket_boundaries[i]: | |||
| raise ValueError("bucket_boundaries should be strictly increasing.") | |||
| # check bucket_batch_sizes: must be list of int and positive | |||
| bucket_batch_sizes = param_dict.get('bucket_batch_sizes') | |||
| if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: | |||
| raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") | |||
| all_int = all(isinstance(item, int) for item in bucket_batch_sizes) | |||
| if not all_int: | |||
| raise TypeError("bucket_batch_sizes should be a list of int.") | |||
| all_non_negative = all(item >= 0 for item in bucket_batch_sizes) | |||
| if not all_non_negative: | |||
| raise ValueError("bucket_batch_sizes cannot contain any negative numbers.") | |||
| if param_dict.get('pad_info') is not None: | |||
| check_type(param_dict["pad_info"], "pad_info", dict) | |||
| for k, v in param_dict.get('pad_info').items(): | |||
| check_pad_info(k, v) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_batch(method): | |||
| """check the input arguments of batch.""" | |||
| @@ -0,0 +1,373 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| # generates 1 column [0], [0, 1], ..., [0, ..., n-1] | |||
| def generate_sequential(n): | |||
| for i in range(n): | |||
| yield (np.array([j for j in range(i + 1)]),) | |||
| # generates 1 column [0], [1], ..., [n-1] | |||
| def generate_sequential_same_shape(n): | |||
| for i in range(n): | |||
| yield (np.array([i]),) | |||
| # combines generate_sequential_same_shape and generate_sequential | |||
| def generate_2_columns(n): | |||
| for i in range(n): | |||
| yield (np.array([i]), np.array([j for j in range(i + 1)])) | |||
| def test_bucket_batch_invalid_input(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) | |||
| column_names = ["col1"] | |||
| invalid_column_names = [1, 2, 3] | |||
| bucket_boundaries = [1, 2, 3] | |||
| empty_bucket_boundaries = [] | |||
| invalid_bucket_boundaries = ["1", "2", "3"] | |||
| negative_bucket_boundaries = [1, 2, -3] | |||
| decreasing_bucket_boundaries = [3, 2, 1] | |||
| non_increasing_bucket_boundaries = [1, 2, 2] | |||
| bucket_batch_sizes = [1, 1, 1, 1] | |||
| invalid_bucket_batch_sizes = ["1", "2", "3", "4"] | |||
| negative_bucket_batch_sizes = [1, 2, 3, -4] | |||
| with pytest.raises(TypeError) as info: | |||
| _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) | |||
| assert "column_names should be a list of str" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) | |||
| assert "bucket_boundaries cannot be empty" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) | |||
| assert "bucket_boundaries should be a list of int" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) | |||
| assert "bucket_boundaries cannot contain any negative numbers" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) | |||
| assert "bucket_boundaries should be strictly increasing" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, non_increasing_bucket_boundaries, bucket_batch_sizes) | |||
| assert "bucket_boundaries should be strictly increasing" in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, invalid_bucket_batch_sizes) | |||
| assert "bucket_batch_sizes should be a list of int" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) | |||
| assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) | |||
| assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value) | |||
| def test_bucket_batch_multi_bucket_no_padding(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [1, 2, 3] | |||
| bucket_batch_sizes = [3, 3, 2, 2] | |||
| element_length_function = (lambda x: x[0] % 4) | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function) | |||
| expected_output = [[[2], [6]], | |||
| [[3], [7]], | |||
| [[0], [4], [8]], | |||
| [[1], [5], [9]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_multi_bucket_with_padding(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(10)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [1, 2, 3] | |||
| bucket_batch_sizes = [2, 3, 3, 2] | |||
| element_length_function = (lambda x: len(x) % 4) | |||
| pad_info = {"col1": ([10], 0)} | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, | |||
| pad_info) | |||
| expected_output = [[[0, 1, 2, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 0, 0, 0]], | |||
| [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0]], | |||
| [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0]], | |||
| [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_single_bucket_no_padding(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [1, 2, 3] | |||
| bucket_batch_sizes = [1, 1, 5, 1] | |||
| element_length_function = (lambda x: 2) | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function) | |||
| expected_output = [[[0], [1], [2], [3], [4]], | |||
| [[5], [6], [7], [8], [9]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_single_bucket_with_padding(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [1, 2, 3] | |||
| bucket_batch_sizes = [1, 1, 1, 3] | |||
| element_length_function = (lambda x: 7) | |||
| pad_info = {"col1": ([12], 0)} | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, | |||
| pad_info) | |||
| expected_output = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]], | |||
| [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0]], | |||
| [[0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_pad_to_bucket_boundary(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [3, 6, 15] | |||
| bucket_batch_sizes = [2, 3, 4, 1] | |||
| element_length_function = len | |||
| pad_info = {"col1": ([None], 0)} | |||
| pad_to_bucket_boundary = True | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, | |||
| pad_info, pad_to_bucket_boundary) | |||
| expected_output = [[[0, 0], | |||
| [0, 1]], | |||
| [[0, 1, 2, 0, 0], | |||
| [0, 1, 2, 3, 0], | |||
| [0, 1, 2, 3, 4]], | |||
| [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_default_pad(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(15)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [5, 8, 17] | |||
| bucket_batch_sizes = [2, 1, 4, 1] | |||
| element_length_function = len | |||
| pad_info = {"col1": ([None], 0)} | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, | |||
| pad_info) | |||
| expected_output = [[[0, 0], | |||
| [0, 1]], | |||
| [[0, 1, 2, 0], | |||
| [0, 1, 2, 3]], | |||
| [[0, 1, 2, 3, 4]], | |||
| [[0, 1, 2, 3, 4, 5]], | |||
| [[0, 1, 2, 3, 4, 5, 6]], | |||
| [[0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], | |||
| [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_drop_remainder(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(27)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [1, 2] | |||
| bucket_batch_sizes = [2, 3, 5] | |||
| element_length_function = (lambda x: x[0] % 3) | |||
| pad_info = None | |||
| pad_to_bucket_boundary = False | |||
| drop_remainder = True | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, | |||
| pad_info, pad_to_bucket_boundary, drop_remainder) | |||
| expected_output = [[[0], [3]], | |||
| [[1], [4], [7]], | |||
| [[6], [9]], | |||
| [[2], [5], [8], [11], [14]], | |||
| [[12], [15]], | |||
| [[10], [13], [16]], | |||
| [[18], [21]], | |||
| [[19], [22], [25]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_default_length_function(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) | |||
| column_names = ["col1"] | |||
| bucket_boundaries = [6, 12] | |||
| bucket_batch_sizes = [5, 4, 1] | |||
| element_length_function = None | |||
| pad_info = {} | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, | |||
| pad_info) | |||
| expected_output = [[[0, 0, 0, 0, 0], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 1, 2, 0, 0], | |||
| [0, 1, 2, 3, 0], | |||
| [0, 1, 2, 3, 4]], | |||
| [[0, 1, 2, 3, 4, 5, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| assert output == expected_output | |||
| def test_bucket_batch_multi_column(): | |||
| dataset = ds.GeneratorDataset((lambda: generate_2_columns(10)), ["same_shape", "variable_shape"]) | |||
| column_names = ["same_shape"] | |||
| bucket_boundaries = [6, 12] | |||
| bucket_batch_sizes = [5, 5, 1] | |||
| element_length_function = None | |||
| pad_info = {} | |||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, | |||
| pad_info) | |||
| same_shape_expected_output = [[[0], [1], [2], [3], [4]], | |||
| [[5], [6], [7], [8], [9]]] | |||
| variable_shape_expected_output = [[[0, 0, 0, 0, 0], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 1, 2, 0, 0], | |||
| [0, 1, 2, 3, 0], | |||
| [0, 1, 2, 3, 4]], | |||
| [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 0, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] | |||
| same_shape_output = [] | |||
| variable_shape_output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| same_shape_output.append(data["same_shape"].tolist()) | |||
| variable_shape_output.append(data["variable_shape"].tolist()) | |||
| assert same_shape_output == same_shape_expected_output | |||
| assert variable_shape_output == variable_shape_expected_output | |||
| if __name__ == '__main__': | |||
| test_bucket_batch_invalid_input() | |||
| test_bucket_batch_multi_bucket_no_padding() | |||
| test_bucket_batch_multi_bucket_with_padding() | |||
| test_bucket_batch_single_bucket_no_padding() | |||
| test_bucket_batch_single_bucket_with_padding() | |||
| test_bucket_batch_pad_to_bucket_boundary() | |||
| test_bucket_batch_default_pad() | |||
| test_bucket_batch_drop_remainder() | |||
| test_bucket_batch_default_length_function() | |||
| test_bucket_batch_multi_column() | |||