concat, bucketbatch project rename fix ci round 1 fix ci round 2 fix up fix citags/v1.1.0
| @@ -41,19 +41,8 @@ | |||||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | #include "minddata/dataset/engine/datasetops/source/voc_op.h" | ||||
| #endif | #endif | ||||
| // Dataset operator headers (in alphabetical order) | // Dataset operator headers (in alphabetical order) | ||||
| #include "minddata/dataset/engine/datasetops/batch_op.h" | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/datasetops/build_vocab_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/project_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/rename_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/skip_op.h" | #include "minddata/dataset/engine/datasetops/skip_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/take_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/zip_op.h" | #include "minddata/dataset/engine/datasetops/zip_op.h" | ||||
| // Sampler headers (in alphabetical order) | // Sampler headers (in alphabetical order) | ||||
| @@ -61,8 +50,21 @@ | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | ||||
| // IR nodes | |||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #endif | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| @@ -1759,175 +1761,9 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() { | |||||
| #endif | #endif | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| BucketBatchByLengthNode::BucketBatchByLengthNode( | |||||
| std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names, | |||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||||
| std::function<TensorRow(TensorRow)> element_length_function, | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | |||||
| bool drop_remainder) | |||||
| : column_names_(column_names), | |||||
| bucket_boundaries_(bucket_boundaries), | |||||
| bucket_batch_sizes_(bucket_batch_sizes), | |||||
| element_length_function_(element_length_function), | |||||
| pad_info_(pad_info), | |||||
| pad_to_bucket_boundary_(pad_to_bucket_boundary), | |||||
| drop_remainder_(drop_remainder) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::shared_ptr<TensorOp> c_func; | |||||
| if (element_length_function_ != nullptr) { | |||||
| c_func = std::make_shared<CFuncOp>(element_length_function_); | |||||
| } else { | |||||
| c_func = nullptr; | |||||
| } | |||||
| node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||||
| c_func, pad_info_, pad_to_bucket_boundary_, | |||||
| drop_remainder_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| Status BucketBatchByLengthNode::ValidateParams() { | |||||
| if (element_length_function_ == nullptr && column_names_.size() != 1) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " + | |||||
| std::to_string(column_names_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| // Check bucket_boundaries: must be positive and strictly increasing | |||||
| if (bucket_boundaries_.empty()) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (int i = 0; i < bucket_boundaries_.size(); i++) { | |||||
| if (bucket_boundaries_[i] <= 0) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: "; | |||||
| MS_LOG(ERROR) | |||||
| << "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: " | |||||
| << i << " was: " << bucket_boundaries_[i]; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing."; | |||||
| MS_LOG(ERROR) | |||||
| << "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: " | |||||
| << i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i] | |||||
| << " respectively."; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| // Check bucket_batch_sizes: must be positive | |||||
| if (bucket_batch_sizes_.empty()) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) { | |||||
| std::string err_msg = | |||||
| "BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, | |||||
| const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, | |||||
| int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first) | |||||
| : vocab_(vocab), | |||||
| columns_(columns), | |||||
| freq_range_(freq_range), | |||||
| top_k_(top_k), | |||||
| special_tokens_(special_tokens), | |||||
| special_first_(special_first) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Function to build BuildVocabNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::shared_ptr<BuildVocabOp> build_vocab_op; | |||||
| build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_, | |||||
| special_first_, num_workers_, connector_que_size_); | |||||
| node_ops.push_back(build_vocab_op); | |||||
| return node_ops; | |||||
| } | |||||
| Status BuildVocabNode::ValidateParams() { | |||||
| if (vocab_ == nullptr) { | |||||
| std::string err_msg = "BuildVocabNode: vocab is null."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (top_k_ <= 0) { | |||||
| std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) { | |||||
| std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; | |||||
| MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " | |||||
| << "but got [" << freq_range_.first << ", " << freq_range_.second << "]"; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (!columns_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | #endif | ||||
| // Function to build ConcatOp | |||||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) { | |||||
| this->children = datasets_; | |||||
| } | |||||
| Status ConcatNode::ValidateParams() { | |||||
| if (datasets_.empty()) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| const std::vector<std::string> &project_columns) | const std::vector<std::string> &project_columns) | ||||
| @@ -1984,110 +1820,6 @@ Status MapNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Function to build ProjectOp | |||||
| ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| Status ProjectNode::ValidateParams() { | |||||
| if (columns_.empty()) { | |||||
| std::string err_msg = "ProjectNode: No columns are specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<ProjectOp>(columns_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to build RenameOp | |||||
| RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns) | |||||
| : input_columns_(input_columns), output_columns_(output_columns) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| Status RenameNode::ValidateParams() { | |||||
| if (input_columns_.size() != output_columns_.size()) { | |||||
| std::string err_msg = "RenameNode: input and output columns must be the same size"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_)); | |||||
| return node_ops; | |||||
| } | |||||
| Status RepeatNode::ValidateParams() { | |||||
| if (repeat_count_ <= 0 && repeat_count_ != -1) { | |||||
| std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + | |||||
| std::to_string(repeat_count_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Constructor for ShuffleNode | |||||
| ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch) | |||||
| : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Function to build the ShuffleOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, | |||||
| rows_per_buffer_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to validate the parameters for ShuffleNode | |||||
| Status ShuffleNode::ValidateParams() { | |||||
| if (shuffle_size_ <= 1) { | |||||
| std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Constructor for SkipNode | // Constructor for SkipNode | ||||
| SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) { | SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) { | ||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| @@ -2113,31 +1845,6 @@ Status SkipNode::ValidateParams() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Constructor for TakeNode | |||||
| TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Function to build the TakeOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to validate the parameters for TakeNode | |||||
| Status TakeNode::ValidateParams() { | |||||
| if (take_count_ <= 0 && take_count_ != -1) { | |||||
| std::string err_msg = | |||||
| "TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to build ZipOp | // Function to build ZipOp | ||||
| ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) { | ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) { | ||||
| for (auto dataset : datasets_) { | for (auto dataset : datasets_) { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -1,5 +1,22 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_subdirectory(source) | add_subdirectory(source) | ||||
| add_library(engine-ir-datasetops OBJECT | |||||
| batch_node.cc) | |||||
| set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | |||||
| batch_node.cc | |||||
| concat_node.cc | |||||
| project_node.cc | |||||
| rename_node.cc | |||||
| repeat_node.cc | |||||
| shuffle_node.cc | |||||
| take_node.cc | |||||
| ) | |||||
| if (NOT ENABLE_ANDROID) | |||||
| set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | |||||
| ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES} | |||||
| bucket_batch_by_length_node.cc | |||||
| build_vocab_node.cc) | |||||
| endif () | |||||
| add_library(engine-ir-datasetops OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES}) | |||||
| @@ -0,0 +1,121 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| BucketBatchByLengthNode::BucketBatchByLengthNode( | |||||
| std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names, | |||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||||
| std::function<TensorRow(TensorRow)> element_length_function, | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | |||||
| bool drop_remainder) | |||||
| : column_names_(column_names), | |||||
| bucket_boundaries_(bucket_boundaries), | |||||
| bucket_batch_sizes_(bucket_batch_sizes), | |||||
| element_length_function_(element_length_function), | |||||
| pad_info_(pad_info), | |||||
| pad_to_bucket_boundary_(pad_to_bucket_boundary), | |||||
| drop_remainder_(drop_remainder) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::shared_ptr<TensorOp> c_func; | |||||
| if (element_length_function_ != nullptr) { | |||||
| c_func = std::make_shared<CFuncOp>(element_length_function_); | |||||
| } else { | |||||
| c_func = nullptr; | |||||
| } | |||||
| node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||||
| c_func, pad_info_, pad_to_bucket_boundary_, | |||||
| drop_remainder_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| Status BucketBatchByLengthNode::ValidateParams() { | |||||
| if (element_length_function_ == nullptr && column_names_.size() != 1) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " + | |||||
| std::to_string(column_names_.size()); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| // Check bucket_boundaries: must be positive and strictly increasing | |||||
| if (bucket_boundaries_.empty()) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| for (int i = 0; i < bucket_boundaries_.size(); i++) { | |||||
| if (bucket_boundaries_[i] <= 0) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: "; | |||||
| MS_LOG(ERROR) | |||||
| << "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: " | |||||
| << i << " was: " << bucket_boundaries_[i]; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing."; | |||||
| MS_LOG(ERROR) | |||||
| << "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: " | |||||
| << i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i] | |||||
| << " respectively."; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| // Check bucket_batch_sizes: must be positive | |||||
| if (bucket_batch_sizes_.empty()) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) { | |||||
| std::string err_msg = | |||||
| "BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) { | |||||
| std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class BucketBatchByLengthNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names, | |||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||||
| std::function<TensorRow(TensorRow)> element_length_function = nullptr, | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, | |||||
| bool pad_to_bucket_boundary = false, bool drop_remainder = false); | |||||
| /// \brief Destructor | |||||
| ~BucketBatchByLengthNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> column_names_; | |||||
| std::vector<int32_t> bucket_boundaries_; | |||||
| std::vector<int32_t> bucket_batch_sizes_; | |||||
| std::function<TensorRow(TensorRow)> element_length_function_; | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_; | |||||
| bool pad_to_bucket_boundary_; | |||||
| bool drop_remainder_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_ | |||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/build_vocab_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, | |||||
| const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, | |||||
| int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first) | |||||
| : vocab_(vocab), | |||||
| columns_(columns), | |||||
| freq_range_(freq_range), | |||||
| top_k_(top_k), | |||||
| special_tokens_(special_tokens), | |||||
| special_first_(special_first) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Function to build BuildVocabNode | |||||
| std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| std::shared_ptr<BuildVocabOp> build_vocab_op; | |||||
| build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_, | |||||
| special_first_, num_workers_, connector_que_size_); | |||||
| node_ops.push_back(build_vocab_op); | |||||
| return node_ops; | |||||
| } | |||||
| Status BuildVocabNode::ValidateParams() { | |||||
| if (vocab_ == nullptr) { | |||||
| std::string err_msg = "BuildVocabNode: vocab is null."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (top_k_ <= 0) { | |||||
| std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) { | |||||
| std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)"; | |||||
| MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), " | |||||
| << "but got [" << freq_range_.first << ", " << freq_range_.second << "]"; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (!columns_.empty()) { | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class BuildVocabNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns, | |||||
| const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | |||||
| const std::vector<std::string> &special_tokens, bool special_first); | |||||
| /// \brief Destructor | |||||
| ~BuildVocabNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::shared_ptr<Vocab> vocab_; | |||||
| std::vector<std::string> columns_; | |||||
| std::pair<int64_t, int64_t> freq_range_; | |||||
| int64_t top_k_; | |||||
| std::vector<std::string> special_tokens_; | |||||
| bool special_first_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_ | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Function to build ConcatOp | |||||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) { | |||||
| this->children = datasets_; | |||||
| } | |||||
| Status ConcatNode::ValidateParams() { | |||||
| if (datasets_.empty()) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets are not specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { | |||||
| std::string err_msg = "ConcatNode: concatenated datasets should not be null."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class ConcatNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets); | |||||
| /// \brief Destructor | |||||
| ~ConcatNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::shared_ptr<Dataset>> datasets_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_ | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/project_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Function to build ProjectOp | |||||
| ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| Status ProjectNode::ValidateParams() { | |||||
| if (columns_.empty()) { | |||||
| std::string err_msg = "ProjectNode: No columns are specified."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<ProjectOp>(columns_)); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class ProjectNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns); | |||||
| /// \brief Destructor | |||||
| ~ProjectNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> columns_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_ | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/rename_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Function to build RenameOp | |||||
| RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns) | |||||
| : input_columns_(input_columns), output_columns_(output_columns) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| Status RenameNode::ValidateParams() { | |||||
| if (input_columns_.size() != output_columns_.size()) { | |||||
| std::string err_msg = "RenameNode: input and output columns must be the same size"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_)); | |||||
| RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class RenameNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns); | |||||
| /// \brief Destructor | |||||
| ~RenameNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> input_columns_; | |||||
| std::vector<std::string> output_columns_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_ | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_)); | |||||
| return node_ops; | |||||
| } | |||||
| Status RepeatNode::ValidateParams() { | |||||
| if (repeat_count_ <= 0 && repeat_count_ != -1) { | |||||
| std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " + | |||||
| std::to_string(repeat_count_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class RepeatNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| /// \brief Destructor | |||||
| ~RepeatNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t repeat_count_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_ | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/shuffle_op.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for ShuffleNode | |||||
| ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch) | |||||
| : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Function to build the ShuffleOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, | |||||
| rows_per_buffer_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to validate the parameters for ShuffleNode | |||||
| Status ShuffleNode::ValidateParams() { | |||||
| if (shuffle_size_ <= 1) { | |||||
| std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class ShuffleNode : public Dataset { | |||||
| public: | |||||
| ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch); | |||||
| ~ShuffleNode() = default; | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t shuffle_size_; | |||||
| uint32_t shuffle_seed_; | |||||
| bool reset_every_epoch_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_ | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/datasetops/take_op.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| // Constructor for TakeNode | |||||
| TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) { | |||||
| this->children.push_back(child); | |||||
| } | |||||
| // Function to build the TakeOp | |||||
| std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to validate the parameters for TakeNode | |||||
| Status TakeNode::ValidateParams() { | |||||
| if (take_count_ <= 0 && take_count_ != -1) { | |||||
| std::string err_msg = | |||||
| "TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace api { | |||||
| class TakeNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| /// \brief Destructor | |||||
| ~TakeNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t take_count_; | |||||
| }; | |||||
| } // namespace api | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_ | |||||
| @@ -1201,85 +1201,6 @@ class VOCNode : public Dataset { | |||||
| // DERIVED DATASET CLASSES FOR DATASET OPS | // DERIVED DATASET CLASSES FOR DATASET OPS | ||||
| // (In alphabetical order) | // (In alphabetical order) | ||||
| #ifndef ENABLE_ANDROID | |||||
| class BucketBatchByLengthNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names, | |||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||||
| std::function<TensorRow(TensorRow)> element_length_function = nullptr, | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, | |||||
| bool pad_to_bucket_boundary = false, bool drop_remainder = false); | |||||
| /// \brief Destructor | |||||
| ~BucketBatchByLengthNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> column_names_; | |||||
| std::vector<int32_t> bucket_boundaries_; | |||||
| std::vector<int32_t> bucket_batch_sizes_; | |||||
| std::function<TensorRow(TensorRow)> element_length_function_; | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_; | |||||
| bool pad_to_bucket_boundary_; | |||||
| bool drop_remainder_; | |||||
| }; | |||||
| class BuildVocabNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns, | |||||
| const std::pair<int64_t, int64_t> &freq_range, int64_t top_k, | |||||
| const std::vector<std::string> &special_tokens, bool special_first); | |||||
| /// \brief Destructor | |||||
| ~BuildVocabNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::shared_ptr<Vocab> vocab_; | |||||
| std::vector<std::string> columns_; | |||||
| std::pair<int64_t, int64_t> freq_range_; | |||||
| int64_t top_k_; | |||||
| std::vector<std::string> special_tokens_; | |||||
| bool special_first_; | |||||
| }; | |||||
| #endif | |||||
| class ConcatNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets); | |||||
| /// \brief Destructor | |||||
| ~ConcatNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::shared_ptr<Dataset>> datasets_; | |||||
| }; | |||||
| class MapNode : public Dataset { | class MapNode : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -1305,84 +1226,6 @@ class MapNode : public Dataset { | |||||
| std::vector<std::string> project_columns_; | std::vector<std::string> project_columns_; | ||||
| }; | }; | ||||
| class ProjectNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns); | |||||
| /// \brief Destructor | |||||
| ~ProjectNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> columns_; | |||||
| }; | |||||
| class RenameNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns, | |||||
| const std::vector<std::string> &output_columns); | |||||
| /// \brief Destructor | |||||
| ~RenameNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> input_columns_; | |||||
| std::vector<std::string> output_columns_; | |||||
| }; | |||||
| class RepeatNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| /// \brief Destructor | |||||
| ~RepeatNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t repeat_count_; | |||||
| }; | |||||
| class ShuffleNode : public Dataset { | |||||
| public: | |||||
| ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch); | |||||
| ~ShuffleNode() = default; | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t shuffle_size_; | |||||
| uint32_t shuffle_seed_; | |||||
| bool reset_every_epoch_; | |||||
| }; | |||||
| class SkipNode : public Dataset { | class SkipNode : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -1403,26 +1246,6 @@ class SkipNode : public Dataset { | |||||
| int32_t skip_count_; | int32_t skip_count_; | ||||
| }; | }; | ||||
| class TakeNode : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count); | |||||
| /// \brief Destructor | |||||
| ~TakeNode() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return shared pointer to the list of newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return Status Status::OK() if all the parameters are valid | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| int32_t take_count_; | |||||
| }; | |||||
| class ZipNode : public Dataset { | class ZipNode : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -18,6 +18,13 @@ | |||||
| #include "minddata/dataset/include/config.h" | #include "minddata/dataset/include/config.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::ShuffleMode; | using mindspore::dataset::ShuffleMode; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -17,6 +17,11 @@ | |||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -18,8 +18,17 @@ | |||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -16,10 +16,14 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| @@ -16,8 +16,14 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| using mindspore::dataset::TensorShape; | using mindspore::dataset::TensorShape; | ||||
| @@ -16,8 +16,13 @@ | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||
| @@ -19,6 +19,11 @@ | |||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::BorderType; | using mindspore::dataset::BorderType; | ||||
| @@ -18,8 +18,14 @@ | |||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| #include "minddata/dataset/include/vision.h" | #include "minddata/dataset/include/vision.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| using namespace mindspore::dataset::api; | using namespace mindspore::dataset::api; | ||||
| using mindspore::dataset::BorderType; | using mindspore::dataset::BorderType; | ||||
| using mindspore::dataset::Tensor; | using mindspore::dataset::Tensor; | ||||