Browse Source

!1984 Implementation of BucketBatchByLengthOp

Merge pull request !1984 from Peilin/BucketBatchByLengthOp
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
6d6f367529
11 changed files with 1060 additions and 61 deletions
  1. +94
    -41
      mindspore/ccsrc/dataset/api/de_pipeline.cc
  2. +3
    -0
      mindspore/ccsrc/dataset/api/de_pipeline.h
  3. +1
    -0
      mindspore/ccsrc/dataset/api/python_bindings.cc
  4. +1
    -0
      mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
  5. +16
    -17
      mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
  6. +242
    -0
      mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc
  7. +153
    -0
      mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h
  8. +114
    -3
      mindspore/dataset/engine/datasets.py
  9. +2
    -0
      mindspore/dataset/engine/iterators.py
  10. +61
    -0
      mindspore/dataset/engine/validators.py
  11. +373
    -0
      tests/ut/python/dataset/test_bucket_batch_by_length.py

+ 94
- 41
mindspore/ccsrc/dataset/api/de_pipeline.cc View File

@@ -19,62 +19,65 @@
#include <map>

#include "common/utils.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
#include "pybind11/stl.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace dataset {
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *);

static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp},
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
{kConcat, &DEPipeline::ParseConcatOp},
{kRename, &DEPipeline::ParseRenameOp},
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
{kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp},
{kVoc, &DEPipeline::ParseVOCOp},
{kCoco, &DEPipeline::ParseCocoOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kStorage, &DEPipeline::ParseStorageOp},
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
{kConcat, &DEPipeline::ParseConcatOp},
{kRename, &DEPipeline::ParseRenameOp},
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
{kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp},
{kVoc, &DEPipeline::ParseVOCOp},
{kCoco, &DEPipeline::ParseCocoOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};

DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
@@ -672,6 +675,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}

Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::vector<std::string> mandatory_arguments = {"length_dependent_columns", "bucket_boundaries",
"bucket_batch_sizes"};
for (auto name : mandatory_arguments) {
if (args[name.c_str()].is_none()) {
std::string err_msg = "Error: " + name + " is not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}

std::shared_ptr<BucketBatchByLengthOp::Builder> builder = std::make_shared<BucketBatchByLengthOp::Builder>(
ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]),
ToIntVector(args[mandatory_arguments[2].c_str()]));

for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "length_dependent_columns") {
(void)builder->SetLengthDependentColumns(ToStringVector(value));
}
if (key == "bucket_boundaries") {
(void)builder->SetBucketBoundaries(ToIntVector(value));
}
if (key == "bucket_batch_sizes") {
(void)builder->SetBucketBatchSizes(ToIntVector(value));
}
if (key == "element_length_function") {
(void)builder->SetElementLengthFunction(value.cast<py::function>());
}
if (key == "pad_info") {
PadInfo pad_info;
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
(void)builder->SetPadInfo(pad_info);
}
if (key == "pad_to_bucket_boundary") {
(void)builder->SetPadToBucketBoundary(ToBool(value));
}
if (key == "drop_remainder") {
(void)builder->SetDropRemainder(ToBool(value));
}
}
}

std::shared_ptr<BucketBatchByLengthOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}

Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
// Right now barrier should only take num_rows_per_buffer = 1


+ 3
- 0
mindspore/ccsrc/dataset/api/de_pipeline.h View File

@@ -40,6 +40,7 @@ enum OpName {
kShuffle,
kMindrecord,
kBatch,
kBucketBatch,
kBarrier,
kCache,
kRepeat,
@@ -121,6 +122,8 @@ class DEPipeline {

Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);

Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);


+ 1
- 0
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -616,6 +616,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("STORAGE", OpName::kStorage)
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BUCKETBATCH", OpName::kBucketBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)


+ 1
- 0
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt View File

