Merge pull request !7734 from Alex Yuyue/IR_dataset_inputtags/v1.1.0
| @@ -257,32 +257,79 @@ int64_t Dataset::GetDatasetSize() { | |||||
| std::vector<DataType> Dataset::GetOutputTypes() { | std::vector<DataType> Dataset::GetOutputTypes() { | ||||
| std::vector<DataType> types; | std::vector<DataType> types; | ||||
| Status s; | |||||
| Status rc; | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; | |||||
| types.clear(); | |||||
| return types; | |||||
| } | |||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| s = tree_getters_->Init(shared_from_this()); | |||||
| if (s.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||||
| rc = tree_getters_->Init(shared_from_this()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed."; | |||||
| types.clear(); | |||||
| return types; | return types; | ||||
| } | } | ||||
| } | } | ||||
| tree_getters_->GetOutputTypes(&types); | |||||
| rc = tree_getters_->GetOutputTypes(&types); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed."; | |||||
| types.clear(); | |||||
| return types; | |||||
| } | |||||
| return types; | return types; | ||||
| } | } | ||||
| std::vector<TensorShape> Dataset::GetOutputShapes() { | std::vector<TensorShape> Dataset::GetOutputShapes() { | ||||
| std::vector<TensorShape> shapes; | std::vector<TensorShape> shapes; | ||||
| Status s; | |||||
| Status rc; | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; | |||||
| shapes.clear(); | |||||
| return shapes; | |||||
| } | |||||
| if (!tree_getters_->isInitialized()) { | if (!tree_getters_->isInitialized()) { | ||||
| s = tree_getters_->Init(shared_from_this()); | |||||
| if (s.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||||
| rc = tree_getters_->Init(shared_from_this()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed."; | |||||
| shapes.clear(); | |||||
| return shapes; | return shapes; | ||||
| } | } | ||||
| } | } | ||||
| tree_getters_->GetOutputShapes(&shapes); | |||||
| rc = tree_getters_->GetOutputShapes(&shapes); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed."; | |||||
| shapes.clear(); | |||||
| return shapes; | |||||
| } | |||||
| return shapes; | return shapes; | ||||
| } | } | ||||
| int64_t Dataset::GetNumClasses() { | |||||
| int64_t num_classes; | |||||
| auto ds = shared_from_this(); | |||||
| Status rc; | |||||
| std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>(); | |||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; | |||||
| return -1; | |||||
| } | |||||
| if (!tree_getters_->isInitialized()) { | |||||
| rc = tree_getters_->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| } | |||||
| rc = tree_getters_->GetNumClasses(&num_classes); | |||||
| return rc.IsError() ? -1 : num_classes; | |||||
| } | |||||
| // Constructor to initialize the cache | // Constructor to initialize the cache | ||||
| Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; } | Dataset::Dataset(const std::shared_ptr<DatasetCache> &dataset_cache) : Dataset() { cache_ = dataset_cache; } | ||||
| @@ -656,6 +703,7 @@ Status Dataset::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t Dataset::GetBatchSize() { | int64_t Dataset::GetBatchSize() { | ||||
| int64_t batch_size; | int64_t batch_size; | ||||
| auto ds = shared_from_this(); | auto ds = shared_from_this(); | ||||
| @@ -666,14 +714,17 @@ int64_t Dataset::GetBatchSize() { | |||||
| MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; | MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| rc = tree_getters_->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| if (!tree_getters_->isInitialized()) { | |||||
| rc = tree_getters_->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| } | } | ||||
| rc = tree_getters_->GetBatchSize(&batch_size); | rc = tree_getters_->GetBatchSize(&batch_size); | ||||
| return rc.IsError() ? -1 : batch_size; | return rc.IsError() ? -1 : batch_size; | ||||
| } | } | ||||
| int64_t Dataset::GetRepeatCount() { | int64_t Dataset::GetRepeatCount() { | ||||
| int64_t repeat_count; | int64_t repeat_count; | ||||
| auto ds = shared_from_this(); | auto ds = shared_from_this(); | ||||
| @@ -684,10 +735,12 @@ int64_t Dataset::GetRepeatCount() { | |||||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; | MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| rc = tree_getters_->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| if (!tree_getters_->isInitialized()) { | |||||
| rc = tree_getters_->Init(ds); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| } | } | ||||
| rc = tree_getters_->GetRepeatCount(&repeat_count); | rc = tree_getters_->GetRepeatCount(&repeat_count); | ||||
| return rc.IsError() ? 0 : repeat_count; | return rc.IsError() ? 0 : repeat_count; | ||||
| @@ -444,10 +444,18 @@ Status TreeGetters::GetBatchSize(int64_t *batch_size) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size."); | CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size."); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetRepeatCount(int64_t *repeat_count) { | Status TreeGetters::GetRepeatCount(int64_t *repeat_count) { | ||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | ||||
| *repeat_count = root->GetTreeRepeatCount(); | *repeat_count = root->GetTreeRepeatCount(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetNumClasses(int64_t *num_classes) { | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| @@ -164,6 +164,7 @@ class TreeGetters : public TreeConsumer { | |||||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | Status GetOutputShapes(std::vector<TensorShape> *shapes); | ||||
| Status GetBatchSize(int64_t *batch_size); | Status GetBatchSize(int64_t *batch_size); | ||||
| Status GetRepeatCount(int64_t *repeat_count); | Status GetRepeatCount(int64_t *repeat_count); | ||||
| Status GetNumClasses(int64_t *num_classes); | |||||
| bool isInitialized(); | bool isInitialized(); | ||||
| std::string Name() override { return "TreeGetters"; } | std::string Name() override { return "TreeGetters"; } | ||||
| Status GetRow(TensorRow *r); | Status GetRow(TensorRow *r); | ||||
| @@ -51,7 +51,8 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler | |||||
| op_current_repeats_(0), | op_current_repeats_(0), | ||||
| op_current_epochs_(0), | op_current_epochs_(0), | ||||
| out_connector_(nullptr), | out_connector_(nullptr), | ||||
| dataset_size_(-1) { | |||||
| dataset_size_(-1), | |||||
| num_classes_(-1) { | |||||
| // The operator starts out with an invalid operator id. The only way to | // The operator starts out with an invalid operator id. The only way to | ||||
| // get it out of invalid state is to assign the operator to an execution tree. | // get it out of invalid state is to assign the operator to an execution tree. | ||||
| } | } | ||||
| @@ -302,6 +303,19 @@ Status DatasetOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| return child_[0]->GetDatasetSize(dataset_size); | return child_[0]->GetDatasetSize(dataset_size); | ||||
| } | } | ||||
| // Gets the number of classes | |||||
| Status DatasetOp::GetNumClasses(int64_t *num_classes) { | |||||
| if (num_classes_ > 0) { | |||||
| *num_classes = num_classes_; | |||||
| return Status::OK(); | |||||
| } | |||||
| if (!child_.empty()) { | |||||
| return child_[0]->GetNumClasses(num_classes); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Can't get the dataset size for the current tree."); | |||||
| } | |||||
| } | |||||
| // Performs handling for when an eoe message is received. | // Performs handling for when an eoe message is received. | ||||
| // The base class implementation simply flows the eoe message to output. Derived classes | // The base class implementation simply flows the eoe message to output. Derived classes | ||||
| // may override if they need to perform special eoe handling. | // may override if they need to perform special eoe handling. | ||||
| @@ -191,6 +191,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \return Status - The status code return | /// \return Status - The status code return | ||||
| virtual int64_t GetTreeRepeatCount(); | virtual int64_t GetTreeRepeatCount(); | ||||
| /// \brief Gets the number of classes | |||||
| /// \return Status - The status code return | |||||
| virtual Status GetNumClasses(int64_t *num_classes); | |||||
| /// \brief Performs handling for when an eoe message is received. | /// \brief Performs handling for when an eoe message is received. | ||||
| /// The base class implementation simply flows the eoe message to output. Derived classes | /// The base class implementation simply flows the eoe message to output. Derived classes | ||||
| /// may override if they need to perform special eoe handling. | /// may override if they need to perform special eoe handling. | ||||
| @@ -419,6 +423,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| std::mutex column_name_map_mutex_; // For protecting shared access to the column map | std::mutex column_name_map_mutex_; // For protecting shared access to the column map | ||||
| CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp | CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp | ||||
| int64_t dataset_size_; // Size of the dataset | int64_t dataset_size_; // Size of the dataset | ||||
| int64_t num_classes_; // Number of classes | |||||
| private: | private: | ||||
| /// Sets the operator id. | /// Sets the operator id. | ||||
| @@ -468,5 +468,17 @@ Status ImageFolderOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Get number of classes | |||||
| Status ImageFolderOp::GetNumClasses(int64_t *num_classes) { | |||||
| if (num_classes_ > 0) { | |||||
| *num_classes = num_classes_; | |||||
| return Status::OK(); | |||||
| } | |||||
| int64_t num_rows = num_rows_; | |||||
| RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, extensions_, &num_rows, num_classes)); | |||||
| num_classes_ = *num_classes; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -222,6 +222,11 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status GetDatasetSize(int64_t *dataset_size) override; | Status GetDatasetSize(int64_t *dataset_size) override; | ||||
| /// \brief Base-class override for GetNumClasses | |||||
| /// \param[out] num_classes the number of classes | |||||
| /// \return Status of the function | |||||
| Status GetNumClasses(int64_t *num_classes) override; | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <set> | |||||
| #include <nlohmann/json.hpp> | #include <nlohmann/json.hpp> | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| @@ -297,6 +298,7 @@ Status ManifestOp::ParseManifestFile() { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Manifest file: " + file_); | RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Manifest file: " + file_); | ||||
| } | } | ||||
| std::string line; | std::string line; | ||||
| std::set<std::string> classes; | |||||
| while (getline(file_handle, line)) { | while (getline(file_handle, line)) { | ||||
| try { | try { | ||||
| nlohmann::json js = nlohmann::json::parse(line); | nlohmann::json js = nlohmann::json::parse(line); | ||||
| @@ -317,6 +319,7 @@ Status ManifestOp::ParseManifestFile() { | |||||
| for (nlohmann::json::iterator it = annotations.begin(); it != annotations.end(); ++it) { | for (nlohmann::json::iterator it = annotations.begin(); it != annotations.end(); ++it) { | ||||
| nlohmann::json annotation = it.value(); | nlohmann::json annotation = it.value(); | ||||
| std::string label_name = annotation.value("name", ""); | std::string label_name = annotation.value("name", ""); | ||||
| classes.insert(label_name); | |||||
| if (label_name == "") { | if (label_name == "") { | ||||
| file_handle.close(); | file_handle.close(); | ||||
| RETURN_STATUS_UNEXPECTED("Invalid data, label name is not found in Manifest file: " + image_file_path); | RETURN_STATUS_UNEXPECTED("Invalid data, label name is not found in Manifest file: " + image_file_path); | ||||
| @@ -336,6 +339,7 @@ Status ManifestOp::ParseManifestFile() { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse manifest file: " + line); | RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse manifest file: " + line); | ||||
| } | } | ||||
| } | } | ||||
| num_classes_ = classes.size(); | |||||
| file_handle.close(); | file_handle.close(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -471,5 +475,18 @@ Status ManifestOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Get number of classes | |||||
| Status ManifestOp::GetNumClasses(int64_t *num_classes) { | |||||
| if (num_classes_ > 0) { | |||||
| *num_classes = num_classes_; | |||||
| return Status::OK(); | |||||
| } | |||||
| std::shared_ptr<ManifestOp> op; | |||||
| RETURN_IF_NOT_OK(Builder().SetManifestFile(file_).SetClassIndex(class_index_).SetUsage(usage_).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->ParseManifestFile()); | |||||
| *num_classes = num_classes_; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -188,6 +188,11 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| /// \return Status of the function | /// \return Status of the function | ||||
| Status GetDatasetSize(int64_t *dataset_size) override; | Status GetDatasetSize(int64_t *dataset_size) override; | ||||
| /// \brief Base-class override for GetNumClasses | |||||
| /// \param[out] num_classes the number of classes | |||||
| /// \return Status of the function | |||||
| Status GetNumClasses(int64_t *num_classes) override; | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| @@ -589,15 +589,15 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| } | } | ||||
| /// \brief Gets the dataset size | /// \brief Gets the dataset size | ||||
| /// \return int64_t | |||||
| /// \return dataset size. If failed, return -1 | |||||
| int64_t GetDatasetSize(); | int64_t GetDatasetSize(); | ||||
| /// \brief Gets the output type | /// \brief Gets the output type | ||||
| /// \return vector of DataType | |||||
| /// \return a vector of DataType. If failed, return an empty vector | |||||
| std::vector<DataType> GetOutputTypes(); | std::vector<DataType> GetOutputTypes(); | ||||
| /// \brief Gets the output shape | /// \brief Gets the output shape | ||||
| /// \return vector of TensorShapes | |||||
| /// \return a vector of TensorShape. If failed, return am empty vector | |||||
| std::vector<TensorShape> GetOutputShapes(); | std::vector<TensorShape> GetOutputShapes(); | ||||
| /// \brief Gets the batch size | /// \brief Gets the batch size | ||||
| @@ -608,6 +608,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \return int64_t | /// \return int64_t | ||||
| int64_t GetRepeatCount(); | int64_t GetRepeatCount(); | ||||
| /// \brief Gets the number of classes | |||||
| /// \return number of classes. If failed, return -1 | |||||
| int64_t GetNumClasses(); | |||||
| /// \brief Setter function for runtime number of workers | /// \brief Setter function for runtime number of workers | ||||
| /// \param[in] num_workers The number of threads in this operator | /// \param[in] num_workers The number of threads in this operator | ||||
| /// \return Shared pointer to the original object | /// \return Shared pointer to the original object | ||||
| @@ -69,6 +69,26 @@ TEST_F(MindDataTestPipeline, TestAlbumBasic) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestAlbumgetters) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumgetters."; | |||||
| std::string folder_path = datasets_root_path_ + "/testAlbum/images"; | |||||
| std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json"; | |||||
| std::vector<std::string> column_names = {"image", "label", "id"}; | |||||
| // Create a Album Dataset | |||||
| std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| int64_t dataset_size = ds->GetDatasetSize(); | |||||
| EXPECT_EQ(dataset_size, 7); | |||||
| int64_t num_classes = ds->GetNumClasses(); | |||||
| EXPECT_EQ(num_classes, -1); | |||||
| int64_t batch_size = ds->GetBatchSize(); | |||||
| EXPECT_EQ(batch_size, 1); | |||||
| int64_t repeat_count = ds->GetRepeatCount(); | |||||
| EXPECT_EQ(repeat_count, 1); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestAlbumDecode) { | TEST_F(MindDataTestPipeline, TestAlbumDecode) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumDecode."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumDecode."; | ||||
| std::string folder_path = datasets_root_path_ + "/testAlbum/images"; | std::string folder_path = datasets_root_path_ + "/testAlbum/images"; | ||||
| @@ -86,7 +86,7 @@ TEST_F(MindDataTestPipeline, TestCifar10GetDatasetSize) { | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | EXPECT_EQ(ds->GetDatasetSize(), 10000); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestCifar10MixGetter) { | |||||
| TEST_F(MindDataTestPipeline, TestCifar10Getters) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10MixGetter."; | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| @@ -97,19 +97,28 @@ TEST_F(MindDataTestPipeline, TestCifar10MixGetter) { | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | EXPECT_EQ(ds->GetDatasetSize(), 10000); | ||||
| std::vector<DataType> types = ds->GetOutputTypes(); | std::vector<DataType> types = ds->GetOutputTypes(); | ||||
| std::vector<TensorShape> shapes = ds->GetOutputShapes(); | std::vector<TensorShape> shapes = ds->GetOutputShapes(); | ||||
| int64_t num_classes = ds->GetNumClasses(); | |||||
| EXPECT_EQ(types.size(), 2); | EXPECT_EQ(types.size(), 2); | ||||
| EXPECT_EQ(types[0].ToString(), "uint8"); | EXPECT_EQ(types[0].ToString(), "uint8"); | ||||
| EXPECT_EQ(types[1].ToString(), "uint32"); | EXPECT_EQ(types[1].ToString(), "uint32"); | ||||
| EXPECT_EQ(shapes.size(), 2); | EXPECT_EQ(shapes.size(), 2); | ||||
| EXPECT_EQ(shapes[0].ToString(), "<32,32,3>"); | EXPECT_EQ(shapes[0].ToString(), "<32,32,3>"); | ||||
| EXPECT_EQ(shapes[1].ToString(), "<>"); | EXPECT_EQ(shapes[1].ToString(), "<>"); | ||||
| EXPECT_EQ(num_classes, -1); | |||||
| EXPECT_EQ(ds->GetBatchSize(), 1); | |||||
| EXPECT_EQ(ds->GetRepeatCount(), 1); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | EXPECT_EQ(ds->GetDatasetSize(), 10000); | ||||
| EXPECT_EQ(ds->GetOutputTypes(), types); | EXPECT_EQ(ds->GetOutputTypes(), types); | ||||
| EXPECT_EQ(ds->GetOutputShapes(), shapes); | EXPECT_EQ(ds->GetOutputShapes(), shapes); | ||||
| EXPECT_EQ(ds->GetNumClasses(), -1); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | EXPECT_EQ(ds->GetDatasetSize(), 10000); | ||||
| EXPECT_EQ(ds->GetOutputTypes(), types); | EXPECT_EQ(ds->GetOutputTypes(), types); | ||||
| EXPECT_EQ(ds->GetOutputShapes(), shapes); | EXPECT_EQ(ds->GetOutputShapes(), shapes); | ||||
| EXPECT_EQ(ds->GetBatchSize(), 1); | |||||
| EXPECT_EQ(ds->GetRepeatCount(), 1); | |||||
| EXPECT_EQ(ds->GetNumClasses(), -1); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 10000); | EXPECT_EQ(ds->GetDatasetSize(), 10000); | ||||
| } | } | ||||
| @@ -67,15 +67,22 @@ TEST_F(MindDataTestPipeline, TestManifestBasic) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestManifestGetDatasetSize) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestGetDatasetSize."; | |||||
| TEST_F(MindDataTestPipeline, TestManifestGetters) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestManifestGetters."; | |||||
| std::string file_path = datasets_root_path_ + "/testManifestData/cpp.json"; | |||||
| std::string file_path1 = datasets_root_path_ + "/testManifestData/cpp.json"; | |||||
| std::string file_path2 = datasets_root_path_ + "/testManifestData/cpp2.json"; | |||||
| // Create a Manifest Dataset | // Create a Manifest Dataset | ||||
| std::shared_ptr<Dataset> ds = Manifest(file_path); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| std::shared_ptr<Dataset> ds1 = Manifest(file_path1); | |||||
| std::shared_ptr<Dataset> ds2 = Manifest(file_path2); | |||||
| EXPECT_NE(ds1, nullptr); | |||||
| EXPECT_EQ(ds1->GetDatasetSize(), 2); | |||||
| EXPECT_EQ(ds1->GetNumClasses(), 2); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 2); | |||||
| EXPECT_NE(ds2, nullptr); | |||||
| EXPECT_EQ(ds2->GetDatasetSize(), 4); | |||||
| EXPECT_EQ(ds2->GetNumClasses(), 3); | |||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestManifestDecode) { | TEST_F(MindDataTestPipeline, TestManifestDecode) { | ||||
| @@ -221,7 +221,7 @@ TEST_F(MindDataTestPipeline, TestImageFolderFailWithWrongExtension) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestImageFolderGetDatasetSize) { | |||||
| TEST_F(MindDataTestPipeline, TestImageFolderGetters) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderGetDatasetSize."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestImageFolderGetDatasetSize."; | ||||
| // Create an ImageFolder Dataset | // Create an ImageFolder Dataset | ||||
| @@ -230,6 +230,10 @@ TEST_F(MindDataTestPipeline, TestImageFolderGetDatasetSize) { | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| EXPECT_EQ(ds->GetDatasetSize(), 44); | EXPECT_EQ(ds->GetDatasetSize(), 44); | ||||
| EXPECT_EQ(ds->GetNumClasses(), 4); | |||||
| EXPECT_EQ(ds->GetNumClasses(), 4); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 44); | |||||
| EXPECT_EQ(ds->GetDatasetSize(), 44); | |||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestImageFolderFailWithNullSampler) { | TEST_F(MindDataTestPipeline, TestImageFolderFailWithNullSampler) { | ||||
| @@ -0,0 +1,6 @@ | |||||
| {"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "dog","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"} | |||||
| {"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"} | |||||
| {"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"} | |||||
| {"source":"./data/dataset/testManifestData/train/1.JPEG", "usage":"TRAIN","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"},{"type": "modelarts/image_classification","name": "flower","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"} | |||||
| {"source":"./data/dataset/testManifestData/eval/1.JPEG", "usage":"EVAL","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "cat","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"} | |||||
| {"source":"./data/dataset/testManifestData/eval/2.JPEG", "usage":"EVAL","id":"0162005993f8065ef47eefb59d1e4970","annotation": [{"type": "modelarts/image_classification","name": "dog","property": {"color":"white","kind":"Persian cat"},"hard":"true","hard-coefficient":0.8,"annotated-by":"human","creation-time":"2019-01-23 11:30:30"}],"inference-loc":"/path/to/inference-output"} | |||||