Browse Source

Add support for GetOutputTypes and GetOutputShapes

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>
tags/v1.1.0
alex-yuyue 5 years ago
parent
commit
2906659673
5 changed files with 116 additions and 29 deletions
  1. +34
    -4
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +38
    -5
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  3. +7
    -19
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h
  4. +8
    -1
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  5. +29
    -0
      tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc

+ 34
- 4
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -192,15 +192,45 @@ int64_t Dataset::GetDatasetSize() {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
if (!tree_getters_->isInitialized()) {
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
}
}
rc = tree_getters_->GetDatasetSize(&dataset_size);
return rc.IsError() ? -1 : dataset_size;
}

std::vector<DataType> Dataset::GetOutputTypes() {
std::vector<DataType> types;
Status s;
if (!tree_getters_->isInitialized()) {
s = tree_getters_->Init(shared_from_this());
if (s.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return types;
}
}
tree_getters_->GetOutputTypes(&types);
return types;
}

std::vector<TensorShape> Dataset::GetOutputShapes() {
std::vector<TensorShape> shapes;
Status s;
if (!tree_getters_->isInitialized()) {
s = tree_getters_->Init(shared_from_this());
if (s.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return shapes;
}
}
tree_getters_->GetOutputShapes(&shapes);
return shapes;
}

// Constructor to initialize the cache
Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; }



+ 38
- 5
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -351,12 +351,27 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
}
#endif

TreeGetters::TreeGetters() {
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) {
tree_adapter_ = std::make_unique<TreeAdapter>();
dataset_size_ = -1;
}

Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); }
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
if (!s.IsError()) {
init_flag_ = true;
}
return s;
}

bool TreeGetters::isInitialized() { return init_flag_; }

Status TreeGetters::GetRow(TensorRow *row) {
if (row_flag_ == false) {
RETURN_IF_NOT_OK(tree_adapter_->GetNext(row));
row_flag_ = true;
}
return Status::OK();
}

Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ == -1) {
@@ -364,10 +379,10 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
dataset_size_ = *dataset_size;
TensorRow row;
if (*dataset_size == -1) {
RETURN_IF_NOT_OK(GetRow(&row_));
int64_t num_rows = 0;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
TensorRow row = row_;
while (row.size() != 0) {
num_rows++;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
@@ -379,4 +394,22 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
*dataset_size = dataset_size_;
return Status::OK();
}

Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
RETURN_IF_NOT_OK(GetRow(&row_));
for (auto ts : row_) {
DataType dt = ts->type();
types->push_back(dt);
}
return Status::OK();
}

Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
RETURN_IF_NOT_OK(GetRow(&row_));
for (auto ts : row_) {
TensorShape t = ts->shape();
shapes->push_back(t);
}
return Status::OK();
}
} // namespace mindspore::dataset

+ 7
- 19
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h View File

@@ -156,29 +156,17 @@ class TreeGetters : public TreeConsumer {
TreeGetters();
Status Init(std::shared_ptr<api::Dataset> d) override;
Status GetDatasetSize(int64_t *size);
Status GetBatchSize(int32_t *batch_size) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetRepeatCount(int32_t *repeat_count) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetNumClasses(int32_t *num_classes) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetOutputShapes(std::vector<TensorShape> *shapes) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetOutputTypes(std::vector<DataType> *types) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status GetOutputNames(std::vector<std::string> *names) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}

Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
bool isInitialized();
std::string Name() override { return "TreeGetters"; }
Status GetRow(TensorRow *r);

private:
int64_t dataset_size_;
TensorRow row_;
bool init_flag_; // indicate whether the tree has initialized
bool row_flag_; // indicate whether the first row has been stored in row_
};

} // namespace mindspore::dataset


+ 8
- 1
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -27,7 +27,6 @@
#include <vector>
#include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h"
#include "minddata/dataset/core/constants.h"

#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/include/iterator.h"
@@ -576,6 +575,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return status code
int64_t GetDatasetSize();

/// \brief Gets the output type
/// \return status code
std::vector<DataType> GetOutputTypes();

/// \brief Gets the output shape
/// \return status code
std::vector<TensorShape> GetOutputShapes();

/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object


+ 29
- 0
tests/ut/cpp/dataset/c_api_dataset_cifar_test.cc View File

@@ -34,6 +34,8 @@

using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
using mindspore::dataset::DataType;
using mindspore::dataset::TensorShape;

class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
@@ -84,6 +86,33 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) {
EXPECT_EQ(ds->GetDatasetSize(), 10000);
}

TEST_F(MindDataTestPipeline, TestCifar10MixGetter) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter.";

// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all");
EXPECT_NE(ds, nullptr);

EXPECT_EQ(ds->GetDatasetSize(), 10000);
std::vector<DataType> types = ds->GetOutputTypes();
std::vector<TensorShape> shapes = ds->GetOutputShapes();
EXPECT_EQ(types.size(), 2);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(shapes.size(), 2);
EXPECT_EQ(shapes[0].ToString(), "<32,32,3>");
EXPECT_EQ(shapes[1].ToString(), "<>");

EXPECT_EQ(ds->GetDatasetSize(), 10000);
EXPECT_EQ(ds->GetOutputTypes(), types);
EXPECT_EQ(ds->GetOutputShapes(), shapes);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
EXPECT_EQ(ds->GetOutputTypes(), types);
EXPECT_EQ(ds->GetOutputShapes(), shapes);
EXPECT_EQ(ds->GetDatasetSize(), 10000);
}

TEST_F(MindDataTestPipeline, TestCifar100Dataset) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset.";



Loading…
Cancel
Save