Browse Source

fix get_shape/type and a segfault in getter_pass

add test case
tags/v1.1.0
Zirui Wu 5 years ago
parent
commit
23487fae09
6 changed files with 49 additions and 17 deletions
  1. +23
    -14
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  2. +6
    -2
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  3. +2
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
  4. +6
    -1
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.cc
  5. +8
    -0
      tests/ut/python/dataset/test_dataset_numpy_slices.py
  6. +4
    -0
      tests/ut/python/dataset/test_datasets_get_dataset_size.py

+ 23
- 14
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -30,8 +30,8 @@
#include "minddata/mindrecord/include/shard_writer.h"
#endif

namespace mindspore::dataset {
namespace mindspore {
namespace dataset {
// TreeConsumer
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }

@@ -440,7 +440,9 @@ Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &
}
#endif

TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); }
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), first_row_obtained_(false) {
tree_adapter_ = std::make_unique<TreeAdapter>();
}

Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
root_ = std::move(d);
@@ -473,20 +475,14 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
}

Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));

std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types),
[](const TensorPtr &t) { return t->type(); });
RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
*types = first_row_type_;
return Status::OK();
}

Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));

std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes),
[](const TensorPtr &t) { return t->shape(); });
RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
*shapes = first_row_shape_;
return Status::OK();
}

@@ -555,6 +551,18 @@ Status TreeGetters::InternalInit() {
if (!s.IsError()) init_flag_ = true;
return s;
}
Status TreeGetters::GetFirstRowShapeAndType() {
RETURN_OK_IF_TRUE(first_row_obtained_);
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
TensorRow first_row;
RETURN_IF_NOT_OK(GetRow(&first_row));
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_),
[](const TensorPtr &t) { return t->type(); });
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_shape_),
[](const TensorPtr &t) { return t->shape(); });
first_row_obtained_ = true;
return Status::OK();
}
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }

Status BuildVocabConsumer::Start() {
@@ -565,4 +573,5 @@ Status BuildVocabConsumer::Start() {
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE.");
return Status::OK();
}
} // namespace mindspore::dataset
} // namespace dataset
} // namespace mindspore

+ 6
- 2
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -189,10 +189,14 @@ class TreeGetters : public TreeConsumer {
virtual Status GetRow(TensorRow *r);

private:
Status GetFirstRowShapeAndType();

std::shared_ptr<DatasetNode> root_;
int64_t dataset_size_;
TensorRow first_row_;
bool init_flag_; // indicate whether the tree has initialized
std::vector<DataType> first_row_type_;
std::vector<TensorShape> first_row_shape_;
bool first_row_obtained_; // whether first row (which could be empty) is obtained by TreeGetter
bool init_flag_; // indicate whether the tree has initialized

Status InternalInit(int8_t type);
Status InternalInit();


+ 2
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc View File

@@ -151,6 +151,8 @@ Status DatasetOp::Remove() {
if (!child_.empty()) {
// If we have a parent, then assign child's parent to point to our parent.
if (!parent_.empty()) {
CHECK_FAIL_RETURN_UNEXPECTED(parent_[0]->Children().size() == 1,
"Removing a node whose parent has more than 1 child is not supported.");
child_[0]->parent_[0] = parent_[0];
} else {
// We don't have a parent, so we are the root node being removed.


+ 6
- 1
mindspore/ccsrc/minddata/dataset/engine/opt/pre/getter_pass.cc View File

@@ -56,7 +56,12 @@ Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) {

// nested private class variables can be directly accessed by its outer class
for (auto node : pass_.nodes_to_remove_) {
RETURN_IF_NOT_OK(node->Remove());
DatasetOp *parent;
node->Parent(&parent, 0);
// only remove node whose is a single child of its parent
if (parent != nullptr && parent->Children().size() == 1) {
RETURN_IF_NOT_OK(node->Remove());
}
}

// clear the callback for selected ops (map when its GetOutputType/Shape)


+ 8
- 0
tests/ut/python/dataset/test_dataset_numpy_slices.py View File

@@ -239,6 +239,13 @@ def test_numpy_slices_invalid_empty_data_column():
assert "Argument data cannot be empty" in str(err.value)


def test_numpy_slice_empty_output_shape():
logger.info("running test_numpy_slice_empty_output_shape")
dataset = de.NumpySlicesDataset([[[1, 2], [3, 4]]], column_names=["col1"])
dataset = dataset.batch(batch_size=3, drop_remainder=True)
assert dataset.output_shapes() == []


if __name__ == "__main__":
test_numpy_slices_list_1()
test_numpy_slices_list_2()
@@ -259,3 +266,4 @@ if __name__ == "__main__":
test_numpy_slices_invalid_column_names_string()
test_numpy_slices_invalid_empty_column_names()
test_numpy_slices_invalid_empty_data_column()
test_numpy_slice_empty_output_shape()

+ 4
- 0
tests/ut/python/dataset/test_datasets_get_dataset_size.py View File

@@ -238,6 +238,10 @@ def test_pipeline_get_dataset_size():
dataset = dataset.repeat(count=2)
assert dataset.get_dataset_size() == 8

tf1 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True)
tf2 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, shuffle=True)
assert tf2.concat(tf1).get_dataset_size() == 24


if __name__ == '__main__':
test_imagenet_rawdata_dataset_size()


Loading…
Cancel
Save