@@ -8,6 +8,7 @@ add_library(engine-datasetops OBJECT
pipeline_op.cc
barrier_op.cc
batch_op.cc
bucket_batch_by_length_op.cc
device_queue_op.cc
map_op.cc
project_op.cc


+ 16
- 17
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h View File

@@ -193,6 +193,22 @@ class BatchOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "BatchOp"; }

// batch the rows in src table then put it to dest table
// @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
// @param int32_t size - batch_size
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
dsize_t batch_size);

// @param table
// @param const PadInfo &pad_info pad info
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);

private:
// Worker thread for doing the memcpy of batch
// @param int32_t param workerId
@@ -203,16 +219,6 @@ class BatchOp : public ParallelOp {
// @return Status - The error code return
Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
std::unique_ptr<DataBuffer> *db);

// batch the rows in src table then put it to dest table
// @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
// @param int32_t size - batch_size
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
dsize_t batch_size);

// Function that calls pyfunc to perform map on batch
// @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
// @return Status - The error code return
@@ -229,13 +235,6 @@ class BatchOp : public ParallelOp {
std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
std::vector<std::vector<dsize_t>> *pad_shapes);

// @param table
// @param const PadInfo &pad_info pad info
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);

// the number of thread pulling from the mOutConnector of the Op below
// @return int32_t, 1
int32_t num_consumers() const override { return 1; }


+ 242
- 0
mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc View File

@@ -0,0 +1,242 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "dataset/core/pybind_support.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_shape.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/status.h"

namespace py = pybind11;
namespace mindspore {
namespace dataset {
BucketBatchByLengthOp::Builder::Builder(std::vector<std::string> length_dependent_columns,
std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes)
: builder_length_dependent_columns_(length_dependent_columns),
builder_bucket_boundaries_(bucket_boundaries),
builder_bucket_batch_sizes_(bucket_batch_sizes),
builder_pad_info_({}),
builder_pad_to_bucket_boundary_(false),
builder_drop_remainder_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_op_connector_size_ = config_manager->op_connector_size();
}

Status BucketBatchByLengthOp::Builder::SanityCheck() {
std::string error_message;

if (builder_length_dependent_columns_.empty()) {
error_message += "At least 1 column must be specified for element length calculation.\n";
}

if (builder_bucket_boundaries_.empty()) {
error_message += "At least 1 bucket boundary must be specified.\n";
}

if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) {
error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n";
}

CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message);

return Status::OK();
}

Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op) {
RETURN_IF_NOT_OK(SanityCheck());

// insert 0 for the first bucket
builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0);

*new_bucket_batch_by_length_op = std::make_shared<BucketBatchByLengthOp>(
builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_,
builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_,
builder_op_connector_size_);

return Status::OK();
}

BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns,
std::vector<int32_t> bucket_boundaries,
std::vector<int32_t> bucket_batch_sizes,
py::function element_length_function, PadInfo pad_info,
bool pad_to_bucket_boundary, bool drop_remainder,
int32_t op_connector_size)
: PipelineOp(op_connector_size),
length_dependent_columns_(length_dependent_columns),
bucket_boundaries_(bucket_boundaries),
bucket_batch_sizes_(bucket_batch_sizes),
element_length_function_(element_length_function),
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder),
batch_count_(0) {
for (int i = 0; i < bucket_batch_sizes_.size(); i++) {
buckets_.push_back(std::make_unique<TensorQTable>());
}
}

Status BucketBatchByLengthOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}

void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; }

Status BucketBatchByLengthOp::operator()() {
TaskManager::FindMe()->Post();

TensorRow current_row;
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&current_row));
RETURN_IF_NOT_OK(AssignColMapFromChild());
while (!child_iterator_->eof_handled()) {
while (!current_row.empty()) {
int32_t element_length;
RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row));

int bucket_index = bucket_boundaries_.size() - 1;
while (element_length < bucket_boundaries_[bucket_index]) {
bucket_index--;
}

buckets_[bucket_index]->push_back(current_row);

if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) {
RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index]));
}

RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&current_row));
}

