Merge pull request !7657 from ZiruiWu/ir_breakdown_leaf_nodestags/v1.1.0
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -4,19 +4,17 @@ add_subdirectory(source) | |||||
| set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | ||||
| batch_node.cc | batch_node.cc | ||||
| bucket_batch_by_length_node.cc | |||||
| build_vocab_node.cc | |||||
| concat_node.cc | concat_node.cc | ||||
| map_node.cc | |||||
| project_node.cc | project_node.cc | ||||
| rename_node.cc | rename_node.cc | ||||
| repeat_node.cc | repeat_node.cc | ||||
| shuffle_node.cc | shuffle_node.cc | ||||
| skip_node.cc | |||||
| take_node.cc | take_node.cc | ||||
| zip_node.cc | |||||
| ) | ) | ||||
| if (NOT ENABLE_ANDROID) | |||||
| set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | |||||
| ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES} | |||||
| bucket_batch_by_length_node.cc | |||||
| build_vocab_node.cc) | |||||
| endif () | |||||
| add_library(engine-ir-datasetops OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES}) | add_library(engine-ir-datasetops OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES}) | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -61,4 +61,4 @@ class BucketBatchByLengthNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -58,4 +58,4 @@ class BuildVocabNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -50,4 +50,4 @@ class ConcatNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_ | |||||
| @@ -0,0 +1,91 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||||
| #include "minddata/dataset/include/transforms.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | |||||
| const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache) | |||||
| : operations_(operations), | |||||
| input_columns_(input_columns), | |||||
| output_columns_(output_columns), | |||||
| project_columns_(project_columns), | |||||
| Dataset(std::move(cache)) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::vector<std::shared_ptr<TensorOp>> tensor_ops; | |||||
| // Build tensorOp from tensorOperation vector | |||||
| // This is to ensure each iterator hold its own copy of the tensorOp objects. | |||||
| (void)std::transform( | |||||
| operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), | |||||
| [](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); }); | |||||
| // This parameter will be removed with next rebase | |||||
| std::vector<std::string> col_orders; | |||||
| auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); | |||||
| if (!project_columns_.empty()) { | |||||
| auto project_op = std::make_shared<ProjectOp>(project_columns_); | |||||
| node_ops.push_back(project_op); | |||||
| } | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(map_op); | |||||
| return node_ops; | |||||
| } | |||||
| Status MapNode::ValidateParams() { | |||||
| if (operations_.empty()) { | |||||
| std::string err_msg = "MapNode: No operation is specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (!input_columns_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_)); | |||||
| } | |||||
| if (!output_columns_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "output_columns", output_columns_)); | |||||
| } | |||||
| if (!project_columns_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "project_columns", project_columns_)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class MapNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | |||||
| const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr); | |||||
| /// \brief Destructor | |||||
| ~MapNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::shared_ptr<TensorOperation>> operations_; | |||||
| std::vector<std::string> input_columns_; | |||||
| std::vector<std::string> output_columns_; | |||||
| std::vector<std::string> project_columns_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -51,4 +51,4 @@ class ProjectNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -53,4 +53,4 @@ class RenameNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -53,4 +53,4 @@ class RepeatNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -49,4 +49,4 @@ class ShuffleNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_ | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for SkipNode | |||||
| SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Function to build the SkipOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to validate the parameters for SkipNode | |||||
| Status SkipNode::ValidateParams() { | |||||
| if (skip_count_ <= -1) { | |||||
| std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class SkipNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit SkipNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| /// \brief Destructor | |||||
| ~SkipNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t skip_count_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_ | |||||
| @@ -2,7 +2,21 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | ||||
| album_node.cc | |||||
| celeba_node.cc | |||||
| cifar100_node.cc | |||||
| cifar10_node.cc | |||||
| clue_node.cc | |||||
| coco_node.cc | |||||
| csv_node.cc | |||||
| image_folder_node.cc | image_folder_node.cc | ||||
| manifest_node.cc | |||||
| minddata_node.cc | |||||
| mnist_node.cc | |||||
| random_node.cc | |||||
| text_file_node.cc | |||||
| tf_record_node.cc | |||||
| voc_node.cc | |||||
| ) | ) | ||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/album_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for AlbumNode | |||||
| AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | |||||
| const std::vector<std::string> &column_names, bool decode, | |||||
| const std::shared_ptr<SamplerObj> &sampler) | |||||
| : dataset_dir_(dataset_dir), | |||||
| schema_path_(data_schema), | |||||
| column_names_(column_names), | |||||
| decode_(decode), | |||||
| sampler_(sampler) {} | |||||
| Status AlbumNode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_})); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumNode", sampler_)); | |||||
| if (!column_names_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumNode", "column_names", column_names_)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build AlbumNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_)); | |||||
| // Argument that is not exposed to user in the API. | |||||
| std::set<std::string> extensions = {}; | |||||
| node_ops.push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||||
| decode_, extensions, std::move(schema), std::move(sampler_->Build()))); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class AlbumNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | |||||
| const std::vector<std::string> &column_names, bool decode, const std::shared_ptr<SamplerObj> &sampler); | |||||
| /// \brief Destructor | |||||
| ~AlbumNode() = default; | |||||
| /// \brief a base class override function to create a runtime dataset op object from this class | |||||
| /// \return shared pointer to the newly created DatasetOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string schema_path_; | |||||
| std::vector<std::string> column_names_; | |||||
| bool decode_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_ | |||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for CelebANode | |||||
| CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | |||||
| const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) | |||||
| : Dataset(cache), | |||||
| dataset_dir_(dataset_dir), | |||||
| usage_(usage), | |||||
| sampler_(sampler), | |||||
| decode_(decode), | |||||
| extensions_(extensions) {} | |||||
| Status CelebANode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_)); | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"})); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build CelebANode | |||||
| std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||||
| // label is like this:0 1 0 0 1...... | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||||
| decode_, usage_, extensions_, std::move(schema), | |||||
| std::move(sampler_->Build()))); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_ | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class CelebANode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache); | |||||
| /// \brief Destructor | |||||
| ~CelebANode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| bool decode_; | |||||
| std::set<std::string> extensions_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_ | |||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for Cifar100Node | |||||
| Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage, | |||||
| std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status Cifar100Node::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_)); | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"})); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build CifarOp for Cifar100 | |||||
| std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| // Do internal Schema generation. | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); | |||||
| TensorShape scalar = TensorShape::CreateScalar(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | |||||
| dataset_dir_, connector_que_size_, std::move(schema), | |||||
| std::move(sampler_->Build()))); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class Cifar100Node : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~Cifar100Node() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_ | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for Cifar10Node | |||||
| Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status Cifar10Node::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_)); | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"})); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build CifarOp for Cifar10 | |||||
| std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| // Do internal Schema generation. | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); | |||||
| TensorShape scalar = TensorShape::CreateScalar(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | |||||
| dataset_dir_, connector_que_size_, std::move(schema), | |||||
| std::move(sampler_->Build()))); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class Cifar10Node : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~Cifar10Node() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_ | |||||
| @@ -0,0 +1,218 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #include <algorithm> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for CLUENode | |||||
| CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples, | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(clue_files), | |||||
| task_(task), | |||||
| usage_(usage), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id) {} | |||||
| Status CLUENode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_)); | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"})); | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"})); | |||||
| if (num_samples_ < 0) { | |||||
| std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUENode", num_shards_, shard_id_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to split string based on a character delimiter | |||||
| std::vector<std::string> CLUENode::split(const std::string &s, char delim) { | |||||
| std::vector<std::string> res; | |||||
| std::stringstream ss(s); | |||||
| std::string item; | |||||
| while (getline(ss, item, delim)) { | |||||
| res.push_back(item); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| // Function to build CLUENode | |||||
| std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::map<std::string, std::string> key_map; | |||||
| if (task_ == "AFQMC") { | |||||
| if (usage_ == "train") { | |||||
| key_map["sentence1"] = "sentence1"; | |||||
| key_map["sentence2"] = "sentence2"; | |||||
| key_map["label"] = "label"; | |||||
| } else if (usage_ == "test") { | |||||
| key_map["id"] = "id"; | |||||
| key_map["sentence1"] = "sentence1"; | |||||
| key_map["sentence2"] = "sentence2"; | |||||
| } else if (usage_ == "eval") { | |||||
| key_map["sentence1"] = "sentence1"; | |||||
| key_map["sentence2"] = "sentence2"; | |||||
| key_map["label"] = "label"; | |||||
| } | |||||
| } else if (task_ == "CMNLI") { | |||||
| if (usage_ == "train") { | |||||
| key_map["sentence1"] = "sentence1"; | |||||
| key_map["sentence2"] = "sentence2"; | |||||
| key_map["label"] = "label"; | |||||
| } else if (usage_ == "test") { | |||||
| key_map["id"] = "id"; | |||||
| key_map["sentence1"] = "sentence1"; | |||||
| key_map["sentence2"] = "sentence2"; | |||||
| } else if (usage_ == "eval") { | |||||
| key_map["sentence1"] = "sentence1"; | |||||
| key_map["sentence2"] = "sentence2"; | |||||
| key_map["label"] = "label"; | |||||
| } | |||||
| } else if (task_ == "CSL") { | |||||
| if (usage_ == "train") { | |||||
| key_map["id"] = "id"; | |||||
| key_map["abst"] = "abst"; | |||||
| key_map["keyword"] = "keyword"; | |||||
| key_map["label"] = "label"; | |||||
| } else if (usage_ == "test") { | |||||
| key_map["id"] = "id"; | |||||
| key_map["abst"] = "abst"; | |||||
| key_map["keyword"] = "keyword"; | |||||
| } else if (usage_ == "eval") { | |||||
| key_map["id"] = "id"; | |||||
| key_map["abst"] = "abst"; | |||||
| key_map["keyword"] = "keyword"; | |||||
| key_map["label"] = "label"; | |||||
| } | |||||
| } else if (task_ == "IFLYTEK") { | |||||
| if (usage_ == "train") { | |||||
| key_map["label"] = "label"; | |||||
| key_map["label_des"] = "label_des"; | |||||
| key_map["sentence"] = "sentence"; | |||||
| } else if (usage_ == "test") { | |||||
| key_map["id"] = "id"; | |||||
| key_map["sentence"] = "sentence"; | |||||
| } else if (usage_ == "eval") { | |||||
| key_map["label"] = "label"; | |||||
| key_map["label_des"] = "label_des"; | |||||
| key_map["sentence"] = "sentence"; | |||||
| } | |||||
| } else if (task_ == "TNEWS") { | |||||
| if (usage_ == "train") { | |||||
| key_map["label"] = "label"; | |||||
| key_map["label_desc"] = "label_desc"; | |||||
| key_map["sentence"] = "sentence"; | |||||
| key_map["keywords"] = "keywords"; | |||||
| } else if (usage_ == "test") { | |||||
| key_map["id"] = "id"; | |||||
| key_map["sentence"] = "sentence"; | |||||
| key_map["keywords"] = "keywords"; | |||||
| } else if (usage_ == "eval") { | |||||
| key_map["label"] = "label"; | |||||
| key_map["label_desc"] = "label_desc"; | |||||
| key_map["sentence"] = "sentence"; | |||||
| key_map["keywords"] = "keywords"; | |||||
| } | |||||
| } else if (task_ == "WSC") { | |||||
| if (usage_ == "train") { | |||||
| key_map["span1_index"] = "target/span1_index"; | |||||
| key_map["span2_index"] = "target/span2_index"; | |||||
| key_map["span1_text"] = "target/span1_text"; | |||||
| key_map["span2_text"] = "target/span2_text"; | |||||
| key_map["idx"] = "idx"; | |||||
| key_map["label"] = "label"; | |||||
| key_map["text"] = "text"; | |||||
| } else if (usage_ == "test") { | |||||
| key_map["span1_index"] = "target/span1_index"; | |||||
| key_map["span2_index"] = "target/span2_index"; | |||||
| key_map["span1_text"] = "target/span1_text"; | |||||
| key_map["span2_text"] = "target/span2_text"; | |||||
| key_map["idx"] = "idx"; | |||||
| key_map["text"] = "text"; | |||||
| } else if (usage_ == "eval") { | |||||
| key_map["span1_index"] = "target/span1_index"; | |||||
| key_map["span2_index"] = "target/span2_index"; | |||||
| key_map["span1_text"] = "target/span1_text"; | |||||
| key_map["span2_text"] = "target/span2_text"; | |||||
| key_map["idx"] = "idx"; | |||||
| key_map["label"] = "label"; | |||||
| key_map["text"] = "text"; | |||||
| } | |||||
| } | |||||
| ColKeyMap ck_map; | |||||
| for (auto &p : key_map) { | |||||
| ck_map.insert({p.first, split(p.second, '/')}); | |||||
| } | |||||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||||
| // Sort the dataset files in a lexicographical order | |||||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||||
| std::shared_ptr<ClueOp> clue_op = | |||||
| std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, | |||||
| sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||||
| RETURN_EMPTY_IF_ERROR(clue_op->Init()); | |||||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||||
| // Inject ShuffleOp | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||||
| int64_t num_rows = 0; | |||||
| // First, get the number of rows in the dataset | |||||
| RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||||
| // Add the shuffle op after this op | |||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| node_ops.push_back(shuffle_op); | |||||
| } | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(clue_op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| /// \class CLUENode | |||||
| /// \brief A Dataset derived class to represent CLUE dataset | |||||
| class CLUENode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~CLUENode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| /// \brief Split string based on a character delimiter | |||||
| /// \return A string vector | |||||
| std::vector<std::string> split(const std::string &s, char delim); | |||||
| std::vector<std::string> dataset_files_; | |||||
| std::string task_; | |||||
| std::string usage_; | |||||
| int64_t num_samples_; | |||||
| ShuffleMode shuffle_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ | |||||
| @@ -0,0 +1,122 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for CocoNode | |||||
| CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | |||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | |||||
| annotation_file_(annotation_file), | |||||
| task_(task), | |||||
| decode_(decode), | |||||
| sampler_(sampler) {} | |||||
| Status CocoNode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_)); | |||||
| Path annotation_file(annotation_file_); | |||||
| if (!annotation_file.Exists()) { | |||||
| std::string err_msg = "CocoNode: annotation_file is invalid or does not exist."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("CocoNode", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"})); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build CocoNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| CocoOp::TaskType task_type; | |||||
| if (task_ == "Detection") { | |||||
| task_type = CocoOp::TaskType::Detection; | |||||
| } else if (task_ == "Stuff") { | |||||
| task_type = CocoOp::TaskType::Stuff; | |||||
| } else if (task_ == "Keypoint") { | |||||
| task_type = CocoOp::TaskType::Keypoint; | |||||
| } else if (task_ == "Panoptic") { | |||||
| task_type = CocoOp::TaskType::Panoptic; | |||||
| } | |||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor(std::string("image"), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||||
| switch (task_type) { | |||||
| case CocoOp::TaskType::Detection: | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| break; | |||||
| case CocoOp::TaskType::Stuff: | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("segmentation"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| break; | |||||
| case CocoOp::TaskType::Keypoint: | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("keypoints"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("num_keypoints"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| break; | |||||
| case CocoOp::TaskType::Panoptic: | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "CocoNode::Build : Invalid task type"; | |||||
| return {}; | |||||
| } | |||||
| std::shared_ptr<CocoOp> op = | |||||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | |||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class CocoNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | |||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~CocoNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string annotation_file_; | |||||
| std::string task_; | |||||
| bool decode_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_ | |||||
| @@ -0,0 +1,127 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for CSVNode | |||||
| CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, | |||||
| const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(csv_files), | |||||
| field_delim_(field_delim), | |||||
| column_defaults_(column_defaults), | |||||
| column_names_(column_names), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id) {} | |||||
| Status CSVNode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_)); | |||||
| if (field_delim_ == '"' || field_delim_ == '\r' || field_delim_ == '\n') { | |||||
| std::string err_msg = "CSVNode: The field delimiter should not be \", \\r, \\n"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (num_samples_ < 0) { | |||||
| std::string err_msg = "CSVNode: Invalid number of samples: " + std::to_string(num_samples_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVNode", num_shards_, shard_id_)); | |||||
| if (find(column_defaults_.begin(), column_defaults_.end(), nullptr) != column_defaults_.end()) { | |||||
| std::string err_msg = "CSVNode: column_default should not be null."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (!column_names_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVNode", "column_names", column_names_)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build CSVNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||||
| // Sort the dataset files in a lexicographical order | |||||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list; | |||||
| for (auto v : column_defaults_) { | |||||
| if (v->type == CsvType::INT) { | |||||
| column_default_list.push_back( | |||||
| std::make_shared<CsvOp::Record<int>>(CsvOp::INT, std::dynamic_pointer_cast<CsvRecord<int>>(v)->value)); | |||||
| } else if (v->type == CsvType::FLOAT) { | |||||
| column_default_list.push_back( | |||||
| std::make_shared<CsvOp::Record<float>>(CsvOp::FLOAT, std::dynamic_pointer_cast<CsvRecord<float>>(v)->value)); | |||||
| } else if (v->type == CsvType::STRING) { | |||||
| column_default_list.push_back(std::make_shared<CsvOp::Record<std::string>>( | |||||
| CsvOp::STRING, std::dynamic_pointer_cast<CsvRecord<std::string>>(v)->value)); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||||
| sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, | |||||
| num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||||
| RETURN_EMPTY_IF_ERROR(csv_op->Init()); | |||||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||||
| // Inject ShuffleOp | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||||
| int64_t num_rows = 0; | |||||
| // First, get the number of rows in the dataset | |||||
| RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows)); | |||||
| // Add the shuffle op after this op | |||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| node_ops.push_back(shuffle_op); | |||||
| } | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(csv_op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,82 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { /// \brief Base class of CSV Record | |||||
| /// \brief Record type for CSV | |||||
| enum CsvType : uint8_t { INT = 0, FLOAT, STRING }; | |||||
| class CsvBase { | |||||
| public: | |||||
| CsvBase() = default; | |||||
| explicit CsvBase(CsvType t) : type(t) {} | |||||
| virtual ~CsvBase() {} | |||||
| CsvType type; | |||||
| }; | |||||
| /// \brief CSV Record that can represent integer, float and string. | |||||
| template <typename T> | |||||
| class CsvRecord : public CsvBase { | |||||
| public: | |||||
| CsvRecord() = default; | |||||
| CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {} | |||||
| ~CsvRecord() {} | |||||
| T value; | |||||
| }; | |||||
| class CSVNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names, | |||||
| int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~CSVNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> dataset_files_; | |||||
| char field_delim_; | |||||
| std::vector<std::shared_ptr<CsvBase>> column_defaults_; | |||||
| std::vector<std::string> column_names_; | |||||
| int64_t num_samples_; | |||||
| ShuffleMode shuffle_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_ | |||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_file_(dataset_file), | |||||
| usage_(usage), | |||||
| decode_(decode), | |||||
| class_index_(class_indexing), | |||||
| sampler_(sampler) {} | |||||
| Status ManifestNode::ValidateParams() { | |||||
| std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'}; | |||||
| for (char c : dataset_file_) { | |||||
| auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c); | |||||
| if (p != forbidden_symbols.end()) { | |||||
| std::string err_msg = "ManifestNode: filename should not contain :*?\"<>|`&;\'"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| Path manifest_file(dataset_file_); | |||||
| if (!manifest_file.Exists()) { | |||||
| std::string err_msg = "ManifestNode: dataset file: [" + dataset_file_ + "] is invalid or not exist"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_)); | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("ManifestNode", usage_, {"train", "eval", "inference"})); | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| // Do internal Schema generation. | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); | |||||
| TensorShape scalar = TensorShape::CreateScalar(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||||
| std::shared_ptr<ManifestOp> manifest_op; | |||||
| manifest_op = | |||||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | |||||
| class_index_, std::move(schema), std::move(sampler_->Build()), usage_); | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(manifest_op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MANIFEST_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MANIFEST_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class ManifestNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~ManifestNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_file_; | |||||
| std::string usage_; | |||||
| bool decode_; | |||||
| std::map<std::string, int32_t> class_index_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MANIFEST_NODE_H_ | |||||
| @@ -0,0 +1,165 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <stack> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | |||||
| : dataset_file_(std::string()), | |||||
| dataset_files_(dataset_files), | |||||
| search_for_pattern_(false), | |||||
| columns_list_(columns_list), | |||||
| sampler_(sampler), | |||||
| padded_sample_(padded_sample), | |||||
| sample_bytes_({}), | |||||
| num_padded_(num_padded) {} | |||||
| MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded) | |||||
| : dataset_file_(dataset_file), | |||||
| dataset_files_({}), | |||||
| search_for_pattern_(true), | |||||
| columns_list_(columns_list), | |||||
| sampler_(sampler), | |||||
| padded_sample_(padded_sample), | |||||
| sample_bytes_({}), | |||||
| num_padded_(num_padded) {} | |||||
| Status MindDataNode::ValidateParams() { | |||||
| if (!search_for_pattern_ && dataset_files_.size() > 4096) { | |||||
| std::string err_msg = | |||||
| "MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " + | |||||
| std::to_string(dataset_file_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| std::vector<std::string> dataset_file_vec = | |||||
| search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_; | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataNode", dataset_file_vec)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataNode", sampler_)); | |||||
| if (!columns_list_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataNode", "columns_list", columns_list_)); | |||||
| } | |||||
| if (padded_sample_ != nullptr) { | |||||
| if (num_padded_ < 0) { | |||||
| std::string err_msg = | |||||
| "MindDataNode: num_padded must be greater than or equal to zero, num_padded: " + std::to_string(num_padded_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (columns_list_.empty()) { | |||||
| std::string err_msg = "MindDataNode: padded_sample is specified and requires columns_list as well"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (std::string &column : columns_list_) { | |||||
| if (padded_sample_.find(column) == padded_sample_.end()) { | |||||
| std::string err_msg = "MindDataNode: " + column + " in columns_list does not match any column in padded_sample"; | |||||
| MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (num_padded_ > 0) { | |||||
| if (padded_sample_ == nullptr) { | |||||
| std::string err_msg = "MindDataNode: num_padded is specified but padded_sample is not"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to create runtime sampler for minddata dataset | |||||
| Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler, | |||||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_, | |||||
| int64_t num_padded) { | |||||
| std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset(); | |||||
| if (op == nullptr) { | |||||
| std::string err_msg = | |||||
| "MindDataNode: Unsupported sampler is supplied for MindDataset. Supported sampler list: " | |||||
| "SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops; | |||||
| while (op != nullptr) { | |||||
| auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op); | |||||
| if (sampler_op && num_padded > 0) { | |||||
| sampler_op->SetNumPaddedSamples(num_padded); | |||||
| stack_ops.push(sampler_op); | |||||
| } else { | |||||
| stack_ops.push(op); | |||||
| } | |||||
| op = op->GetChildOp(); | |||||
| } | |||||
| while (!stack_ops.empty()) { | |||||
| operators_->push_back(stack_ops.top()); | |||||
| stack_ops.pop(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Helper function to set sample_bytes from py::byte type | |||||
| void MindDataNode::SetSampleBytes(std::map<std::string, std::string> *sample_bytes) { sample_bytes_ = *sample_bytes; } | |||||
| std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::vector<std::shared_ptr<ShardOperator>> operators_; | |||||
| RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); | |||||
| std::shared_ptr<MindRecordOp> mindrecord_op; | |||||
| // If pass a string to MindData(), it will be treated as a pattern to search for matched files, | |||||
| // else if pass a vector to MindData(), it will be treated as specified files to be read | |||||
| if (search_for_pattern_) { | |||||
| std::vector<std::string> dataset_file_vec_ = {dataset_file_}; | |||||
| mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, rows_per_buffer_, dataset_file_vec_, | |||||
| search_for_pattern_, connector_que_size_, columns_list_, operators_, | |||||
| num_padded_, padded_sample_, sample_bytes_); | |||||
| } else { | |||||
| mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, rows_per_buffer_, dataset_files_, search_for_pattern_, | |||||
| connector_que_size_, columns_list_, operators_, num_padded_, | |||||
| padded_sample_, sample_bytes_); | |||||
| } | |||||
| RETURN_EMPTY_IF_ERROR(mindrecord_op->Init()); | |||||
| node_ops.push_back(mindrecord_op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MINDDATA_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MINDDATA_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class MindDataNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded); | |||||
| /// \brief Constructor | |||||
| MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded); | |||||
| /// \brief Destructor | |||||
| ~MindDataNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| /// \brief Build sampler chain for minddata dataset | |||||
| /// \return Status Status::OK() if input sampler is valid | |||||
| Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler, | |||||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_, | |||||
| int64_t num_padded); | |||||
| /// \brief Set sample_bytes when padded_sample has py::byte value | |||||
| /// \note Pybind will use this function to set sample_bytes into MindDataNode | |||||
| void SetSampleBytes(std::map<std::string, std::string> *sample_bytes); | |||||
| private: | |||||
| std::string dataset_file_; // search_for_pattern_ will be true in this mode | |||||
| std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode | |||||
| bool search_for_pattern_; | |||||
| std::vector<std::string> columns_list_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| nlohmann::json padded_sample_; | |||||
| std::map<std::string, std::string> sample_bytes_; // enable in python | |||||
| int64_t num_padded_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MINDDATA_NODE_H_ | |||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| Status MnistNode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_)); | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("MnistNode", usage_, {"train", "test", "all"})); | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| // Do internal Schema generation. | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); | |||||
| TensorShape scalar = TensorShape::CreateScalar(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||||
| connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MNIST_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MNIST_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class MnistNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~MnistNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_MNIST_NODE_H_ | |||||
| @@ -0,0 +1,104 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // ValidateParams for RandomNode | |||||
| Status RandomNode::ValidateParams() { | |||||
| if (total_rows_ < 0) { | |||||
| std::string err_msg = | |||||
| "RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomNode", sampler_)); | |||||
| if (!columns_list_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| int32_t RandomNode::GenRandomInt(int32_t min, int32_t max) { | |||||
| std::uniform_int_distribution<int32_t> uniDist(min, max); | |||||
| return uniDist(rand_gen_); | |||||
| } | |||||
| // Build for RandomNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| rand_gen_.seed(GetSeed()); // seed the random generator | |||||
| // If total rows was not given, then randomly pick a number | |||||
| std::shared_ptr<SchemaObj> schema_obj; | |||||
| if (!schema_path_.empty()) { | |||||
| schema_obj = Schema(schema_path_); | |||||
| if (schema_obj == nullptr) { | |||||
| return {}; | |||||
| } | |||||
| } | |||||
| std::string schema_json_string, schema_file_path; | |||||
| if (schema_ != nullptr) { | |||||
| schema_->set_dataset_type("Random"); | |||||
| if (total_rows_ != 0) { | |||||
| schema_->set_num_rows(total_rows_); | |||||
| } | |||||
| schema_json_string = schema_->to_json(); | |||||
| } else { | |||||
| schema_file_path = schema_path_; | |||||
| } | |||||
| std::unique_ptr<DataSchema> data_schema; | |||||
| std::vector<std::string> columns_to_load; | |||||
| if (columns_list_.size() > 0) { | |||||
| columns_to_load = columns_list_; | |||||
| } | |||||
| if (!schema_file_path.empty() || !schema_json_string.empty()) { | |||||
| data_schema = std::make_unique<DataSchema>(); | |||||
| if (!schema_file_path.empty()) { | |||||
| data_schema->LoadSchemaFile(schema_file_path, columns_to_load); | |||||
| } else if (!schema_json_string.empty()) { | |||||
| data_schema->LoadSchemaString(schema_json_string, columns_to_load); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<RandomDataOp> op; | |||||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | |||||
| std::move(data_schema), std::move(sampler_->Build())); | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class RandomNode : public Dataset { | |||||
| public: | |||||
| // Some constants to provide limits to random generation. | |||||
| static constexpr int32_t kMaxNumColumns = 4; | |||||
| static constexpr int32_t kMaxRank = 4; | |||||
| static constexpr int32_t kMaxDimValue = 32; | |||||
| /// \brief Constructor | |||||
| RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| total_rows_(total_rows), | |||||
| schema_path_(""), | |||||
| schema_(std::move(schema)), | |||||
| columns_list_(columns_list), | |||||
| sampler_(std::move(sampler)) {} | |||||
| /// \brief Constructor | |||||
| RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| total_rows_(total_rows), | |||||
| schema_path_(schema_path), | |||||
| columns_list_(columns_list), | |||||
| sampler_(std::move(sampler)) {} | |||||
| /// \brief Destructor | |||||
| ~RandomNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| /// \brief A quick inline for producing a random number between (and including) min/max | |||||
| /// \param[in] min minimum number that can be generated. | |||||
| /// \param[in] max maximum number that can be generated. | |||||
| /// \return The generated random number | |||||
| int32_t GenRandomInt(int32_t min, int32_t max); | |||||
| int32_t total_rows_; | |||||
| std::string schema_path_; | |||||
| std::shared_ptr<SchemaObj> schema_; | |||||
| std::vector<std::string> columns_list_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| std::mt19937 rand_gen_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RANDOM_NODE_H_ | |||||
| @@ -0,0 +1,100 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for TextFileNode | |||||
| TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id) {} | |||||
| Status TextFileNode::ValidateParams() { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_)); | |||||
| if (num_samples_ < 0) { | |||||
| std::string err_msg = "TextFileNode: Invalid number of samples: " + std::to_string(num_samples_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileNode", num_shards_, shard_id_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build TextFileNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||||
| // Sort the dataset files in a lexicographical order | |||||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||||
| // Do internal Schema generation. | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||||
| // Create and initalize TextFileOp | |||||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | |||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr); | |||||
| RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | |||||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||||
| // Inject ShuffleOp | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||||
| int64_t num_rows = 0; | |||||
| // First, get the number of rows in the dataset | |||||
| RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||||
| // Add the shuffle op after this op | |||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| node_ops.push_back(shuffle_op); | |||||
| } | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| // Add TextFileOp | |||||
| node_ops.push_back(text_file_op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEXT_FILE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEXT_FILE_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| /// \class TextFileNode | |||||
| /// \brief A Dataset derived class to represent TextFile dataset | |||||
| class TextFileNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||||
| int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~TextFileNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> dataset_files_; | |||||
| int32_t num_samples_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| ShuffleMode shuffle_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TEXT_FILE_NODE_H_ | |||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/jagged_connector.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Validator for TFRecordNode | |||||
| Status TFRecordNode::ValidateParams() { return Status::OK(); } | |||||
| // Function to build TFRecordNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| // Sort the datasets file in a lexicographical order | |||||
| std::vector<std::string> sorted_dir_files = dataset_files_; | |||||
| std::sort(sorted_dir_files.begin(), sorted_dir_files.end()); | |||||
| // Create Schema Object | |||||
| std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>(); | |||||
| if (!schema_path_.empty()) { | |||||
| RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaFile(schema_path_, columns_list_)); | |||||
| } else if (schema_obj_ != nullptr) { | |||||
| std::string schema_json_string = schema_obj_->to_json(); | |||||
| RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, columns_list_)); | |||||
| } | |||||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||||
| // Create and initialize TFReaderOp | |||||
| std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>( | |||||
| num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema), | |||||
| connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_, nullptr); | |||||
| RETURN_EMPTY_IF_ERROR(tf_reader_op->Init()); | |||||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||||
| // Inject ShuffleOp | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||||
| int64_t num_rows = 0; | |||||
| // First, get the number of rows in the dataset | |||||
| RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files)); | |||||
| // Add the shuffle op after this op | |||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| node_ops.push_back(shuffle_op); | |||||
| } | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| // Add TFReaderOp | |||||
| node_ops.push_back(tf_reader_op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| /// \class TFRecordNode | |||||
| /// \brief A Dataset derived class to represent TFRecord dataset | |||||
| class TFRecordNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \note Parameter 'schema' is the path to the schema file | |||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | |||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| schema_path_(schema), | |||||
| columns_list_(columns_list), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id), | |||||
| shard_equal_rows_(shard_equal_rows) {} | |||||
| /// \brief Constructor | |||||
| /// \note Parameter 'schema' is shared pointer to Schema object | |||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | |||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| schema_obj_(schema), | |||||
| columns_list_(columns_list), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id), | |||||
| shard_equal_rows_(shard_equal_rows) {} | |||||
| /// \brief Destructor | |||||
| ~TFRecordNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> dataset_files_; | |||||
| std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string | |||||
| std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object. | |||||
| std::vector<std::string> columns_list_; | |||||
| int64_t num_samples_; | |||||
| ShuffleMode shuffle_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| bool shard_equal_rows_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TF_RECORD_NODE_H_ | |||||
| @@ -0,0 +1,117 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for VOCNode | |||||
| VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_dir_(dataset_dir), | |||||
| task_(task), | |||||
| usage_(usage), | |||||
| class_index_(class_indexing), | |||||
| decode_(decode), | |||||
| sampler_(sampler) {} | |||||
| Status VOCNode::ValidateParams() { | |||||
| Path dir(dataset_dir_); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_)); | |||||
| if (task_ == "Segmentation") { | |||||
| if (!class_index_.empty()) { | |||||
| std::string err_msg = "VOCNode: class_indexing is invalid in Segmentation task."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt"; | |||||
| if (!imagesets_file.Exists()) { | |||||
| std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; | |||||
| MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } else if (task_ == "Detection") { | |||||
| Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt"; | |||||
| if (!imagesets_file.Exists()) { | |||||
| std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist"; | |||||
| MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } else { | |||||
| std::string err_msg = "VOCNode: Invalid task: " + task_; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build VOCNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| VOCOp::TaskType task_type_; | |||||
| if (task_ == "Segmentation") { | |||||
| task_type_ = VOCOp::TaskType::Segmentation; | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||||
| } else if (task_ == "Detection") { | |||||
| task_type_ = VOCOp::TaskType::Detection; | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| RETURN_EMPTY_IF_ERROR(schema->AddColumn( | |||||
| ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||||
| } | |||||
| std::shared_ptr<VOCOp> voc_op; | |||||
| voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | |||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| node_ops.push_back(voc_op); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_VOC_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_VOC_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class VOCNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~VOCNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| const std::string kColumnImage = "image"; | |||||
| const std::string kColumnTarget = "target"; | |||||
| const std::string kColumnBbox = "bbox"; | |||||
| const std::string kColumnLabel = "label"; | |||||
| const std::string kColumnDifficult = "difficult"; | |||||
| const std::string kColumnTruncate = "truncate"; | |||||
| std::string dataset_dir_; | |||||
| std::string task_; | |||||
| std::string usage_; | |||||
| std::map<std::string, int32_t> class_index_; | |||||
| bool decode_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_VOC_NODE_H_ | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_TAKE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_TAKE_NODE_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -51,4 +51,4 @@ class TakeNode : public Dataset { | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_TAKE_NODE_H_ | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) { | |||||
| for (auto dataset : datasets_) { | |||||
| this->children.push_back(dataset); | |||||
| } | |||||
| } | |||||
| Status ZipNode::ValidateParams() { | |||||
| if (datasets_.empty()) { | |||||
| std::string err_msg = "ZipNode: datasets to zip are not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||||
| std::string err_msg = "ZipNode: zip datasets should not be null."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> ZipNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ZIP_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ZIP_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class ZipNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets); | |||||
| /// \brief Destructor | |||||
| ~ZipNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::shared_ptr<Dataset>> datasets_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ZIP_NODE_H_ | |||||
| @@ -106,13 +106,20 @@ class ZipNode; | |||||
| } \ | } \ | ||||
| } while (false) | } while (false) | ||||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||||
| int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op); | |||||
| // Helper function to validate dataset files parameter | |||||
| Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files); | |||||
| // Helper function to validate dataset num_shards and shard_id parameters | // Helper function to validate dataset num_shards and shard_id parameters | ||||
| Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id); | Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id); | ||||
| // Helper function to validate dataset sampler parameter | // Helper function to validate dataset sampler parameter | ||||
| Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler); | Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler); | ||||
| Status ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings); | |||||
| Status ValidateStringValue(const std::string &dataset_name, const std::string &str, | |||||
| const std::unordered_set<std::string> &valid_strings); | |||||
| // Helper function to validate dataset input/output column parameterCD - | // Helper function to validate dataset input/output column parameterCD - | ||||
| Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | ||||
| @@ -815,551 +822,8 @@ class SchemaObj { | |||||
| /* ####################################### Derived Dataset classes ################################# */ | /* ####################################### Derived Dataset classes ################################# */ | ||||
| // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS | |||||
| // (In alphabetical order) | |||||
| class AlbumNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| AlbumNode(const std::string &dataset_dir, const std::string &data_schema, | |||||
| const std::vector<std::string> &column_names, bool decode, const std::shared_ptr<SamplerObj> &sampler); | |||||
| /// \brief Destructor | |||||
| ~AlbumNode() = default; | |||||
| /// \brief a base class override function to create a runtime dataset op object from this class | |||||
| /// \return shared pointer to the newly created DatasetOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string schema_path_; | |||||
| std::vector<std::string> column_names_; | |||||
| bool decode_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| class CelebANode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache); | |||||
| /// \brief Destructor | |||||
| ~CelebANode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| bool decode_; | |||||
| std::set<std::string> extensions_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS | |||||
| // (In alphabetical order) | |||||
| class Cifar10Node : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~Cifar10Node() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| class Cifar100Node : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~Cifar100Node() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| /// \class CLUENode | |||||
| /// \brief A Dataset derived class to represent CLUE dataset | |||||
| class CLUENode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples, | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~CLUENode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| /// \brief Split string based on a character delimiter | |||||
| /// \return A string vector | |||||
| std::vector<std::string> split(const std::string &s, char delim); | |||||
| std::vector<std::string> dataset_files_; | |||||
| std::string task_; | |||||
| std::string usage_; | |||||
| int64_t num_samples_; | |||||
| ShuffleMode shuffle_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| }; | |||||
| class CocoNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task, | |||||
| const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~CocoNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string annotation_file_; | |||||
| std::string task_; | |||||
| bool decode_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| /// \brief Record type for CSV | |||||
| enum CsvType : uint8_t { INT = 0, FLOAT, STRING }; | |||||
| /// \brief Base class of CSV Record | |||||
| class CsvBase { | |||||
| public: | |||||
| CsvBase() = default; | |||||
| explicit CsvBase(CsvType t) : type(t) {} | |||||
| virtual ~CsvBase() {} | |||||
| CsvType type; | |||||
| }; | |||||
| /// \brief CSV Record that can represent integer, float and string. | |||||
| template <typename T> | |||||
| class CsvRecord : public CsvBase { | |||||
| public: | |||||
| CsvRecord() = default; | |||||
| CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {} | |||||
| ~CsvRecord() {} | |||||
| T value; | |||||
| }; | |||||
| class CSVNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| CSVNode(const std::vector<std::string> &dataset_files, char field_delim, | |||||
| const std::vector<std::shared_ptr<CsvBase>> &column_defaults, const std::vector<std::string> &column_names, | |||||
| int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~CSVNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> dataset_files_; | |||||
| char field_delim_; | |||||
| std::vector<std::shared_ptr<CsvBase>> column_defaults_; | |||||
| std::vector<std::string> column_names_; | |||||
| int64_t num_samples_; | |||||
| ShuffleMode shuffle_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| }; | |||||
| #ifndef ENABLE_ANDROID | |||||
| class ManifestNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~ManifestNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_file_; | |||||
| std::string usage_; | |||||
| bool decode_; | |||||
| std::map<std::string, int32_t> class_index_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| #endif | |||||
| #ifndef ENABLE_ANDROID | |||||
| class MindDataNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded); | |||||
| /// \brief Constructor | |||||
| MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded); | |||||
| /// \brief Destructor | |||||
| ~MindDataNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| /// \brief Build sampler chain for minddata dataset | |||||
| /// \return Status Status::OK() if input sampler is valid | |||||
| Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler, | |||||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_, | |||||
| int64_t num_padded); | |||||
| /// \brief Set sample_bytes when padded_sample has py::byte value | |||||
| /// \note Pybind will use this function to set sample_bytes into MindDataNode | |||||
| void SetSampleBytes(std::map<std::string, std::string> *sample_bytes); | |||||
| private: | |||||
| std::string dataset_file_; // search_for_pattern_ will be true in this mode | |||||
| std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode | |||||
| bool search_for_pattern_; | |||||
| std::vector<std::string> columns_list_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| nlohmann::json padded_sample_; | |||||
| std::map<std::string, std::string> sample_bytes_; // enable in python | |||||
| int64_t num_padded_; | |||||
| }; | |||||
| #endif | |||||
| class MnistNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~MnistNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::string dataset_dir_; | |||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| class RandomNode : public Dataset { | |||||
| public: | |||||
| // Some constants to provide limits to random generation. | |||||
| static constexpr int32_t kMaxNumColumns = 4; | |||||
| static constexpr int32_t kMaxRank = 4; | |||||
| static constexpr int32_t kMaxDimValue = 32; | |||||
| /// \brief Constructor | |||||
| RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| total_rows_(total_rows), | |||||
| schema_path_(""), | |||||
| schema_(std::move(schema)), | |||||
| columns_list_(columns_list), | |||||
| sampler_(std::move(sampler)) {} | |||||
| /// \brief Constructor | |||||
| RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list, | |||||
| const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| total_rows_(total_rows), | |||||
| schema_path_(schema_path), | |||||
| columns_list_(columns_list), | |||||
| sampler_(std::move(sampler)) {} | |||||
| /// \brief Destructor | |||||
| ~RandomNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| /// \brief A quick inline for producing a random number between (and including) min/max | |||||
| /// \param[in] min minimum number that can be generated. | |||||
| /// \param[in] max maximum number that can be generated. | |||||
| /// \return The generated random number | |||||
| int32_t GenRandomInt(int32_t min, int32_t max); | |||||
| int32_t total_rows_; | |||||
| std::string schema_path_; | |||||
| std::shared_ptr<SchemaObj> schema_; | |||||
| std::vector<std::string> columns_list_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| std::mt19937 rand_gen_; | |||||
| }; | |||||
| /// \class TextFileNode | |||||
| /// \brief A Dataset derived class to represent TextFile dataset | |||||
| class TextFileNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||||
| int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~TextFileNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> dataset_files_; | |||||
| int32_t num_samples_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| ShuffleMode shuffle_; | |||||
| }; | |||||
| /// \class TFRecordNode | |||||
| /// \brief A Dataset derived class to represent TFRecord dataset | |||||
| class TFRecordNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \note Parameter 'schema' is the path to the schema file | |||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema, | |||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| schema_path_(schema), | |||||
| columns_list_(columns_list), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id), | |||||
| shard_equal_rows_(shard_equal_rows) {} | |||||
| /// \brief Constructor | |||||
| /// \note Parameter 'schema' is shared pointer to Schema object | |||||
| TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema, | |||||
| const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache) | |||||
| : Dataset(std::move(cache)), | |||||
| dataset_files_(dataset_files), | |||||
| schema_obj_(schema), | |||||
| columns_list_(columns_list), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id), | |||||
| shard_equal_rows_(shard_equal_rows) {} | |||||
| /// \brief Destructor | |||||
| ~TFRecordNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> dataset_files_; | |||||
| std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string | |||||
| std::shared_ptr<SchemaObj> schema_obj_; // schema_obj_ schema object. | |||||
| std::vector<std::string> columns_list_; | |||||
| int64_t num_samples_; | |||||
| ShuffleMode shuffle_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| bool shard_equal_rows_; | |||||
| }; | |||||
| #ifndef ENABLE_ANDROID | |||||
| class VOCNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler, | |||||
| std::shared_ptr<DatasetCache> cache); | |||||
| /// \brief Destructor | |||||
| ~VOCNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| const std::string kColumnImage = "image"; | |||||
| const std::string kColumnTarget = "target"; | |||||
| const std::string kColumnBbox = "bbox"; | |||||
| const std::string kColumnLabel = "label"; | |||||
| const std::string kColumnDifficult = "difficult"; | |||||
| const std::string kColumnTruncate = "truncate"; | |||||
| std::string dataset_dir_; | |||||
| std::string task_; | |||||
| std::string usage_; | |||||
| std::map<std::string, int32_t> class_index_; | |||||
| bool decode_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| }; | |||||
| #endif | |||||
| // DERIVED DATASET CLASSES FOR DATASET OPS | |||||
| // (In alphabetical order) | |||||
| class MapNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | |||||
| std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | |||||
| const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr); | |||||
| /// \brief Destructor | |||||
| ~MapNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::shared_ptr<TensorOperation>> operations_; | |||||
| std::vector<std::string> input_columns_; | |||||
| std::vector<std::string> output_columns_; | |||||
| std::vector<std::string> project_columns_; | |||||
| }; | |||||
| class SkipNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit SkipNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| /// \brief Destructor | |||||
| ~SkipNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t skip_count_; | |||||
| }; | |||||
| class ZipNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets); | |||||
| /// \brief Destructor | |||||
| ~ZipNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::shared_ptr<Dataset>> datasets_; | |||||
| }; | |||||
| } // namespace api | } // namespace api | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASETS_H_ | ||||
| @@ -81,18 +81,18 @@ AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/image/lite_cv MINDDATA_KERNELS_IMA | |||||
| if (BUILD_MINDDATA STREQUAL "full") | if (BUILD_MINDDATA STREQUAL "full") | ||||
| list(REMOVE_ITEM MINDDATA_API_SRC_FILES | |||||
| "${MINDDATA_DIR}/api/text.cc" | |||||
| list(REMOVE_ITEM MINDDATA_API_SRC_FILES | |||||
| "${MINDDATA_DIR}/api/text.cc" | |||||
| "${MINDDATA_DIR}/api/de_tensor.cc" | "${MINDDATA_DIR}/api/de_tensor.cc" | ||||
| "${MINDDATA_DIR}/api/execute.cc" | "${MINDDATA_DIR}/api/execute.cc" | ||||
| ) | ) | ||||
| list(REMOVE_ITEM MINDDATA_CALLBACK_SRC_FILES | |||||
| "${MINDDATA_DIR}/callback/py_ds_callback.cc" | |||||
| list(REMOVE_ITEM MINDDATA_CALLBACK_SRC_FILES | |||||
| "${MINDDATA_DIR}/callback/py_ds_callback.cc" | |||||
| ) | ) | ||||
| list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc") | list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc") | ||||
| list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SRC_FILES | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SRC_FILES | |||||
| "${MINDDATA_DIR}/engine/datasetops/build_sentence_piece_vocab_op.cc" | "${MINDDATA_DIR}/engine/datasetops/build_sentence_piece_vocab_op.cc" | ||||
| "${MINDDATA_DIR}/engine/datasetops/filter_op.cc" | "${MINDDATA_DIR}/engine/datasetops/filter_op.cc" | ||||
| "${MINDDATA_DIR}/engine/datasetops/barrier_op.cc" | "${MINDDATA_DIR}/engine/datasetops/barrier_op.cc" | ||||
| @@ -104,7 +104,7 @@ if (BUILD_MINDDATA STREQUAL "full") | |||||
| "${MINDDATA_DIR}/engine/datasetops/cache_op.cc" | "${MINDDATA_DIR}/engine/datasetops/cache_op.cc" | ||||
| ) | ) | ||||
| list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||||
| "${MINDDATA_DIR}/engine/datasetops/source/generator_op.cc" | "${MINDDATA_DIR}/engine/datasetops/source/generator_op.cc" | ||||
| "${MINDDATA_DIR}/engine/datasetops/source/voc_op.cc" | "${MINDDATA_DIR}/engine/datasetops/source/voc_op.cc" | ||||
| "${MINDDATA_DIR}/engine/datasetops/source/manifest_op.cc" | "${MINDDATA_DIR}/engine/datasetops/source/manifest_op.cc" | ||||
| @@ -131,6 +131,10 @@ if (BUILD_MINDDATA STREQUAL "full") | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | ||||
| "${MINDDATA_DIR}/engine/ir/datasetops/source/generator_node.cc" | "${MINDDATA_DIR}/engine/ir/datasetops/source/generator_node.cc" | ||||
| "${MINDDATA_DIR}/engine/ir/datasetops/source/manifest_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/source/minddata_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/source/tf_record_node.cc" | |||||
| "${MINDDATA_DIR}/engine/ir/datasetops/source/voc_node.cc" | |||||
| ) | ) | ||||
| list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES | ||||
| @@ -184,7 +188,7 @@ if (BUILD_MINDDATA STREQUAL "full") | |||||
| opencv_imgproc | opencv_imgproc | ||||
| mindspore::json | mindspore::json | ||||
| ) | ) | ||||
| # ref: https://github.com/android/ndk/issues/1202 | # ref: https://github.com/android/ndk/issues/1202 | ||||
| if (PLATFORM_ARM32) | if (PLATFORM_ARM32) | ||||
| file(GLOB_RECURSE LIBCLANG_RT_LIB $ENV{ANDROID_NDK}/libclang_rt.builtins-arm-android.a) | file(GLOB_RECURSE LIBCLANG_RT_LIB $ENV{ANDROID_NDK}/libclang_rt.builtins-arm-android.a) | ||||
| @@ -206,7 +210,7 @@ if (BUILD_MINDDATA STREQUAL "full") | |||||
| elseif (BUILD_MINDDATA STREQUAL "lite") | elseif (BUILD_MINDDATA STREQUAL "lite") | ||||
| list(REMOVE_ITEM MINDDATA_CORE_SRC_FILES "${MINDDATA_DIR}/core/client.cc") | list(REMOVE_ITEM MINDDATA_CORE_SRC_FILES "${MINDDATA_DIR}/core/client.cc") | ||||
| list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc") | list(REMOVE_ITEM MINDDATA_KERNELS_SRC_FILES "${MINDDATA_DIR}/kernels/py_func_op.cc") | ||||
| add_library(minddata_eager_mid OBJECT | |||||
| add_library(minddata_eager_mid OBJECT | |||||
| ${MINDDATA_DIR}/api/de_tensor.cc | ${MINDDATA_DIR}/api/de_tensor.cc | ||||
| ${MINDDATA_DIR}/api/execute.cc | ${MINDDATA_DIR}/api/execute.cc | ||||
| ) | ) | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| #if defined(__ANDROID__) || defined(ANDROID) | #if defined(__ANDROID__) || defined(ANDROID) | ||||
| #include <android/log.h> | #include <android/log.h> | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/util/path.h" | #include "minddata/dataset/util/path.h" | ||||
| using Dataset = mindspore::dataset::api::Dataset; | using Dataset = mindspore::dataset::api::Dataset; | ||||
| @@ -16,6 +16,19 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -16,6 +16,22 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -18,10 +18,24 @@ | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::GlobalContext; | |||||
| using mindspore::dataset::ShuffleMode; | using mindspore::dataset::ShuffleMode; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| using mindspore::dataset::GlobalContext; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | class MindDataTestPipeline : public UT::DatasetOpTesting { | ||||
| protected: | protected: | ||||
| @@ -49,11 +63,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetAFQMC) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("sentence1"), row.end()); | EXPECT_NE(row.find("sentence1"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "蚂蚁借呗等额还款能否换成先息后本", | |||||
| "蚂蚁花呗说我违约了", | |||||
| "帮我看看本月花呗账单结清了没" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"蚂蚁借呗等额还款能否换成先息后本", "蚂蚁花呗说我违约了", | |||||
| "帮我看看本月花呗账单结清了没"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -75,11 +86,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetAFQMC) { | |||||
| // test | // test | ||||
| usage = "test"; | usage = "test"; | ||||
| expected_result = { | |||||
| "借呗取消的时间", | |||||
| "网商贷用什么方法转变成借呗", | |||||
| "我的借呗为什么开通不了" | |||||
| }; | |||||
| expected_result = {"借呗取消的时间", "网商贷用什么方法转变成借呗", "我的借呗为什么开通不了"}; | |||||
| ds = CLUE({test_file}, task, usage, 0, ShuffleMode::kFalse); | ds = CLUE({test_file}, task, usage, 0, ShuffleMode::kFalse); | ||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| iter = ds->CreateIterator(); | iter = ds->CreateIterator(); | ||||
| @@ -100,11 +107,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetAFQMC) { | |||||
| // eval | // eval | ||||
| usage = "eval"; | usage = "eval"; | ||||
| expected_result = { | |||||
| "你有花呗吗", | |||||
| "吃饭能用花呗吗", | |||||
| "蚂蚁花呗支付金额有什么限制" | |||||
| }; | |||||
| expected_result = {"你有花呗吗", "吃饭能用花呗吗", "蚂蚁花呗支付金额有什么限制"}; | |||||
| ds = CLUE({eval_file}, task, usage, 0, ShuffleMode::kFalse); | ds = CLUE({eval_file}, task, usage, 0, ShuffleMode::kFalse); | ||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| iter = ds->CreateIterator(); | iter = ds->CreateIterator(); | ||||
| @@ -179,11 +182,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetCMNLI) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("sentence1"), row.end()); | EXPECT_NE(row.find("sentence1"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "你应该给这件衣服定一个价格。", | |||||
| "我怎么知道他要说什么", | |||||
| "向左。" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"你应该给这件衣服定一个价格。", "我怎么知道他要说什么", "向左。"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -224,11 +223,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetCSL) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("abst"), row.end()); | EXPECT_NE(row.find("abst"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "这是一段长文本", | |||||
| "这是一段长文本", | |||||
| "这是一段长文本" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"这是一段长文本", "这是一段长文本", "这是一段长文本"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -337,11 +332,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetIFLYTEK) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("sentence"), row.end()); | EXPECT_NE(row.find("sentence"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "第一个文本", | |||||
| "第二个文本", | |||||
| "第三个文本" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"第一个文本", "第二个文本", "第三个文本"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -396,14 +387,12 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesA) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("sentence1"), row.end()); | EXPECT_NE(row.find("sentence1"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "你有花呗吗", | |||||
| "吃饭能用花呗吗", | |||||
| "蚂蚁花呗支付金额有什么限制", | |||||
| "蚂蚁借呗等额还款能否换成先息后本", | |||||
| "蚂蚁花呗说我违约了", | |||||
| "帮我看看本月花呗账单结清了没" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"你有花呗吗", | |||||
| "吃饭能用花呗吗", | |||||
| "蚂蚁花呗支付金额有什么限制", | |||||
| "蚂蚁借呗等额还款能否换成先息后本", | |||||
| "蚂蚁花呗说我违约了", | |||||
| "帮我看看本月花呗账单结清了没"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -463,14 +452,12 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesB) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("sentence1"), row.end()); | EXPECT_NE(row.find("sentence1"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "你有花呗吗", | |||||
| "吃饭能用花呗吗", | |||||
| "蚂蚁花呗支付金额有什么限制", | |||||
| "蚂蚁借呗等额还款能否换成先息后本", | |||||
| "蚂蚁花呗说我违约了", | |||||
| "帮我看看本月花呗账单结清了没" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"你有花呗吗", | |||||
| "吃饭能用花呗吗", | |||||
| "蚂蚁花呗支付金额有什么限制", | |||||
| "蚂蚁借呗等额还款能否换成先息后本", | |||||
| "蚂蚁花呗说我违约了", | |||||
| "帮我看看本月花呗账单结清了没"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -523,11 +510,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleGlobal) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("sentence1"), row.end()); | EXPECT_NE(row.find("sentence1"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "蚂蚁花呗说我违约了", | |||||
| "帮我看看本月花呗账单结清了没", | |||||
| "蚂蚁借呗等额还款能否换成先息后本" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"蚂蚁花呗说我违约了", "帮我看看本月花呗账单结清了没", | |||||
| "蚂蚁借呗等额还款能否换成先息后本"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| auto text = row["sentence1"]; | auto text = row["sentence1"]; | ||||
| @@ -572,11 +556,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetTNEWS) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("sentence"), row.end()); | EXPECT_NE(row.find("sentence"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "新闻1", | |||||
| "新闻2", | |||||
| "新闻3" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"新闻1", "新闻2", "新闻3"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -617,11 +597,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetWSC) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("text"), row.end()); | EXPECT_NE(row.find("text"), row.end()); | ||||
| std::vector<std::string> expected_result = { | |||||
| "小明呢,他在哪?", | |||||
| "小红刚刚看到小明,他在操场", | |||||
| "等小明回来,小张你叫他交作业" | |||||
| }; | |||||
| std::vector<std::string> expected_result = {"小明呢,他在哪?", "小红刚刚看到小明,他在操场", | |||||
| "等小明回来,小张你叫他交作业"}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -16,10 +16,40 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::dsize_t; | |||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| using mindspore::dataset::TensorShape; | using mindspore::dataset::TensorShape; | ||||
| using mindspore::dataset::dsize_t; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | class MindDataTestPipeline : public UT::DatasetOpTesting { | ||||
| protected: | protected: | ||||
| @@ -79,12 +109,14 @@ TEST_F(MindDataTestPipeline, TestCocoDetection) { | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | ||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| std::string expect_file[] = {"000000391895", "000000318219", "000000554625", "000000574769", "000000060623", | |||||
| "000000309022"}; | |||||
| std::string expect_file[] = {"000000391895", "000000318219", "000000554625", | |||||
| "000000574769", "000000060623", "000000309022"}; | |||||
| std::vector<std::vector<float>> expect_bbox_vector = {{10.0, 10.0, 10.0, 10.0, 70.0, 70.0, 70.0, 70.0}, | std::vector<std::vector<float>> expect_bbox_vector = {{10.0, 10.0, 10.0, 10.0, 70.0, 70.0, 70.0, 70.0}, | ||||
| {20.0, 20.0, 20.0, 20.0, 80.0, 80.0, 80.0, 80.0}, | {20.0, 20.0, 20.0, 20.0, 80.0, 80.0, 80.0, 80.0}, | ||||
| {30.0, 30.0, 30.0, 30.0}, {40.0, 40.0, 40.0, 40.0}, | |||||
| {50.0, 50.0, 50.0, 50.0}, {60.0, 60.0, 60.0, 60.0}}; | |||||
| {30.0, 30.0, 30.0, 30.0}, | |||||
| {40.0, 40.0, 40.0, 40.0}, | |||||
| {50.0, 50.0, 50.0, 50.0}, | |||||
| {60.0, 60.0, 60.0, 60.0}}; | |||||
| std::vector<std::vector<uint32_t>> expect_catagoryid_list = {{1, 7}, {2, 8}, {3}, {4}, {5}, {6}}; | std::vector<std::vector<uint32_t>> expect_catagoryid_list = {{1, 7}, {2, 8}, {3}, {4}, {5}, {6}}; | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -148,13 +180,13 @@ TEST_F(MindDataTestPipeline, TestCocoKeypoint) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| std::string expect_file[] = {"000000391895", "000000318219"}; | std::string expect_file[] = {"000000391895", "000000318219"}; | ||||
| std::vector<std::vector<float>> expect_keypoint_vector = | |||||
| {{368.0, 61.0, 1.0, 369.0, 52.0, 2.0, 0.0, 0.0, 0.0, 382.0, 48.0, 2.0, 0.0, 0.0, 0.0, 368.0, 84.0, 2.0, 435.0, | |||||
| 81.0, 2.0, 362.0, 125.0, 2.0, 446.0, 125.0, 2.0, 360.0, 153.0, 2.0, 0.0, 0.0, 0.0, 397.0, 167.0, 1.0, 439.0, | |||||
| 166.0, 1.0, 369.0, 193.0, 2.0, 461.0, 234.0, 2.0, 361.0, 246.0, 2.0, 474.0, 287.0, 2.0}, | |||||
| {244.0, 139.0, 2.0, 0.0, 0.0, 0.0, 226.0, 118.0, 2.0, 0.0, 0.0, 0.0, 154.0, 159.0, 2.0, 143.0, 261.0, 2.0, 135.0, | |||||
| 312.0, 2.0, 271.0, 423.0, 2.0, 184.0, 530.0, 2.0, 261.0, 280.0, 2.0, 347.0, 592.0, 2.0, 0.0, 0.0, 0.0, 123.0, | |||||
| 596.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}; | |||||
| std::vector<std::vector<float>> expect_keypoint_vector = { | |||||
| {368.0, 61.0, 1.0, 369.0, 52.0, 2.0, 0.0, 0.0, 0.0, 382.0, 48.0, 2.0, 0.0, 0.0, 0.0, 368.0, 84.0, 2.0, | |||||
| 435.0, 81.0, 2.0, 362.0, 125.0, 2.0, 446.0, 125.0, 2.0, 360.0, 153.0, 2.0, 0.0, 0.0, 0.0, 397.0, 167.0, 1.0, | |||||
| 439.0, 166.0, 1.0, 369.0, 193.0, 2.0, 461.0, 234.0, 2.0, 361.0, 246.0, 2.0, 474.0, 287.0, 2.0}, | |||||
| {244.0, 139.0, 2.0, 0.0, 0.0, 0.0, 226.0, 118.0, 2.0, 0.0, 0.0, 0.0, 154.0, 159.0, 2.0, 143.0, 261.0, 2.0, | |||||
| 135.0, 312.0, 2.0, 271.0, 423.0, 2.0, 184.0, 530.0, 2.0, 261.0, 280.0, 2.0, 347.0, 592.0, 2.0, 0.0, 0.0, 0.0, | |||||
| 123.0, 596.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}; | |||||
| std::vector<std::vector<dsize_t>> expect_size = {{1, 51}, {1, 51}}; | std::vector<std::vector<dsize_t>> expect_size = {{1, 51}, {1, 51}}; | ||||
| std::vector<std::vector<uint32_t>> expect_num_keypoints_list = {{14}, {10}}; | std::vector<std::vector<uint32_t>> expect_num_keypoints_list = {{14}, {10}}; | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| @@ -258,17 +290,17 @@ TEST_F(MindDataTestPipeline, TestCocoStuff) { | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | ||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| std::string expect_file[] = {"000000391895", "000000318219", "000000554625", "000000574769", "000000060623", | |||||
| "000000309022"}; | |||||
| std::vector<std::vector<float>> expect_segmentation_vector = | |||||
| {{10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, | |||||
| 70.0, 72.0, 73.0, 74.0, 75.0, -1.0, -1.0, -1.0, -1.0, -1.0}, | |||||
| {20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, | |||||
| 10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, -1.0}, | |||||
| {40.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 40.0, 41.0, 42.0}, | |||||
| {50.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0}, | |||||
| {60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}, | |||||
| {60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}}; | |||||
| std::string expect_file[] = {"000000391895", "000000318219", "000000554625", | |||||
| "000000574769", "000000060623", "000000309022"}; | |||||
| std::vector<std::vector<float>> expect_segmentation_vector = { | |||||
| {10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, | |||||
| 70.0, 72.0, 73.0, 74.0, 75.0, -1.0, -1.0, -1.0, -1.0, -1.0}, | |||||
| {20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, | |||||
| 10.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, -1.0}, | |||||
| {40.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 40.0, 41.0, 42.0}, | |||||
| {50.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0}, | |||||
| {60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}, | |||||
| {60.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0}}; | |||||
| std::vector<std::vector<dsize_t>> expect_size = {{2, 10}, {2, 11}, {1, 12}, {1, 13}, {1, 14}, {2, 7}}; | std::vector<std::vector<dsize_t>> expect_size = {{2, 10}, {2, 11}, {1, 12}, {1, 13}, {1, 14}, {2, 7}}; | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -18,13 +18,17 @@ | |||||
| #include "minddata/dataset/include/config.h" | #include "minddata/dataset/include/config.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::ShuffleMode; | using mindspore::dataset::ShuffleMode; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -18,10 +18,40 @@ | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::GlobalContext; | |||||
| using mindspore::dataset::ShuffleMode; | using mindspore::dataset::ShuffleMode; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| using mindspore::dataset::GlobalContext; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | class MindDataTestPipeline : public UT::DatasetOpTesting { | ||||
| protected: | protected: | ||||
| @@ -98,12 +128,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("col1"), row.end()); | EXPECT_NE(row.find("col1"), row.end()); | ||||
| std::vector<std::vector<std::string>> expected_result = { | std::vector<std::vector<std::string>> expected_result = { | ||||
| {"17", "18", "19", "20"}, | |||||
| {"1", "2", "3", "4"}, | |||||
| {"5", "6", "7", "8"}, | |||||
| {"13", "14", "15", "16"}, | |||||
| {"21", "22", "23", "24"}, | |||||
| {"9", "10", "11", "12"}, | |||||
| {"17", "18", "19", "20"}, {"1", "2", "3", "4"}, {"5", "6", "7", "8"}, | |||||
| {"13", "14", "15", "16"}, {"21", "22", "23", "24"}, {"9", "10", "11", "12"}, | |||||
| }; | }; | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| @@ -148,10 +174,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetNumSamples) { | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | ||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("col1"), row.end()); | EXPECT_NE(row.find("col1"), row.end()); | ||||
| std::vector<std::vector<std::string>> expected_result = { | |||||
| {"1", "2", "3", "4"}, | |||||
| {"5", "6", "7", "8"} | |||||
| }; | |||||
| std::vector<std::vector<std::string>> expected_result = {{"1", "2", "3", "4"}, {"5", "6", "7", "8"}}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -191,10 +214,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDistribution) { | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | ||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("col1"), row.end()); | EXPECT_NE(row.find("col1"), row.end()); | ||||
| std::vector<std::vector<std::string>> expected_result = { | |||||
| {"1", "2", "3", "4"}, | |||||
| {"5", "6", "7", "8"} | |||||
| }; | |||||
| std::vector<std::vector<std::string>> expected_result = {{"1", "2", "3", "4"}, {"5", "6", "7", "8"}}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -386,12 +406,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesA) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("col1"), row.end()); | EXPECT_NE(row.find("col1"), row.end()); | ||||
| std::vector<std::vector<std::string>> expected_result = { | std::vector<std::vector<std::string>> expected_result = { | ||||
| {"13", "14", "15", "16"}, | |||||
| {"1", "2", "3", "4"}, | |||||
| {"17", "18", "19", "20"}, | |||||
| {"5", "6", "7", "8"}, | |||||
| {"21", "22", "23", "24"}, | |||||
| {"9", "10", "11", "12"}, | |||||
| {"13", "14", "15", "16"}, {"1", "2", "3", "4"}, {"17", "18", "19", "20"}, | |||||
| {"5", "6", "7", "8"}, {"21", "22", "23", "24"}, {"9", "10", "11", "12"}, | |||||
| }; | }; | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| @@ -445,12 +461,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesB) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("col1"), row.end()); | EXPECT_NE(row.find("col1"), row.end()); | ||||
| std::vector<std::vector<std::string>> expected_result = { | std::vector<std::vector<std::string>> expected_result = { | ||||
| {"13", "14", "15", "16"}, | |||||
| {"1", "2", "3", "4"}, | |||||
| {"17", "18", "19", "20"}, | |||||
| {"5", "6", "7", "8"}, | |||||
| {"21", "22", "23", "24"}, | |||||
| {"9", "10", "11", "12"}, | |||||
| {"13", "14", "15", "16"}, {"1", "2", "3", "4"}, {"17", "18", "19", "20"}, | |||||
| {"5", "6", "7", "8"}, {"21", "22", "23", "24"}, {"9", "10", "11", "12"}, | |||||
| }; | }; | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| @@ -505,10 +517,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleGlobal) { | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| EXPECT_NE(row.find("col1"), row.end()); | EXPECT_NE(row.find("col1"), row.end()); | ||||
| std::vector<std::vector<std::string>> expected_result = { | std::vector<std::vector<std::string>> expected_result = { | ||||
| {"5", "6", "7", "8"}, | |||||
| {"9", "10", "11", "12"}, | |||||
| {"1", "2", "3", "4"} | |||||
| }; | |||||
| {"5", "6", "7", "8"}, {"9", "10", "11", "12"}, {"1", "2", "3", "4"}}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| @@ -1,3 +1,4 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | * Copyright 2020 Huawei Technologies Co., Ltd | ||||
| * | * | ||||
| @@ -16,12 +17,21 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -16,6 +16,19 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -16,6 +16,20 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -57,7 +71,6 @@ TEST_F(MindDataTestPipeline, TestMindDataSuccess1) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestMindDataSuccess2) { | TEST_F(MindDataTestPipeline, TestMindDataSuccess2) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess2 with a vector of single mindrecord file."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess2 with a vector of single mindrecord file."; | ||||
| @@ -18,16 +18,22 @@ | |||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -20,11 +20,41 @@ | |||||
| #include "mindspore/core/ir/dtype/type_id.h" | #include "mindspore/core/ir/dtype/type_id.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::DataType; | |||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| using mindspore::dataset::TensorShape; | using mindspore::dataset::TensorShape; | ||||
| using mindspore::dataset::DataType; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | class MindDataTestPipeline : public UT::DatasetOpTesting { | ||||
| protected: | protected: | ||||
| @@ -14,10 +14,23 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::ShuffleMode; | using mindspore::dataset::ShuffleMode; | ||||
| @@ -18,12 +18,19 @@ | |||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| @@ -16,10 +16,24 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::DataType; | |||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| using mindspore::dataset::TensorShape; | using mindspore::dataset::TensorShape; | ||||
| using mindspore::dataset::DataType; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | class MindDataTestPipeline : public UT::DatasetOpTesting { | ||||
| protected: | protected: | ||||
| @@ -16,13 +16,22 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -25,6 +25,36 @@ | |||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| #include "minddata/dataset/include/text.h" | #include "minddata/dataset/include/text.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::DataType; | using mindspore::dataset::DataType; | ||||
| using mindspore::dataset::ShuffleMode; | using mindspore::dataset::ShuffleMode; | ||||
| @@ -304,8 +334,8 @@ TEST_F(MindDataTestPipeline, TestVocabFromDatasetFail4) { | |||||
| // Create vocab from dataset | // Create vocab from dataset | ||||
| // Expected failure: special tokens are already in the dataset | // Expected failure: special tokens are already in the dataset | ||||
| std::shared_ptr<Vocab> vocab = ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, | |||||
| std::numeric_limits<int64_t>::max(), {"world"}); | |||||
| std::shared_ptr<Vocab> vocab = | |||||
| ds->BuildVocab({"text"}, {0, std::numeric_limits<int64_t>::max()}, std::numeric_limits<int64_t>::max(), {"world"}); | |||||
| EXPECT_EQ(vocab, nullptr); | EXPECT_EQ(vocab, nullptr); | ||||
| } | } | ||||
| @@ -18,12 +18,35 @@ | |||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::BorderType; | using mindspore::dataset::BorderType; | ||||
| @@ -18,13 +18,21 @@ | |||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | #include "minddata/dataset/engine/ir/datasetops/project_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::BorderType; | using mindspore::dataset::BorderType; | ||||
| @@ -20,8 +20,22 @@ | |||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||