Signed-off-by: alex-yuyue <yue.yu1@huawei.com>tags/v1.1.0
| @@ -192,15 +192,45 @@ int64_t Dataset::GetDatasetSize() { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->Init(ds); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | |||
| return -1; | |||
| if (!tree_getters_->isInitialized()) { | |||
| rc = tree_getters_->Init(ds); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | |||
| return -1; | |||
| } | |||
| } | |||
| rc = tree_getters_->GetDatasetSize(&dataset_size); | |||
| return rc.IsError() ? -1 : dataset_size; | |||
| } | |||
| std::vector<DataType> Dataset::GetOutputTypes() { | |||
| std::vector<DataType> types; | |||
| Status s; | |||
| if (!tree_getters_->isInitialized()) { | |||
| s = tree_getters_->Init(shared_from_this()); | |||
| if (s.IsError()) { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||
| return types; | |||
| } | |||
| } | |||
| tree_getters_->GetOutputTypes(&types); | |||
| return types; | |||
| } | |||
| std::vector<TensorShape> Dataset::GetOutputShapes() { | |||
| std::vector<TensorShape> shapes; | |||
| Status s; | |||
| if (!tree_getters_->isInitialized()) { | |||
| s = tree_getters_->Init(shared_from_this()); | |||
| if (s.IsError()) { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||
| return shapes; | |||
| } | |||
| } | |||
| tree_getters_->GetOutputShapes(&shapes); | |||
| return shapes; | |||
| } | |||
| // Constructor to initialize the cache | |||
| Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; } | |||
| @@ -351,12 +351,27 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape & | |||
| } | |||
| #endif | |||
| TreeGetters::TreeGetters() { | |||
| TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) { | |||
| tree_adapter_ = std::make_unique<TreeAdapter>(); | |||
| dataset_size_ = -1; | |||
| } | |||
| Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); } | |||
| Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { | |||
| Status s = tree_adapter_->BuildAndPrepare(std::move(d)); | |||
| if (!s.IsError()) { | |||
| init_flag_ = true; | |||
| } | |||
| return s; | |||
| } | |||
| bool TreeGetters::isInitialized() { return init_flag_; } | |||
| Status TreeGetters::GetRow(TensorRow *row) { | |||
| if (row_flag_ == false) { | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(row)); | |||
| row_flag_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ == -1) { | |||
| @@ -364,10 +379,10 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); | |||
| dataset_size_ = *dataset_size; | |||
| TensorRow row; | |||
| if (*dataset_size == -1) { | |||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||
| int64_t num_rows = 0; | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | |||
| TensorRow row = row_; | |||
| while (row.size() != 0) { | |||
| num_rows++; | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | |||
| @@ -379,4 +394,22 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { | |||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||
| for (auto ts : row_) { | |||
| DataType dt = ts->type(); | |||
| types->push_back(dt); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) { | |||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||
| for (auto ts : row_) { | |||
| TensorShape t = ts->shape(); | |||
| shapes->push_back(t); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindspore::dataset | |||
| @@ -156,29 +156,17 @@ class TreeGetters : public TreeConsumer { | |||
| TreeGetters(); | |||
| Status Init(std::shared_ptr<api::Dataset> d) override; | |||
| Status GetDatasetSize(int64_t *size); | |||
| Status GetBatchSize(int32_t *batch_size) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetRepeatCount(int32_t *repeat_count) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetNumClasses(int32_t *num_classes) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetOutputShapes(std::vector<TensorShape> *shapes) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetOutputTypes(std::vector<DataType> *types) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetOutputNames(std::vector<std::string> *names) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| Status GetOutputTypes(std::vector<DataType> *types); | |||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | |||
| bool isInitialized(); | |||
| std::string Name() override { return "TreeGetters"; } | |||
| Status GetRow(TensorRow *r); | |||
| private: | |||
| int64_t dataset_size_; | |||
| TensorRow row_; | |||
| bool init_flag_; // indicate whether the tree has initialized | |||
| bool row_flag_; // indicate whether the first row has been stored in row_ | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| @@ -27,7 +27,6 @@ | |||
| #include <vector> | |||
| #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/include/iterator.h" | |||
| @@ -576,6 +575,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||
| /// \return status code | |||
| int64_t GetDatasetSize(); | |||
| /// \brief Gets the output type | |||
| /// \return status code | |||
| std::vector<DataType> GetOutputTypes(); | |||
| /// \brief Gets the output shape | |||
| /// \return status code | |||
| std::vector<TensorShape> GetOutputShapes(); | |||
| /// \brief Setter function for runtime number of workers | |||
| /// \param[in] num_workers The number of threads in this operator | |||
| /// \return Shared pointer to the original object | |||
| @@ -34,6 +34,8 @@ | |||
| using namespace mindspore::dataset::api; | |||
| using mindspore::dataset::Tensor; | |||
| using mindspore::dataset::DataType; | |||
| using mindspore::dataset::TensorShape; | |||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| protected: | |||
| @@ -84,6 +86,33 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) { | |||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCifar10MixGetter) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter."; | |||
| // Create a Cifar10 Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all"); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||
| std::vector<DataType> types = ds->GetOutputTypes(); | |||
| std::vector<TensorShape> shapes = ds->GetOutputShapes(); | |||
| EXPECT_EQ(types.size(), 2); | |||
| EXPECT_EQ(types[0].ToString(), "uint8"); | |||
| EXPECT_EQ(types[1].ToString(), "uint32"); | |||
| EXPECT_EQ(shapes.size(), 2); | |||
| EXPECT_EQ(shapes[0].ToString(), "<32,32,3>"); | |||
| EXPECT_EQ(shapes[1].ToString(), "<>"); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||
| EXPECT_EQ(ds->GetOutputTypes(), types); | |||
| EXPECT_EQ(ds->GetOutputShapes(), shapes); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||
| EXPECT_EQ(ds->GetOutputTypes(), types); | |||
| EXPECT_EQ(ds->GetOutputShapes(), shapes); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestCifar100Dataset) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100Dataset."; | |||