// got EOE, do what we need to do with remainders in each bucket
if (!drop_remainder_) {
for (int i = 0; i < bucket_boundaries_.size(); i++) {
if (!buckets_[i]->empty()) {
RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size()));
}
}
}

// need to send EOE manually since we set state to idle in EoeRecieved()
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));

RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&current_row));
}

return Status::OK();
}

Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) {
// call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of
// the single column specified in length_dependent_columns_
if (element_length_function_) {
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
size_t number_of_arguments = length_dependent_columns_.size();
py::tuple input_arguments(number_of_arguments);
for (size_t i = 0; i < number_of_arguments; i++) {
py::array argument_value;
int32_t column_index = column_name_id_map_[length_dependent_columns_[i]];
RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value));
input_arguments[i] = argument_value;
}

py::object length = element_length_function_(*input_arguments);
*out_element_length = length.cast<int32_t>();
if (*out_element_length < 0) {
return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer.");
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
} catch (const py::cast_error &e) {
return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t.");
}
} else {
*out_element_length = element[0]->shape()[0];
}

return Status::OK();
}

Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) {
std::unique_ptr<TensorQTable> *bucket = &buckets_[bucket_index];

PadInfo pad_info_copy = pad_info_;
if (pad_to_bucket_boundary_) {
for (auto &pair : pad_info_copy) {
std::vector<dsize_t> pad_shape = pair.second.first.AsVector();

for (size_t i = 0; i < pad_shape.size(); i++) {
if (pad_shape[i] == TensorShape::kDimUnknown) {
if (bucket_index + 1 >= bucket_boundaries_.size()) {
std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket";
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message);
}

pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1;
}
}

pair.second.first = TensorShape(pad_shape);
}
}

// PadColumns will change the data in bucket
RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_));

std::unique_ptr<TensorQTable> batched_bucket = std::make_unique<TensorQTable>();
RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size));
(*bucket)->clear();

std::unique_ptr<DataBuffer> batched_buffer = std::make_unique<DataBuffer>(batch_count_, DataBuffer::kDeBFlagNone);
batched_buffer->set_tensor_table(std::move(batched_bucket));
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer)));

batch_count_++;

return Status::OK();
}

Status BucketBatchByLengthOp::Reset() {
batch_count_ = 0;

for (int i = 0; i < buckets_.size(); i++) {
buckets_[i] = std::make_unique<TensorQTable>();
}

return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 153
- 0
mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h View File

@@ -0,0 +1,153 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_
#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_

#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>

#include "dataset/core/config_manager.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/pipeline_op.h"
#include "dataset/util/status.h"

namespace mindspore {
namespace dataset {
class DataBuffer;

class BucketBatchByLengthOp : public PipelineOp {
public:
class Builder {
public:
Builder(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries,
std::vector<int32_t> bucket_batch_sizes);

~Builder() = default;

Builder &SetLengthDependentColumns(std::vector<std::string> length_dependent_columns) {
builder_length_dependent_columns_ = length_dependent_columns;
return *this;
}

Builder &SetBucketBoundaries(std::vector<int32_t> bucket_boundaries) {
builder_bucket_boundaries_ = bucket_boundaries;
return *this;
}

Builder &SetBucketBatchSizes(std::vector<int32_t> bucket_batch_sizes) {
builder_bucket_batch_sizes_ = bucket_batch_sizes;
return *this;
}

Builder &SetElementLengthFunction(py::function element_length_function) {
builder_element_length_function_ = element_length_function;
return *this;
}

Builder &SetPadInfo(PadInfo pad_info) {
builder_pad_info_ = pad_info;
return *this;
}

Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) {
builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary;
return *this;
}

Builder &SetDropRemainder(bool drop_remainder) {
builder_drop_remainder_ = drop_remainder;
return *this;
}

Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}

Status Build(std::shared_ptr<BucketBatchByLengthOp> *new_bucket_batch_by_length_op);

private:
Status SanityCheck();

std::vector<std::string> builder_length_dependent_columns_;
std::vector<int32_t> builder_bucket_boundaries_;
std::vector<int32_t> builder_bucket_batch_sizes_;
py::function builder_element_length_function_;
PadInfo builder_pad_info_;
bool builder_pad_to_bucket_boundary_;
bool builder_drop_remainder_;
int32_t builder_op_connector_size_;
};

BucketBatchByLengthOp(std::vector<std::string> length_dependent_columns, std::vector<int32_t> bucket_boundaries,
std::vector<int32_t> bucket_batch_sizes, py::function element_length_function, PadInfo pad_info,
bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size);

// Might need to batch remaining buckets after receiving eoe, so override this method.
// @param int32_t workerId
// @return Status - The error code returned
Status EoeReceived(int32_t) override;

// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;

// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param sO - reference to the BucketBatchByLengthOp to display
// @return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) {
bo.Print(out, false);
return out;
}

// Main loop of batch
// @return Status - The error code returned
Status operator()() override;

// Function that is called by ResetOp at the end of every epoch
// @return Status - The error code returned
Status Reset() override;

private:
Status ObtainElementLength(int32_t *out_element_length, TensorRow element);

Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size);

