Signed-off-by: alex-yuyue <yue.yu1@huawei.com>tags/v1.1.0
| @@ -192,15 +192,45 @@ int64_t Dataset::GetDatasetSize() { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| rc = tree_getters_->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| if (!tree_getters_->isInitialized()) { | |||||
| rc = tree_getters_->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| } | } | ||||
| rc = tree_getters_->GetDatasetSize(&dataset_size); | rc = tree_getters_->GetDatasetSize(&dataset_size); | ||||
| return rc.IsError() ? -1 : dataset_size; | return rc.IsError() ? -1 : dataset_size; | ||||
| } | } | ||||
| std::vector<DataType> Dataset::GetOutputTypes() { | |||||
| std::vector<DataType> types; | |||||
| Status s; | |||||
| if (!tree_getters_->isInitialized()) { | |||||
| s = tree_getters_->Init(shared_from_this()); | |||||
| if (s.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||||
| return types; | |||||
| } | |||||
| } | |||||
| tree_getters_->GetOutputTypes(&types); | |||||
| return types; | |||||
| } | |||||
| std::vector<TensorShape> Dataset::GetOutputShapes() { | |||||
| std::vector<TensorShape> shapes; | |||||
| Status s; | |||||
| if (!tree_getters_->isInitialized()) { | |||||
| s = tree_getters_->Init(shared_from_this()); | |||||
| if (s.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||||
| return shapes; | |||||
| } | |||||
| } | |||||
| tree_getters_->GetOutputShapes(&shapes); | |||||
| return shapes; | |||||
| } | |||||
| // Constructor to initialize the cache | // Constructor to initialize the cache | ||||
| Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; } | Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; } | ||||
| @@ -351,12 +351,27 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape & | |||||
| } | } | ||||
| #endif | #endif | ||||
| TreeGetters::TreeGetters() { | |||||
| TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) { | |||||
| tree_adapter_ = std::make_unique<TreeAdapter>(); | tree_adapter_ = std::make_unique<TreeAdapter>(); | ||||
| dataset_size_ = -1; | |||||
| } | } | ||||
| Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); } | |||||
| Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { | |||||
| Status s = tree_adapter_->BuildAndPrepare(std::move(d)); | |||||
| if (!s.IsError()) { | |||||
| init_flag_ = true; | |||||
| } | |||||
| return s; | |||||
| } | |||||
| bool TreeGetters::isInitialized() { return init_flag_; } | |||||
| Status TreeGetters::GetRow(TensorRow *row) { | |||||
| if (row_flag_ == false) { | |||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(row)); | |||||
| row_flag_ = true; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | ||||
| if (dataset_size_ == -1) { | if (dataset_size_ == -1) { | ||||
| @@ -364,10 +379,10 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | ||||
| RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); | RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); | ||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| TensorRow row; | |||||
| if (*dataset_size == -1) { | if (*dataset_size == -1) { | ||||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | |||||
| TensorRow row = row_; | |||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| num_rows++; | num_rows++; | ||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | ||||
| @@ -379,4 +394,22 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||||
| *dataset_size = dataset_size_; | *dataset_size = dataset_size_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { | |||||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||||
| for (auto ts : row_) { | |||||
| DataType dt = ts->type(); | |||||
| types->push_back(dt); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) { | |||||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||||
| for (auto ts : row_) { | |||||
| TensorShape t = ts->shape(); | |||||
| shapes->push_back(t); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| @@ -156,29 +156,17 @@ class TreeGetters : public TreeConsumer { | |||||
| TreeGetters(); | TreeGetters(); | ||||
| Status Init(std::shared_ptr<api::Dataset> d) override; | Status Init(std::shared_ptr<api::Dataset> d) override; | ||||
| Status GetDatasetSize(int64_t *size); | Status GetDatasetSize(int64_t *size); | ||||
| Status GetBatchSize(int32_t *batch_size) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetRepeatCount(int32_t *repeat_count) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetNumClasses(int32_t *num_classes) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetOutputShapes(std::vector<TensorShape> *shapes) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetOutputTypes(std::vector<DataType> *types) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetOutputNames(std::vector<std::string> *names) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| Status GetOutputTypes(std::vector<DataType> *types); | |||||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | |||||
| bool isInitialized(); | |||||
| std::string Name() override { return "TreeGetters"; } | std::string Name() override { return "TreeGetters"; } | ||||
| Status GetRow(TensorRow *r); | |||||
| private: | private: | ||||
| int64_t dataset_size_; | int64_t dataset_size_; | ||||
| TensorRow row_; | |||||
| bool init_flag_; // indicate whether the tree has initialized | |||||
| bool row_flag_; // indicate whether the first row has been stored in row_ | |||||
| }; | }; | ||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| @@ -27,7 +27,6 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | ||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | #include "minddata/dataset/engine/consumers/tree_consumer.h" | ||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/include/iterator.h" | #include "minddata/dataset/include/iterator.h" | ||||
| @@ -576,6 +575,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \return status code | /// \return status code | ||||
| int64_t GetDatasetSize(); | int64_t GetDatasetSize(); | ||||
| /// \brief Gets the output type | |||||
| /// \return status code | |||||
| std::vector<DataType> GetOutputTypes(); | |||||
| /// \brief Gets the output shape | |||||
| /// \return status code | |||||
| std::vector<TensorShape> GetOutputShapes(); | |||||
| /// \brief Setter function for runtime number of workers | /// \brief Setter function for runtime number of workers | ||||
| /// \param[in] num_workers The number of threads in this operator | /// \param[in] num_workers The number of threads in this operator | ||||
| /// \return Shared pointer to the original object | /// \return Shared pointer to the original object | ||||
| @@ -34,6 +34,8 @@ | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| using mindspore::dataset::DataType; | |||||
| using mindspore::dataset::TensorShape; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | class MindDataTestPipeline : public UT::DatasetOpTesting { | ||||
| protected: | protected: | ||||
| @@ -84,6 +86,33 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) { | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | EXPECT_EQ(ds->GetDatasetSize(), 10000); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestCifar10MixGetter) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter."; | |||||
| // Create a Cifar10 Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all"); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||||
| std::vector<DataType> types = ds->GetOutputTypes(); | |||||
| std::vector<TensorShape> shapes = ds->GetOutputShapes(); | |||||
| EXPECT_EQ(types.size(), 2); | |||||
| EXPECT_EQ(types[0].ToString(), "uint8"); | |||||
| EXPECT_EQ(types[1].ToString(), "uint32"); | |||||
| EXPECT_EQ(shapes.size(), 2); | |||||
| EXPECT_EQ(shapes[0].ToString(), "<32,32,3>"); | |||||
| EXPECT_EQ(shapes[1].ToString(), "<>"); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||||
| EXPECT_EQ(ds->GetOutputTypes(), types); | |||||
| EXPECT_EQ(ds->GetOutputShapes(), shapes); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||||
| EXPECT_EQ(ds->GetOutputTypes(), types); | |||||
| EXPECT_EQ(ds->GetOutputShapes(), shapes); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestCifar100Dataset) { | TEST_F(MindDataTestPipeline, TestCifar100Dataset) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset."; | ||||