Browse Source

!5929 BucketBatchByLength column issue

Merge pull request !5929 from MahdiRahmaniHanzaki/bucket_batch_by_length_fix
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a778868a5a
2 changed files with 79 additions and 2 deletions
  1. +11
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc
  2. +68
    -0
      tests/ut/python/dataset/test_bucket_batch_by_length.py

+ 11
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc View File

@@ -155,8 +155,17 @@ Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, T
// call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of
// the single column specified in length_dependent_columns_ // the single column specified in length_dependent_columns_
if (element_length_function_) { if (element_length_function_) {
TensorRow output;
RETURN_IF_NOT_OK(element_length_function_->Compute(element, &output));
TensorRow input, output;
size_t number_of_arguments = length_dependent_columns_.size();
for (size_t i = 0; i < number_of_arguments; i++) {
auto map_item = column_name_id_map_.find(length_dependent_columns_[i]);
if (map_item == column_name_id_map_.end()) {
RETURN_STATUS_UNEXPECTED("BucketBatchByLength: Couldn't find the specified column in the dataset");
}
int32_t column_index = map_item->second;
input.push_back(element[column_index]);
}
RETURN_IF_NOT_OK(element_length_function_->Compute(input, &output));
RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_element_length, {0})); RETURN_IF_NOT_OK(output.at(0)->GetItemAt(out_element_length, {0}));
if (*out_element_length < 0) { if (*out_element_length < 0) {
RETURN_STATUS_UNEXPECTED("BucketBatchByLength: element_length_function returned negative integer"); RETURN_STATUS_UNEXPECTED("BucketBatchByLength: element_length_function returned negative integer");


+ 68
- 0
tests/ut/python/dataset/test_bucket_batch_by_length.py View File

@@ -36,6 +36,11 @@ def generate_2_columns(n):
yield (np.array([i]), np.array([j for j in range(i + 1)])) yield (np.array([i]), np.array([j for j in range(i + 1)]))




def generate_3_columns(n):
for i in range(n):
yield (np.array([i]), np.array([i + 1]), np.array([j for j in range(i + 1)]))


def test_bucket_batch_invalid_input(): def test_bucket_batch_invalid_input():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])


@@ -382,6 +387,48 @@ def test_bucket_batch_multi_column():
assert same_shape_output == same_shape_expected_output assert same_shape_output == same_shape_expected_output
assert variable_shape_output == variable_shape_expected_output assert variable_shape_output == variable_shape_expected_output



def test_bucket_batch_three_columns():
dataset = ds.GeneratorDataset((lambda: generate_3_columns(10)), ["same_shape", "same_shape2", "variable_shape"])

column_names = ["same_shape2"]
bucket_boundaries = [6, 12]
bucket_batch_sizes = [5, 5, 1]
element_length_function = (lambda x: x[0] % 3)
pad_info = {}

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function,
pad_info)

same_shape_expected_output = [[[0], [1], [2], [3], [4]],
[[5], [6], [7], [8], [9]]]
same_shape2_expected_output = [[[1], [2], [3], [4], [5]],
[[6], [7], [8], [9], [10]]]
variable_shape_expected_output = [[[0, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 1, 2, 0, 0],
[0, 1, 2, 3, 0],
[0, 1, 2, 3, 4]],
[[0, 1, 2, 3, 4, 5, 0, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 0, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]

same_shape_output = []
same_shape2_output = []
variable_shape_output = []
for data in dataset.create_dict_iterator(num_epochs=1):
same_shape_output.append(data["same_shape"].tolist())
same_shape2_output.append(data["same_shape2"].tolist())
variable_shape_output.append(data["variable_shape"].tolist())

assert same_shape_output == same_shape_expected_output
assert same_shape2_output == same_shape2_expected_output
assert variable_shape_output == variable_shape_expected_output


def test_bucket_batch_get_dataset_size(): def test_bucket_batch_get_dataset_size():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])


@@ -402,6 +449,25 @@ def test_bucket_batch_get_dataset_size():
assert data_size == num_rows assert data_size == num_rows




def test_bucket_batch_invalid_column():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])

column_names = ["invalid_column"]
bucket_boundaries = [1, 2, 3]
bucket_batch_sizes = [3, 3, 2, 2]
element_length_function = (lambda x: x[0] % 4)

dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function)

with pytest.raises(RuntimeError) as info:
num_rows = 0
for _ in dataset.create_dict_iterator(num_epochs=1):
num_rows += 1

assert "BucketBatchByLength: Couldn't find the specified column in the dataset" in str(info.value)


if __name__ == '__main__': if __name__ == '__main__':
test_bucket_batch_invalid_input() test_bucket_batch_invalid_input()
test_bucket_batch_multi_bucket_no_padding() test_bucket_batch_multi_bucket_no_padding()
@@ -413,4 +479,6 @@ if __name__ == '__main__':
test_bucket_batch_drop_remainder() test_bucket_batch_drop_remainder()
test_bucket_batch_default_length_function() test_bucket_batch_default_length_function()
test_bucket_batch_multi_column() test_bucket_batch_multi_column()
test_bucket_batch_three_columns()
test_bucket_batch_get_dataset_size() test_bucket_batch_get_dataset_size()
test_bucket_batch_invalid_column()

Loading…
Cancel
Save