std::vector<std::string> length_dependent_columns_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;
py::function element_length_function_;
PadInfo pad_info_;
bool pad_to_bucket_boundary_;
bool drop_remainder_;

int32_t batch_count_;
std::unique_ptr<ChildIterator> child_iterator_;
std::vector<std::unique_ptr<TensorQTable>> buckets_;
};

} // namespace dataset
} // namespace mindspore

#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_

+ 114
- 3
mindspore/dataset/engine/datasets.py View File

@@ -42,9 +42,9 @@ from .iterators import DictIterator, TupleIterator
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_split, check_cluedataset
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset,\
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat,\
check_split, check_bucket_batch_by_length, check_cluedataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist

try:
@@ -165,6 +165,76 @@ class Dataset:
args["num_parallel_workers"] = self.num_parallel_workers
return args

@check_bucket_batch_by_length
def bucket_batch_by_length(self, column_names, bucket_boundaries, bucket_batch_sizes,
element_length_function=None, pad_info=None,
pad_to_bucket_boundary=False, drop_remainder=False):
"""
Bucket elements according to their lengths, and pad and batch the buckets when
they are full.

A length function is called on each row in the dataset, the row is then
bucketed based on its length and bucket_boundaries. When a bucket reaches its
corresponding size specified in bucket_batch_sizes, the entire bucket will be
padded according to batch_info, and then batched. Each batch will be full,
except for maybe the last batch for each bucket.

Args:
column_names (list of string): Columns passed to element_length_function.
bucket_boundaries (list of int): A list consisting of the upper boundaries
of the buckets. Must be strictly increasing. If there are n boundaries,
n+1 buckets are created: One bucket for [0, bucket_boundaries[0]), one
bucket for [bucket_boundaries[i], bucket_boundaries[i+1]) for each
0<i<n, and one bucket for [bucket_boundaries[n-1], inf).
bucket_batch_sizes (list of int): A list consisting of the batch sizes for
each buclet. Must contain len(bucket_boundaries)+1 elements.
element_length_function (Callable, optional): A function that takes in
len(column_names) arguments and returns an int. If no value is
provided, then len(column_names) must be 1, and the size of the first
dimension of that column will be taken as the length (default=None).
pad_info (dict, optional): Represents how to batch each column. The key
corresponds to the column name, the value must be a tuple of 2 elements.
The first element corresponds to the shape to pad to, and the second
element corresponds to the value to pad with. If a column is not
specified, then that column will be padded to the longest in the current
batch, and 0 will be used as the padding value. Any None dimensions will
be padded to the longest in the current batch, unless if
pad_to_bucket_boundary is True. If no padding is wanted, set pad_info
to None (default=None).
pad_to_bucket_boundary (bool, optional): If True, will pad each None
dimension in pad_info to the bucket_boundary minus 1. If there are any
elements that fall into the last bucket, an error will occur
(default=False).
drop_remainder (bool, optional): If True, will drop the last batch for each
bucket if it is not a full batch (default=False).

Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>>
>>> # creates a dataset where every 100 rows is combined into a batch
>>> # and drops the last incomplete batch if there is one.
>>> column_names = ["col1", "col2"]
>>> buket_boundaries = [5, 10]
>>> bucket_batch_sizes = [5, 1, 1]
>>> element_length_function = (lambda col1, col2: max(len(col1), len(col2)))
>>>
>>> # will pad col1 to shape [2, bucket_boundaries[i]] where i is the
>>> # index of the bucket that is currently being batched.
>>> # will pad col2 to a shape where each dimension is the longest in all
>>> # the elements currently being batched.
>>> pad_info = {"col1", ([2, None], -1)}
>>> pad_to_bucket_boundary = True
>>>
>>> data = data.bucket_batch_by_length(column_names, bucket_boundaries,
>>> bucket_batch_sizes,
>>> element_length_function, pad_info),
>>> pad_to_bucket_boundary)
"""
return BucketBatchByLengthDataset(self, column_names, bucket_boundaries, bucket_batch_sizes,
element_length_function, pad_info,
pad_to_bucket_boundary, drop_remainder)

