Merge pull request !1984 from Peilin/BucketBatchByLengthOptags/v0.5.0-beta
| @@ -19,62 +19,65 @@ | |||||
| #include <map> | #include <map> | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| #include "dataset/kernels/py_func_op.h" | |||||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||||
| #include "dataset/engine/datasetops/source/mnist_op.h" | |||||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||||
| #include "dataset/engine/datasetops/source/coco_op.h" | |||||
| #include "dataset/core/tensor.h" | #include "dataset/core/tensor.h" | ||||
| #include "dataset/engine/dataset_iterator.h" | #include "dataset/engine/dataset_iterator.h" | ||||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||||
| #include "dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||||
| #include "dataset/engine/datasetops/filter_op.h" | |||||
| #include "dataset/engine/datasetops/source/celeba_op.h" | #include "dataset/engine/datasetops/source/celeba_op.h" | ||||
| #include "dataset/engine/datasetops/source/cifar_op.h" | |||||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||||
| #include "dataset/engine/datasetops/source/coco_op.h" | |||||
| #include "dataset/engine/datasetops/source/image_folder_op.h" | |||||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||||
| #include "dataset/engine/datasetops/source/mnist_op.h" | |||||
| #include "dataset/engine/datasetops/source/random_data_op.h" | #include "dataset/engine/datasetops/source/random_data_op.h" | ||||
| #include "dataset/engine/datasetops/source/text_file_op.h" | #include "dataset/engine/datasetops/source/text_file_op.h" | ||||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||||
| #include "dataset/engine/datasetops/filter_op.h" | |||||
| #include "dataset/engine/datasetops/source/voc_op.h" | |||||
| #include "dataset/kernels/py_func_op.h" | |||||
| #include "dataset/util/random.h" | |||||
| #include "dataset/util/status.h" | |||||
| #include "mindrecord/include/shard_category.h" | #include "mindrecord/include/shard_category.h" | ||||
| #include "mindrecord/include/shard_distributed_sample.h" | #include "mindrecord/include/shard_distributed_sample.h" | ||||
| #include "mindrecord/include/shard_sample.h" | #include "mindrecord/include/shard_sample.h" | ||||
| #include "mindrecord/include/shard_shuffle.h" | #include "mindrecord/include/shard_shuffle.h" | ||||
| #include "dataset/util/random.h" | |||||
| #include "dataset/util/status.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "pybind11/stl.h" | #include "pybind11/stl.h" | ||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *); | using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *); | ||||
| static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp}, | |||||
| {kShuffle, &DEPipeline::ParseShuffleOp}, | |||||
| {kMindrecord, &DEPipeline::ParseMindRecordOp}, | |||||
| {kMap, &DEPipeline::ParseMapOp}, | |||||
| {kFilter, &DEPipeline::ParseFilterOp}, | |||||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||||
| {kBarrier, &DEPipeline::ParseBarrierOp}, | |||||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||||
| {kZip, &DEPipeline::ParseZipOp}, | |||||
| {kConcat, &DEPipeline::ParseConcatOp}, | |||||
| {kRename, &DEPipeline::ParseRenameOp}, | |||||
| {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | |||||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | |||||
| {kTfReader, &DEPipeline::ParseTFReaderOp}, | |||||
| {kProject, &DEPipeline::ParseProjectOp}, | |||||
| {kTake, &DEPipeline::ParseTakeOp}, | |||||
| {kImageFolder, &DEPipeline::ParseImageFolderOp}, | |||||
| {kMnist, &DEPipeline::ParseMnistOp}, | |||||
| {kManifest, &DEPipeline::ParseManifestOp}, | |||||
| {kVoc, &DEPipeline::ParseVOCOp}, | |||||
| {kCoco, &DEPipeline::ParseCocoOp}, | |||||
| {kCifar10, &DEPipeline::ParseCifar10Op}, | |||||
| {kCifar100, &DEPipeline::ParseCifar100Op}, | |||||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | |||||
| {kRandomData, &DEPipeline::ParseRandomDataOp}, | |||||
| {kTextFile, &DEPipeline::ParseTextFileOp}, | |||||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | |||||
| {kClue, &DEPipeline::ParseClueOp}}; | |||||
| static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = { | |||||
| {kStorage, &DEPipeline::ParseStorageOp}, | |||||
| {kShuffle, &DEPipeline::ParseShuffleOp}, | |||||
| {kMindrecord, &DEPipeline::ParseMindRecordOp}, | |||||
| {kMap, &DEPipeline::ParseMapOp}, | |||||
| {kFilter, &DEPipeline::ParseFilterOp}, | |||||
| {kBatch, &DEPipeline::ParseBatchOp}, | |||||
| {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, | |||||
| {kBarrier, &DEPipeline::ParseBarrierOp}, | |||||
| {kRepeat, &DEPipeline::ParseRepeatOp}, | |||||
| {kSkip, &DEPipeline::ParseSkipOp}, | |||||
| {kZip, &DEPipeline::ParseZipOp}, | |||||
| {kConcat, &DEPipeline::ParseConcatOp}, | |||||
| {kRename, &DEPipeline::ParseRenameOp}, | |||||
| {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, | |||||
| {kGenerator, &DEPipeline::ParseGeneratorOp}, | |||||
| {kTfReader, &DEPipeline::ParseTFReaderOp}, | |||||
| {kProject, &DEPipeline::ParseProjectOp}, | |||||
| {kTake, &DEPipeline::ParseTakeOp}, | |||||
| {kImageFolder, &DEPipeline::ParseImageFolderOp}, | |||||
| {kMnist, &DEPipeline::ParseMnistOp}, | |||||
| {kManifest, &DEPipeline::ParseManifestOp}, | |||||
| {kVoc, &DEPipeline::ParseVOCOp}, | |||||
| {kCoco, &DEPipeline::ParseCocoOp}, | |||||
| {kCifar10, &DEPipeline::ParseCifar10Op}, | |||||
| {kCifar100, &DEPipeline::ParseCifar100Op}, | |||||
| {kCelebA, &DEPipeline::ParseCelebAOp}, | |||||
| {kRandomData, &DEPipeline::ParseRandomDataOp}, | |||||
| {kTextFile, &DEPipeline::ParseTextFileOp}, | |||||
| {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, | |||||
| {kClue, &DEPipeline::ParseClueOp}}; | |||||
| DEPipeline::DEPipeline() : iterator_(nullptr) { | DEPipeline::DEPipeline() : iterator_(nullptr) { | ||||
| try { | try { | ||||
| @@ -672,6 +675,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | |||||
| std::vector<std::string> mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", | |||||
| "bucket_batch_sizes"}; | |||||
| for (auto name : mandatory_arguments) { | |||||
| if (args[name.c_str()].is_none()) { | |||||
| std::string err_msg = "Error: " + name + " is not set."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<BucketBatchByLengthOp::Builder> builder = std::make_shared<BucketBatchByLengthOp::Builder>( | |||||
| ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), | |||||
| ToIntVector(args[mandatory_arguments[2].c_str()])); | |||||
| for (auto arg : args) { | |||||
| std::string key = py::str(arg.first); | |||||
| py::handle value = arg.second; | |||||
| if (!value.is_none()) { | |||||
| if (key == "length_dependent_columns") { | |||||
| (void)builder->SetLengthDependentColumns(ToStringVector(value)); | |||||
| } | |||||
| if (key == "bucket_boundaries") { | |||||
| (void)builder->SetBucketBoundaries(ToIntVector(value)); | |||||
| } | |||||
| if (key == "bucket_batch_sizes") { | |||||
| (void)builder->SetBucketBatchSizes(ToIntVector(value)); | |||||
| } | |||||
| if (key == "element_length_function") { | |||||
| (void)builder->SetElementLengthFunction(value.cast<py::function>()); | |||||
| } | |||||
| if (key == "pad_info") { | |||||
| PadInfo pad_info; | |||||
| RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); | |||||
| (void)builder->SetPadInfo(pad_info); | |||||
| } | |||||
| if (key == "pad_to_bucket_boundary") { | |||||
| (void)builder->SetPadToBucketBoundary(ToBool(value)); | |||||
| } | |||||
| if (key == "drop_remainder") { | |||||
| (void)builder->SetDropRemainder(ToBool(value)); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::shared_ptr<BucketBatchByLengthOp> op; | |||||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||||
| *ptr = op; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { | ||||
| std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>(); | std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>(); | ||||
| // Right now barrier should only take num_rows_per_buffer = 1 | // Right now barrier should only take num_rows_per_buffer = 1 | ||||
| @@ -40,6 +40,7 @@ enum OpName { | |||||
| kShuffle, | kShuffle, | ||||
| kMindrecord, | kMindrecord, | ||||
| kBatch, | kBatch, | ||||
| kBucketBatch, | |||||
| kBarrier, | kBarrier, | ||||
| kCache, | kCache, | ||||
| kRepeat, | kRepeat, | ||||
| @@ -121,6 +122,8 @@ class DEPipeline { | |||||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | ||||
| @@ -616,6 +616,7 @@ PYBIND11_MODULE(_c_dataengine, m) { | |||||
| .value("STORAGE", OpName::kStorage) | .value("STORAGE", OpName::kStorage) | ||||
| .value("SHUFFLE", OpName::kShuffle) | .value("SHUFFLE", OpName::kShuffle) | ||||
| .value("BATCH", OpName::kBatch) | .value("BATCH", OpName::kBatch) | ||||
| .value("BUCKETBATCH", OpName::kBucketBatch) | |||||
| .value("BARRIER", OpName::kBarrier) | .value("BARRIER", OpName::kBarrier) | ||||
| .value("MINDRECORD", OpName::kMindrecord) | .value("MINDRECORD", OpName::kMindrecord) | ||||
| .value("CACHE", OpName::kCache) | .value("CACHE", OpName::kCache) | ||||
| @@ -8,6 +8,7 @@ add_library(engine-datasetops OBJECT | |||||
| pipeline_op.cc | pipeline_op.cc | ||||
| barrier_op.cc | barrier_op.cc | ||||
| batch_op.cc | batch_op.cc | ||||
| bucket_batch_by_length_op.cc | |||||
| device_queue_op.cc | device_queue_op.cc | ||||
| map_op.cc | map_op.cc | ||||
| project_op.cc | project_op.cc | ||||
| @@ -193,6 +193,22 @@ class BatchOp : public ParallelOp { | |||||
| // @return Name of the current Op | // @return Name of the current Op | ||||
| std::string Name() const override { return "BatchOp"; } | std::string Name() const override { return "BatchOp"; } | ||||
| // batch the rows in src table then put it to dest table | |||||
| // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching | |||||
| // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows | |||||
| // @param int32_t size - batch_size | |||||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||||
| // @return Status - The error code return | |||||
| static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, | |||||
| dsize_t batch_size); | |||||
| // @param table | |||||
| // @param const PadInfo &pad_info pad info | |||||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||||
| // @return Status - The error code return | |||||
| static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | |||||
| const std::unordered_map<std::string, int32_t> &column_name_id_map); | |||||
| private: | private: | ||||
| // Worker thread for doing the memcpy of batch | // Worker thread for doing the memcpy of batch | ||||
| // @param int32_t param workerId | // @param int32_t param workerId | ||||
| @@ -203,16 +219,6 @@ class BatchOp : public ParallelOp { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, | Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, | ||||
| std::unique_ptr<DataBuffer> *db); | std::unique_ptr<DataBuffer> *db); | ||||
| // batch the rows in src table then put it to dest table | |||||
| // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching | |||||
| // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows | |||||
| // @param int32_t size - batch_size | |||||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||||
| // @return Status - The error code return | |||||
| static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, | |||||
| dsize_t batch_size); | |||||
| // Function that calls pyfunc to perform map on batch | // Function that calls pyfunc to perform map on batch | ||||
| // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor | // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| @@ -229,13 +235,6 @@ class BatchOp : public ParallelOp { | |||||
| std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, | std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, | ||||
| std::vector<std::vector<dsize_t>> *pad_shapes); | std::vector<std::vector<dsize_t>> *pad_shapes); | ||||
| // @param table | |||||
| // @param const PadInfo &pad_info pad info | |||||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||||
| // @return Status - The error code return | |||||
| static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | |||||
| const std::unordered_map<std::string, int32_t> &column_name_id_map); | |||||
| // the number of thread pulling from the mOutConnector of the Op below | // the number of thread pulling from the mOutConnector of the Op below | ||||
| // @return int32_t, 1 | // @return int32_t, 1 | ||||
| int32_t num_consumers() const override { return 1; } | int32_t num_consumers() const override { return 1; } | ||||
| @@ -0,0 +1,242 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "pybind11/numpy.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/stl.h" | |||||
| #include "dataset/core/pybind_support.h" | |||||
| #include "dataset/core/config_manager.h" | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/core/tensor_shape.h" | |||||
| #include "dataset/engine/dataset_iterator.h" | |||||
| #include "dataset/engine/datasetops/parallel_op.h" | |||||
| #include "dataset/engine/opt/pass.h" | |||||
| #include "dataset/util/status.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| BucketBatchByLengthOp::Builder::Builder(std::vector<std::string> length_dependent_columns, | |||||
| std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes) | |||||
| : builder_length_dependent_columns_(length_dependent_columns), | |||||
| builder_bucket_boundaries_(bucket_boundaries), | |||||
| builder_bucket_batch_sizes_(bucket_batch_sizes), | |||||
| builder_pad_info_({}), | |||||
| builder_pad_to_bucket_boundary_(false), | |||||
| builder_drop_remainder_(false) { | |||||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||||
| } | |||||
| Status BucketBatchByLengthOp::Builder::SanityCheck() { | |||||
| std::string error_message; | |||||
| if (builder_length_dependent_columns_.empty()) { | |||||
| error_message += "At least 1 column must be specified for element length calculation.\n"; | |||||
| } | |||||
| if (builder_bucket_boundaries_.empty()) { | |||||
| error_message += "At least 1 bucket boundary must be specified.\n"; | |||||
| } | |||||
| if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { | |||||
| error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op) { | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | |||||
| // insert 0 for the first bucket | |||||
| builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); | |||||
| *new_bucket_batch_by_length_op = std::make_shared<BucketBatchByLengthOp>( | |||||
| builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, | |||||
| builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, | |||||
| builder_op_connector_size_); | |||||
| return Status::OK(); | |||||
| } | |||||
| BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns, | |||||
| std::vector<int32_t> bucket_boundaries, | |||||
| std::vector<int32_t> bucket_batch_sizes, | |||||
| py::function element_length_function, PadInfo pad_info, | |||||
| bool pad_to_bucket_boundary, bool drop_remainder, | |||||
| int32_t op_connector_size) | |||||
| : PipelineOp(op_connector_size), | |||||
| length_dependent_columns_(length_dependent_columns), | |||||
| bucket_boundaries_(bucket_boundaries), | |||||
| bucket_batch_sizes_(bucket_batch_sizes), | |||||
| element_length_function_(element_length_function), | |||||
| pad_info_(pad_info), | |||||
| pad_to_bucket_boundary_(pad_to_bucket_boundary), | |||||
| drop_remainder_(drop_remainder), | |||||
| batch_count_(0) { | |||||
| for (int i = 0; i < bucket_batch_sizes_.size(); i++) { | |||||
| buckets_.push_back(std::make_unique<TensorQTable>()); | |||||
| } | |||||
| } | |||||
| Status BucketBatchByLengthOp::EoeReceived(int32_t) { | |||||
| state_ = OpState::kDeOpIdle; | |||||
| return Status::OK(); | |||||
| } | |||||
| void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } | |||||
| Status BucketBatchByLengthOp::operator()() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| TensorRow current_row; | |||||
| child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0); | |||||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); | |||||
| RETURN_IF_NOT_OK(AssignColMapFromChild()); | |||||
| while (!child_iterator_->eof_handled()) { | |||||
| while (!current_row.empty()) { | |||||
| int32_t element_length; | |||||
| RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); | |||||
| int bucket_index = bucket_boundaries_.size() - 1; | |||||
| while (element_length < bucket_boundaries_[bucket_index]) { | |||||
| bucket_index--; | |||||
| } | |||||
| buckets_[bucket_index]->push_back(current_row); | |||||
| if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { | |||||
| RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); | |||||
| } | |||||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); | |||||
| } | |||||
| // got EOE, do what we need to do with remainders in each bucket | |||||
| if (!drop_remainder_) { | |||||
| for (int i = 0; i < bucket_boundaries_.size(); i++) { | |||||
| if (!buckets_[i]->empty()) { | |||||
| RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); | |||||
| } | |||||
| } | |||||
| } | |||||
| // need to send EOE manually since we set state to idle in EoeRecieved() | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { | |||||
| // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of | |||||
| // the single column specified in length_dependent_columns_ | |||||
| if (element_length_function_) { | |||||
| py::gil_scoped_acquire gil_acquire; | |||||
| if (Py_IsInitialized() == 0) { | |||||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||||
| } | |||||
| try { | |||||
| size_t number_of_arguments = length_dependent_columns_.size(); | |||||
| py::tuple input_arguments(number_of_arguments); | |||||
| for (size_t i = 0; i < number_of_arguments; i++) { | |||||
| py::array argument_value; | |||||
| int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; | |||||
| RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); | |||||
| input_arguments[i] = argument_value; | |||||
| } | |||||
| py::object length = element_length_function_(*input_arguments); | |||||
| *out_element_length = length.cast<int32_t>(); | |||||
| if (*out_element_length < 0) { | |||||
| return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); | |||||
| } | |||||
| } catch (const py::error_already_set &e) { | |||||
| return Status(StatusCode::kPyFuncException, e.what()); | |||||
| } catch (const py::cast_error &e) { | |||||
| return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); | |||||
| } | |||||
| } else { | |||||
| *out_element_length = element[0]->shape()[0]; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { | |||||
| std::unique_ptr<TensorQTable> *bucket = &buckets_[bucket_index]; | |||||
| PadInfo pad_info_copy = pad_info_; | |||||
| if (pad_to_bucket_boundary_) { | |||||
| for (auto &pair : pad_info_copy) { | |||||
| std::vector<dsize_t> pad_shape = pair.second.first.AsVector(); | |||||
| for (size_t i = 0; i < pad_shape.size(); i++) { | |||||
| if (pad_shape[i] == TensorShape::kDimUnknown) { | |||||
| if (bucket_index + 1 >= bucket_boundaries_.size()) { | |||||
| std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); | |||||
| } | |||||
| pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; | |||||
| } | |||||
| } | |||||
| pair.second.first = TensorShape(pad_shape); | |||||
| } | |||||
| } | |||||
| // PadColumns will change the data in bucket | |||||
| RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); | |||||
| std::unique_ptr<TensorQTable> batched_bucket = std::make_unique<TensorQTable>(); | |||||
| RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); | |||||
| (*bucket)->clear(); | |||||
| std::unique_ptr<DataBuffer> batched_buffer = std::make_unique<DataBuffer>(batch_count_, DataBuffer::kDeBFlagNone); | |||||
| batched_buffer->set_tensor_table(std::move(batched_bucket)); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); | |||||
| batch_count_++; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status BucketBatchByLengthOp::Reset() { | |||||
| batch_count_ = 0; | |||||
| for (int i = 0; i < buckets_.size(); i++) { | |||||
| buckets_[i] = std::make_unique<TensorQTable>(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,153 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ | |||||
| #define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <queue> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "dataset/core/config_manager.h" | |||||
| #include "dataset/core/tensor.h" | |||||
| #include "dataset/engine/dataset_iterator.h" | |||||
| #include "dataset/engine/datasetops/batch_op.h" | |||||
| #include "dataset/engine/datasetops/pipeline_op.h" | |||||
| #include "dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| class DataBuffer; | |||||
| class BucketBatchByLengthOp : public PipelineOp { | |||||
| public: | |||||
| class Builder { | |||||
| public: | |||||
| Builder(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries, | |||||
| std::vector<int32_t> bucket_batch_sizes); | |||||
| ~Builder() = default; | |||||
| Builder &SetLengthDependentColumns(std::vector<std::string> length_dependent_columns) { | |||||
| builder_length_dependent_columns_ = length_dependent_columns; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetBucketBoundaries(std::vector<int32_t> bucket_boundaries) { | |||||
| builder_bucket_boundaries_ = bucket_boundaries; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetBucketBatchSizes(std::vector<int32_t> bucket_batch_sizes) { | |||||
| builder_bucket_batch_sizes_ = bucket_batch_sizes; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetElementLengthFunction(py::function element_length_function) { | |||||
| builder_element_length_function_ = element_length_function; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetPadInfo(PadInfo pad_info) { | |||||
| builder_pad_info_ = pad_info; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { | |||||
| builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetDropRemainder(bool drop_remainder) { | |||||
| builder_drop_remainder_ = drop_remainder; | |||||
| return *this; | |||||
| } | |||||
| Builder &SetOpConnectorSize(int32_t op_connector_size) { | |||||
| builder_op_connector_size_ = op_connector_size; | |||||
| return *this; | |||||
| } | |||||
| Status Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op); | |||||
| private: | |||||
| Status SanityCheck(); | |||||
| std::vector<std::string> builder_length_dependent_columns_; | |||||
| std::vector<int32_t> builder_bucket_boundaries_; | |||||
| std::vector<int32_t> builder_bucket_batch_sizes_; | |||||
| py::function builder_element_length_function_; | |||||
| PadInfo builder_pad_info_; | |||||
| bool builder_pad_to_bucket_boundary_; | |||||
| bool builder_drop_remainder_; | |||||
| int32_t builder_op_connector_size_; | |||||
| }; | |||||
| BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries, | |||||
| std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, | |||||
| bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); | |||||
| // Might need to batch remaining buckets after receiving eoe, so override this method. | |||||
| // @param int32_t workerId | |||||
| // @return Status - The error code returned | |||||
| Status EoeReceived(int32_t) override; | |||||
| // A print method typically used for debugging | |||||
| // @param out - The output stream to write output to | |||||
| // @param show_all - A bool to control if you want to show all info or just a summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| // << Stream output operator overload | |||||
| // @notes This allows you to write the debug print info using stream operators | |||||
| // @param out - reference to the output stream being overloaded | |||||
| // @param sO - reference to the BucketBatchByLengthOp to display | |||||
| // @return - the output stream must be returned | |||||
| friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { | |||||
| bo.Print(out, false); | |||||
| return out; | |||||
| } | |||||
| // Main loop of batch | |||||
| // @return Status - The error code returned | |||||
| Status operator()() override; | |||||
| // Function that is called by ResetOp at the end of every epoch | |||||
| // @return Status - The error code returned | |||||
| Status Reset() override; | |||||
| private: | |||||
| Status ObtainElementLength(int32_t *out_element_length, TensorRow element); | |||||
| Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); | |||||
| std::vector<std::string> length_dependent_columns_; | |||||
| std::vector<int32_t> bucket_boundaries_; | |||||
| std::vector<int32_t> bucket_batch_sizes_; | |||||
| py::function element_length_function_; | |||||
| PadInfo pad_info_; | |||||
| bool pad_to_bucket_boundary_; | |||||
| bool drop_remainder_; | |||||
| int32_t batch_count_; | |||||
| std::unique_ptr<ChildIterator> child_iterator_; | |||||
| std::vector<std::unique_ptr<TensorQTable>> buckets_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ | |||||
| @@ -42,9 +42,9 @@ from .iterators import DictIterator, TupleIterator | |||||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | ||||
| check_rename, check_numpyslicesdataset, \ | check_rename, check_numpyslicesdataset, \ | ||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||||
| check_split, check_cluedataset | |||||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\ | |||||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\ | |||||
| check_split, check_bucket_batch_by_length, check_cluedataset | |||||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | ||||
| try: | try: | ||||
| @@ -165,6 +165,76 @@ class Dataset: | |||||
| args["num_parallel_workers"] = self.num_parallel_workers | args["num_parallel_workers"] = self.num_parallel_workers | ||||
| return args | return args | ||||
| @check_bucket_batch_by_length | |||||
| def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes, | |||||
| element_length_function=None, pad_info=None, | |||||
| pad_to_bucket_boundary=False, drop_remainder=False): | |||||
| """ | |||||
| Bucket elements according to their lengths, and pad and batch the buckets when | |||||
| they are full. | |||||
| A length function is called on each row in the dataset, the row is then | |||||
| bucketed based on its length and bucket_boundaries. When a bucket reaches its | |||||
| corresponding size specified in bucket_batch_sizes, the entire bucket will be | |||||
| padded according to batch_info, and then batched. Each batch will be full, | |||||
| except for maybe the last batch for each bucket. | |||||
| Args: | |||||
| column_names (list of string): Columns passed to element_length_function. | |||||
| bucket_boundaries (list of int): A list consisting of the upper boundaries | |||||
| of the buckets. Must be strictly increasing. If there are n boundaries, | |||||
| n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one | |||||
| bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each | |||||
| 0<i<n, and one bucket for [bucket_boundaries[n-1], inf). | |||||
| bucket_batch_sizes (list of int): A list consisting of the batch sizes for | |||||
| each buclet. Must contain len(bucket_boundaries)+1 elements. | |||||
| element_length_function (Callable, optional): A function that takes in | |||||
| len(column_names) arguments and returns an int. If no value is | |||||
| provided, then len(column_names) must be 1, and the size of the first | |||||
| dimension of that column will be taken as the length (default=None). | |||||
| pad_info (dict, optional): Represents how to batch each column. The key | |||||
| corresponds to the column name, the value must be a tuple of 2 elements. | |||||
| The first element corresponds to the shape to pad to, and the second | |||||
| element corresponds to the value to pad with. If a column is not | |||||
| specified, then that column will be padded to the longest in the current | |||||
| batch, and 0 will be used as the padding value. Any None dimensions will | |||||
| be padded to the longest in the current batch, unless if | |||||
| pad_to_bucket_boundary is True. If no padding is wanted, set pad_info | |||||
| to None (default=None). | |||||
| pad_to_bucket_boundary (bool, optional): If True, will pad each None | |||||
| dimension in pad_info to the bucket_boundary minus 1. If there are any | |||||
| elements that fall into the last bucket, an error will occur | |||||
| (default=False). | |||||
| drop_remainder (bool, optional): If True, will drop the last batch for each | |||||
| bucket if it is not a full batch (default=False). | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> # data is an instance of Dataset object. | |||||
| >>> | |||||
| >>> # creates a dataset where every 100 rows is combined into a batch | |||||
| >>> # and drops the last incomplete batch if there is one. | |||||
| >>> column_names = ["col1", "col2"] | |||||
| >>> buket_boundaries = [5, 10] | |||||
| >>> bucket_batch_sizes = [5, 1, 1] | |||||
| >>> element_length_function = (lambda col1, col2: max(len(col1), len(col2))) | |||||
| >>> | |||||
| >>> # will pad col1 to shape [2, bucket_boundaries[i]] where i is the | |||||
| >>> # index of the bucket that is currently being batched. | |||||
| >>> # will pad col2 to a shape where each dimension is the longest in all | |||||
| >>> # the elements currently being batched. | |||||
| >>> pad_info = {"col1", ([2, None], -1)} | |||||
| >>> pad_to_bucket_boundary = True | |||||
| >>> | |||||
| >>> data = data.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| >>> bucket_batch_sizes, | |||||
| >>> element_length_function, pad_info), | |||||
| >>> pad_to_bucket_boundary) | |||||
| """ | |||||
| return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes, | |||||
| element_length_function, pad_info, | |||||
| pad_to_bucket_boundary, drop_remainder) | |||||
| @check_batch | @check_batch | ||||
| def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, | def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, | ||||
| input_columns=None, pad_info=None): | input_columns=None, pad_info=None): | ||||
| @@ -1400,6 +1470,47 @@ class DatasetOp(Dataset): | |||||
| # No need for __init__ since it is the same as the super's init | # No need for __init__ since it is the same as the super's init | ||||
| class BucketBatchByLengthDataset(DatasetOp): | |||||
| """ | |||||
| The result of applying BucketBatchByLength operator to the input dataset. | |||||
| """ | |||||
| def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes, | |||||
| element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder): | |||||
| super().__init__() | |||||
| self.column_names = column_names | |||||
| self.bucket_boundaries = bucket_boundaries | |||||
| self.bucket_batch_sizes = bucket_batch_sizes | |||||
| self.element_length_function = element_length_function | |||||
| self.pad_info = pad_info | |||||
| self.pad_to_bucket_boundary = pad_to_bucket_boundary | |||||
| self.drop_remainder = drop_remainder | |||||
| self.input.append(input_dataset) | |||||
| input_dataset.output.append(self) | |||||
| self._input_indexs = input_dataset.input_indexs | |||||
| def get_args(self): | |||||
| args = super().get_args() | |||||
| args["length_dependent_columns"] = self.column_names | |||||
| args["bucket_boundaries"] = self.bucket_boundaries | |||||
| args["bucket_batch_sizes"] = self.bucket_batch_sizes | |||||
| args["element_length_function"] = self.element_length_function | |||||
| args["pad_info"] = self.pad_info | |||||
| args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary | |||||
| args["drop_remainder"] = self.drop_remainder | |||||
| return args | |||||
| def get_dataset_size(self): | |||||
| """ | |||||
| Get the number of batches in an epoch. | |||||
| Return: | |||||
| Number, number of batches. | |||||
| """ | |||||
| return None | |||||
| class BatchDataset(DatasetOp): | class BatchDataset(DatasetOp): | ||||
| """ | """ | ||||
| @@ -132,6 +132,8 @@ class Iterator: | |||||
| op_type = OpName.MINDRECORD | op_type = OpName.MINDRECORD | ||||
| elif isinstance(dataset, de.BatchDataset): | elif isinstance(dataset, de.BatchDataset): | ||||
| op_type = OpName.BATCH | op_type = OpName.BATCH | ||||
| elif isinstance(dataset, de.BucketBatchByLengthDataset): | |||||
| op_type = OpName.BUCKETBATCH | |||||
| elif isinstance(dataset, de.SyncWaitDataset): | elif isinstance(dataset, de.SyncWaitDataset): | ||||
| op_type = OpName.BARRIER | op_type = OpName.BARRIER | ||||
| elif isinstance(dataset, de.ZipDataset): | elif isinstance(dataset, de.ZipDataset): | ||||
| @@ -752,6 +752,67 @@ def check_pad_info(key, val): | |||||
| check_type(val[1], "pad_value", (int, float, str, bytes)) | check_type(val[1], "pad_value", (int, float, str, bytes)) | ||||
| def check_bucket_batch_by_length(method): | |||||
| """check the input arguments of bucket_batch_by_length.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] | |||||
| check_param_type(nreq_param_list, param_dict, list) | |||||
| # check column_names: must be list of string. | |||||
| column_names = param_dict.get("column_names") | |||||
| all_string = all(isinstance(item, str) for item in column_names) | |||||
| if not all_string: | |||||
| raise TypeError("column_names should be a list of str.") | |||||
| element_length_function = param_dict.get("element_length_function") | |||||
| if element_length_function is None and len(column_names) != 1: | |||||
| raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") | |||||
| # check bucket_boundaries: must be list of int, positive and strictly increasing | |||||
| bucket_boundaries = param_dict.get('bucket_boundaries') | |||||
| if not bucket_boundaries: | |||||
| raise ValueError("bucket_boundaries cannot be empty.") | |||||
| all_int = all(isinstance(item, int) for item in bucket_boundaries) | |||||
| if not all_int: | |||||
| raise TypeError("bucket_boundaries should be a list of int.") | |||||
| all_non_negative = all(item >= 0 for item in bucket_boundaries) | |||||
| if not all_non_negative: | |||||
| raise ValueError("bucket_boundaries cannot contain any negative numbers.") | |||||
| for i in range(len(bucket_boundaries) - 1): | |||||
| if not bucket_boundaries[i + 1] > bucket_boundaries[i]: | |||||
| raise ValueError("bucket_boundaries should be strictly increasing.") | |||||
| # check bucket_batch_sizes: must be list of int and positive | |||||
| bucket_batch_sizes = param_dict.get('bucket_batch_sizes') | |||||
| if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: | |||||
| raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") | |||||
| all_int = all(isinstance(item, int) for item in bucket_batch_sizes) | |||||
| if not all_int: | |||||
| raise TypeError("bucket_batch_sizes should be a list of int.") | |||||
| all_non_negative = all(item >= 0 for item in bucket_batch_sizes) | |||||
| if not all_non_negative: | |||||
| raise ValueError("bucket_batch_sizes cannot contain any negative numbers.") | |||||
| if param_dict.get('pad_info') is not None: | |||||
| check_type(param_dict["pad_info"], "pad_info", dict) | |||||
| for k, v in param_dict.get('pad_info').items(): | |||||
| check_pad_info(k, v) | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_batch(method): | def check_batch(method): | ||||
| """check the input arguments of batch.""" | """check the input arguments of batch.""" | ||||
| @@ -0,0 +1,373 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import pytest | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| # generates 1 column [0], [0, 1], ..., [0, ..., n-1] | |||||
| def generate_sequential(n): | |||||
| for i in range(n): | |||||
| yield (np.array([j for j in range(i + 1)]),) | |||||
| # generates 1 column [0], [1], ..., [n-1] | |||||
| def generate_sequential_same_shape(n): | |||||
| for i in range(n): | |||||
| yield (np.array([i]),) | |||||
| # combines generate_sequential_same_shape and generate_sequential | |||||
| def generate_2_columns(n): | |||||
| for i in range(n): | |||||
| yield (np.array([i]), np.array([j for j in range(i + 1)])) | |||||
| def test_bucket_batch_invalid_input(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| invalid_column_names = [1, 2, 3] | |||||
| bucket_boundaries = [1, 2, 3] | |||||
| empty_bucket_boundaries = [] | |||||
| invalid_bucket_boundaries = ["1", "2", "3"] | |||||
| negative_bucket_boundaries = [1, 2, -3] | |||||
| decreasing_bucket_boundaries = [3, 2, 1] | |||||
| non_increasing_bucket_boundaries = [1, 2, 2] | |||||
| bucket_batch_sizes = [1, 1, 1, 1] | |||||
| invalid_bucket_batch_sizes = ["1", "2", "3", "4"] | |||||
| negative_bucket_batch_sizes = [1, 2, 3, -4] | |||||
| with pytest.raises(TypeError) as info: | |||||
| _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) | |||||
| assert "column_names should be a list of str" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) | |||||
| assert "bucket_boundaries cannot be empty" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) | |||||
| assert "bucket_boundaries should be a list of int" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) | |||||
| assert "bucket_boundaries cannot contain any negative numbers" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) | |||||
| assert "bucket_boundaries should be strictly increasing" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, non_increasing_bucket_boundaries, bucket_batch_sizes) | |||||
| assert "bucket_boundaries should be strictly increasing" in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, invalid_bucket_batch_sizes) | |||||
| assert "bucket_batch_sizes should be a list of int" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) | |||||
| assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) | |||||
| assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value) | |||||
| def test_bucket_batch_multi_bucket_no_padding(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [1, 2, 3] | |||||
| bucket_batch_sizes = [3, 3, 2, 2] | |||||
| element_length_function = (lambda x: x[0] % 4) | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function) | |||||
| expected_output = [[[2], [6]], | |||||
| [[3], [7]], | |||||
| [[0], [4], [8]], | |||||
| [[1], [5], [9]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_multi_bucket_with_padding(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(10)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [1, 2, 3] | |||||
| bucket_batch_sizes = [2, 3, 3, 2] | |||||
| element_length_function = (lambda x: len(x) % 4) | |||||
| pad_info = {"col1": ([10], 0)} | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, | |||||
| pad_info) | |||||
| expected_output = [[[0, 1, 2, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 0, 0, 0]], | |||||
| [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0]], | |||||
| [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0]], | |||||
| [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_single_bucket_no_padding(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [1, 2, 3] | |||||
| bucket_batch_sizes = [1, 1, 5, 1] | |||||
| element_length_function = (lambda x: 2) | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function) | |||||
| expected_output = [[[0], [1], [2], [3], [4]], | |||||
| [[5], [6], [7], [8], [9]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_single_bucket_with_padding(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [1, 2, 3] | |||||
| bucket_batch_sizes = [1, 1, 1, 3] | |||||
| element_length_function = (lambda x: 7) | |||||
| pad_info = {"col1": ([12], 0)} | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, | |||||
| pad_info) | |||||
| expected_output = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]], | |||||
| [[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0]], | |||||
| [[0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_pad_to_bucket_boundary(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [3, 6, 15] | |||||
| bucket_batch_sizes = [2, 3, 4, 1] | |||||
| element_length_function = len | |||||
| pad_info = {"col1": ([None], 0)} | |||||
| pad_to_bucket_boundary = True | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, | |||||
| pad_info, pad_to_bucket_boundary) | |||||
| expected_output = [[[0, 0], | |||||
| [0, 1]], | |||||
| [[0, 1, 2, 0, 0], | |||||
| [0, 1, 2, 3, 0], | |||||
| [0, 1, 2, 3, 4]], | |||||
| [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_default_pad(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(15)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [5, 8, 17] | |||||
| bucket_batch_sizes = [2, 1, 4, 1] | |||||
| element_length_function = len | |||||
| pad_info = {"col1": ([None], 0)} | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, | |||||
| pad_info) | |||||
| expected_output = [[[0, 0], | |||||
| [0, 1]], | |||||
| [[0, 1, 2, 0], | |||||
| [0, 1, 2, 3]], | |||||
| [[0, 1, 2, 3, 4]], | |||||
| [[0, 1, 2, 3, 4, 5]], | |||||
| [[0, 1, 2, 3, 4, 5, 6]], | |||||
| [[0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], | |||||
| [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_drop_remainder(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(27)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [1, 2] | |||||
| bucket_batch_sizes = [2, 3, 5] | |||||
| element_length_function = (lambda x: x[0] % 3) | |||||
| pad_info = None | |||||
| pad_to_bucket_boundary = False | |||||
| drop_remainder = True | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, | |||||
| pad_info, pad_to_bucket_boundary, drop_remainder) | |||||
| expected_output = [[[0], [3]], | |||||
| [[1], [4], [7]], | |||||
| [[6], [9]], | |||||
| [[2], [5], [8], [11], [14]], | |||||
| [[12], [15]], | |||||
| [[10], [13], [16]], | |||||
| [[18], [21]], | |||||
| [[19], [22], [25]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_default_length_function(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"]) | |||||
| column_names = ["col1"] | |||||
| bucket_boundaries = [6, 12] | |||||
| bucket_batch_sizes = [5, 4, 1] | |||||
| element_length_function = None | |||||
| pad_info = {} | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, | |||||
| pad_info) | |||||
| expected_output = [[[0, 0, 0, 0, 0], | |||||
| [0, 1, 0, 0, 0], | |||||
| [0, 1, 2, 0, 0], | |||||
| [0, 1, 2, 3, 0], | |||||
| [0, 1, 2, 3, 4]], | |||||
| [[0, 1, 2, 3, 4, 5, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8]]] | |||||
| output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| output.append(data["col1"].tolist()) | |||||
| assert output == expected_output | |||||
| def test_bucket_batch_multi_column(): | |||||
| dataset = ds.GeneratorDataset((lambda: generate_2_columns(10)), ["same_shape", "variable_shape"]) | |||||
| column_names = ["same_shape"] | |||||
| bucket_boundaries = [6, 12] | |||||
| bucket_batch_sizes = [5, 5, 1] | |||||
| element_length_function = None | |||||
| pad_info = {} | |||||
| dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, | |||||
| pad_info) | |||||
| same_shape_expected_output = [[[0], [1], [2], [3], [4]], | |||||
| [[5], [6], [7], [8], [9]]] | |||||
| variable_shape_expected_output = [[[0, 0, 0, 0, 0], | |||||
| [0, 1, 0, 0, 0], | |||||
| [0, 1, 2, 0, 0], | |||||
| [0, 1, 2, 3, 0], | |||||
| [0, 1, 2, 3, 4]], | |||||
| [[0, 1, 2, 3, 4, 5, 0, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 0, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 0, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 0], | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]] | |||||
| same_shape_output = [] | |||||
| variable_shape_output = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| same_shape_output.append(data["same_shape"].tolist()) | |||||
| variable_shape_output.append(data["variable_shape"].tolist()) | |||||
| assert same_shape_output == same_shape_expected_output | |||||
| assert variable_shape_output == variable_shape_expected_output | |||||
| if __name__ == '__main__': | |||||
| test_bucket_batch_invalid_input() | |||||
| test_bucket_batch_multi_bucket_no_padding() | |||||
| test_bucket_batch_multi_bucket_with_padding() | |||||
| test_bucket_batch_single_bucket_no_padding() | |||||
| test_bucket_batch_single_bucket_with_padding() | |||||
| test_bucket_batch_pad_to_bucket_boundary() | |||||
| test_bucket_batch_default_pad() | |||||
| test_bucket_batch_drop_remainder() | |||||
| test_bucket_batch_default_length_function() | |||||
| test_bucket_batch_multi_column() | |||||