@check_batch
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
input_columns=None, pad_info=None):
@@ -1400,6 +1470,47 @@ class DatasetOp(Dataset):

# No need for __init__ since it is the same as the super's init

class BucketBatchByLengthDataset(DatasetOp):
"""
The result of applying BucketBatchByLength operator to the input dataset.
"""

def __init__(self, input_dataset, column_names, bucket_boundaries, bucket_batch_sizes,
element_length_function, pad_info, pad_to_bucket_boundary, drop_remainder):
super().__init__()

self.column_names = column_names
self.bucket_boundaries = bucket_boundaries
self.bucket_batch_sizes = bucket_batch_sizes
self.element_length_function = element_length_function
self.pad_info = pad_info
self.pad_to_bucket_boundary = pad_to_bucket_boundary
self.drop_remainder = drop_remainder

self.input.append(input_dataset)
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs

def get_args(self):
args = super().get_args()
args["length_dependent_columns"] = self.column_names
args["bucket_boundaries"] = self.bucket_boundaries
args["bucket_batch_sizes"] = self.bucket_batch_sizes
args["element_length_function"] = self.element_length_function
args["pad_info"] = self.pad_info
args["pad_to_bucket_boundary"] = self.pad_to_bucket_boundary
args["drop_remainder"] = self.drop_remainder
return args

def get_dataset_size(self):
"""
Get the number of batches in an epoch.

Return:
Number, number of batches.
"""
return None


class BatchDataset(DatasetOp):
"""


+ 2
- 0
mindspore/dataset/engine/iterators.py View File

@@ -132,6 +132,8 @@ class Iterator:
op_type = OpName.MINDRECORD
elif isinstance(dataset, de.BatchDataset):
op_type = OpName.BATCH
elif isinstance(dataset, de.BucketBatchByLengthDataset):
op_type = OpName.BUCKETBATCH
elif isinstance(dataset, de.SyncWaitDataset):
op_type = OpName.BARRIER
elif isinstance(dataset, de.ZipDataset):


+ 61
- 0
mindspore/dataset/engine/validators.py View File

@@ -752,6 +752,67 @@ def check_pad_info(key, val):
check_type(val[1], "pad_value", (int, float, str, bytes))


def check_bucket_batch_by_length(method):
"""check the input arguments of bucket_batch_by_length."""

@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)

nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
check_param_type(nreq_param_list, param_dict, list)

# check column_names: must be list of string.
column_names = param_dict.get("column_names")
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")

element_length_function = param_dict.get("element_length_function")
if element_length_function is None and len(column_names) != 1:
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")

# check bucket_boundaries: must be list of int, positive and strictly increasing
bucket_boundaries = param_dict.get('bucket_boundaries')

if not bucket_boundaries:
raise ValueError("bucket_boundaries cannot be empty.")

all_int = all(isinstance(item, int) for item in bucket_boundaries)
if not all_int:
raise TypeError("bucket_boundaries should be a list of int.")

all_non_negative = all(item >= 0 for item in bucket_boundaries)
if not all_non_negative:
raise ValueError("bucket_boundaries cannot contain any negative numbers.")

for i in range(len(bucket_boundaries) - 1):
if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
raise ValueError("bucket_boundaries should be strictly increasing.")

# check bucket_batch_sizes: must be list of int and positive
bucket_batch_sizes = param_dict.get('bucket_batch_sizes')
if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")

all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
if not all_int:
raise TypeError("bucket_batch_sizes should be a list of int.")

all_non_negative = all(item >= 0 for item in bucket_batch_sizes)
if not all_non_negative:
raise ValueError("bucket_batch_sizes cannot contain any negative numbers.")

if param_dict.get('pad_info') is not None:
check_type(param_dict["pad_info"], "pad_info", dict)
for k, v in param_dict.get('pad_info').items():
check_pad_info(k, v)

return method(*args, **kwargs)

return new_method


def check_batch(method):
"""check the input arguments of batch."""



+ 373
- 0
tests/ut/python/dataset/test_bucket_batch_by_length.py View File

@@ -0,0 +1,373 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import pytest
import numpy as np
import mindspore.dataset as ds

# generates 1 column [0], [0, 1], ..., [0, ..., n-1]
def generate_sequential(n):
for i in range(n):
yield (np.array([j for j in range(i + 1)]),)


# generates 1 column [0], [1], ..., [n-1]
def generate_sequential_same_shape(n):
for i in range(n):
yield (np.array([i]),)


# combines generate_sequential_same_shape and generate_sequential
def generate_2_columns(n):
for i in range(n):
yield (np.array([i]), np.array([j for j in range(i + 1)]))


def test_bucket_batch_invalid_input():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])

column_names = ["col1"]
invalid_column_names = [1, 2, 3]

bucket_boundaries = [1, 2, 3]
empty_bucket_boundaries = []
invalid_bucket_boundaries = ["1", "2", "3"]
negative_bucket_boundaries = [1, 2, -3]
decreasing_bucket_boundaries = [3, 2, 1]
non_increasing_bucket_boundaries = [1, 2, 2]

bucket_batch_sizes = [1, 1, 1, 1]
invalid_bucket_batch_sizes = ["1", "2", "3", "4"]
negative_bucket_batch_sizes = [1, 2, 3, -4]

with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
assert "column_names should be a list of str" in str(info.value)

with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes)
assert "bucket_boundaries cannot be empty" in str(info.value)

with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes)
assert "bucket_boundaries should be a list of int" in str(info.value)

with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes)
assert "bucket_boundaries cannot contain any negative numbers" in str(info.value)

with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes)
assert "bucket_boundaries should be strictly increasing" in str(info.value)

with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, non_increasing_bucket_boundaries, bucket_batch_sizes)
assert "bucket_boundaries should be strictly increasing" in str(info.value)

with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, invalid_bucket_batch_sizes)
assert "bucket_batch_sizes should be a list of int" in str(info.value)

with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes)
assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value)

with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)
assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value)


def test_bucket_batch_multi_bucket_no_padding():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [1, 2, 3]
bucket_batch_sizes = [3, 3, 2, 2]
element_length_function = (lambda x: x[0] % 4)

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function)

expected_output = [[[2], [6]],
[[3], [7]],
[[0], [4], [8]],
[[1], [5], [9]]]

output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_multi_bucket_with_padding():
dataset = ds.GeneratorDataset((lambda: generate_sequential(10)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [1, 2, 3]
bucket_batch_sizes = [2, 3, 3, 2]
element_length_function = (lambda x: len(x) % 4)
pad_info = {"col1": ([10], 0)}

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info)

expected_output = [[[0, 1, 2, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0]],
[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0]],
[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]

output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_single_bucket_no_padding():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [1, 2, 3]
bucket_batch_sizes = [1, 1, 5, 1]
element_length_function = (lambda x: 2)

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function)

expected_output = [[[0], [1], [2], [3], [4]],
[[5], [6], [7], [8], [9]]]

output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_single_bucket_with_padding():
dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [1, 2, 3]
bucket_batch_sizes = [1, 1, 1, 3]
element_length_function = (lambda x: 7)
pad_info = {"col1": ([12], 0)}

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info)

expected_output = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0]],
[[0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0]]]

output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_pad_to_bucket_boundary():
dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [3, 6, 15]
bucket_batch_sizes = [2, 3, 4, 1]
element_length_function = len
pad_info = {"col1": ([None], 0)}
pad_to_bucket_boundary = True

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info, pad_to_bucket_boundary)

expected_output = [[[0, 0],
[0, 1]],
[[0, 1, 2, 0, 0],
[0, 1, 2, 3, 0],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0]]]

output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_default_pad():
dataset = ds.GeneratorDataset((lambda: generate_sequential(15)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [5, 8, 17]
bucket_batch_sizes = [2, 1, 4, 1]
element_length_function = len
pad_info = {"col1": ([None], 0)}

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info)

expected_output = [[[0, 0],
[0, 1]],
[[0, 1, 2, 0],
[0, 1, 2, 3]],
[[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4, 5]],
[[0, 1, 2, 3, 4, 5, 6]],
[[0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]]


output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_drop_remainder():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(27)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [1, 2]
bucket_batch_sizes = [2, 3, 5]
element_length_function = (lambda x: x[0] % 3)
pad_info = None
pad_to_bucket_boundary = False
drop_remainder = True

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info, pad_to_bucket_boundary, drop_remainder)

expected_output = [[[0], [3]],
[[1], [4], [7]],
[[6], [9]],
[[2], [5], [8], [11], [14]],
[[12], [15]],
[[10], [13], [16]],
[[18], [21]],
[[19], [22], [25]]]

output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_default_length_function():
dataset = ds.GeneratorDataset((lambda: generate_sequential(9)), ["col1"])

column_names = ["col1"]
bucket_boundaries = [6, 12]
bucket_batch_sizes = [5, 4, 1]
element_length_function = None
pad_info = {}

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info)

expected_output = [[[0, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 1, 2, 0, 0],
[0, 1, 2, 3, 0],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4, 5, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8]]]

output = []
for data in dataset.create_dict_iterator():
output.append(data["col1"].tolist())

assert output == expected_output


def test_bucket_batch_multi_column():
dataset = ds.GeneratorDataset((lambda: generate_2_columns(10)), ["same_shape", "variable_shape"])

column_names = ["same_shape"]
bucket_boundaries = [6, 12]
bucket_batch_sizes = [5, 5, 1]
element_length_function = None
pad_info = {}

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info)

same_shape_expected_output = [[[0], [1], [2], [3], [4]],
[[5], [6], [7], [8], [9]]]

variable_shape_expected_output = [[[0, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 1, 2, 0, 0],
[0, 1, 2, 3, 0],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4, 5, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]

same_shape_output = []
variable_shape_output = []
for data in dataset.create_dict_iterator():
same_shape_output.append(data["same_shape"].tolist())
variable_shape_output.append(data["variable_shape"].tolist())

assert same_shape_output == same_shape_expected_output
assert variable_shape_output == variable_shape_expected_output


if __name__ == '__main__':
test_bucket_batch_invalid_input()
test_bucket_batch_multi_bucket_no_padding()
test_bucket_batch_multi_bucket_with_padding()
test_bucket_batch_single_bucket_no_padding()
test_bucket_batch_single_bucket_with_padding()
test_bucket_batch_pad_to_bucket_boundary()
test_bucket_batch_default_pad()
test_bucket_batch_drop_remainder()
test_bucket_batch_default_length_function()
test_bucket_batch_multi_column()

Loading…
Cancel
Save