Browse Source

shuffle take repeat bucketbatch buildvocab repeat shuffle take project concat rename node IR added

concat, bucketbatch project rename

fix ci round 1

fix ci round 2

fix up

fix ci
tags/v1.1.0
Zirui Wu 5 years ago
parent
commit
d471552fc5
29 changed files with 1085 additions and 493 deletions
  1. +14
    -307
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc
  3. +19
    -2
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt
  4. +121
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc
  5. +64
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h
  6. +83
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc
  7. +61
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h
  8. +60
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc
  9. +53
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h
  10. +57
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc
  11. +54
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h
  12. +59
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc
  13. +56
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h
  14. +54
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc
  15. +56
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  16. +59
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc
  17. +52
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h
  18. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc
  19. +55
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc
  20. +54
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h
  21. +0
    -177
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  22. +7
    -0
      tests/ut/cpp/dataset/c_api_dataset_config_test.cc
  23. +5
    -0
      tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc
  24. +10
    -1
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
  25. +5
    -1
      tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc
  26. +7
    -1
      tests/ut/cpp/dataset/c_api_datasets_test.cc
  27. +6
    -1
      tests/ut/cpp/dataset/c_api_samplers_test.cc
  28. +5
    -0
      tests/ut/cpp/dataset/c_api_transforms_test.cc
  29. +7
    -1
      tests/ut/cpp/dataset/c_api_vision_test.cc

+ 14
- 307
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -41,19 +41,8 @@
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#endif
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
#endif
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/concat_op.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"

// Sampler headers (in alphabetical order)
@@ -61,8 +50,21 @@
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"

// IR nodes
// IR non-leaf nodes
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"

#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#endif

// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"

#include "minddata/dataset/core/config_manager.h"
@@ -1759,175 +1761,9 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
#endif

#ifndef ENABLE_ANDROID
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder)
: column_names_(column_names),
bucket_boundaries_(bucket_boundaries),
bucket_batch_sizes_(bucket_batch_sizes),
element_length_function_(element_length_function),
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder) {
this->children.push_back(child);
}

std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

std::shared_ptr<TensorOp> c_func;
if (element_length_function_ != nullptr) {
c_func = std::make_shared<CFuncOp>(element_length_function_);
} else {
c_func = nullptr;
}
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
c_func, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_));
return node_ops;
}

Status BucketBatchByLengthNode::ValidateParams() {
if (element_length_function_ == nullptr && column_names_.size() != 1) {
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
std::to_string(column_names_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

// Check bucket_boundaries: must be positive and strictly increasing
if (bucket_boundaries_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

for (int i = 0; i < bucket_boundaries_.size(); i++) {
if (bucket_boundaries_[i] <= 0) {
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: ";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: "
<< i << " was: " << bucket_boundaries_[i];
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) {
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing.";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: "
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i]
<< " respectively.";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}

// Check bucket_batch_sizes: must be positive
if (bucket_batch_sizes_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) {
std::string err_msg =
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
: vocab_(vocab),
columns_(columns),
freq_range_(freq_range),
top_k_(top_k),
special_tokens_(special_tokens),
special_first_(special_first) {
this->children.push_back(child);
}

// Function to build BuildVocabNode
std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_);
node_ops.push_back(build_vocab_op);
return node_ops;
}

Status BuildVocabNode::ValidateParams() {
if (vocab_ == nullptr) {
std::string err_msg = "BuildVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (top_k_ <= 0) {
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (!columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_));
}

return Status::OK();
}
#endif

// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
this->children = datasets_;
}

Status ConcatNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
return node_ops;
}

MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns)
@@ -1984,110 +1820,6 @@ Status MapNode::ValidateParams() {
return Status::OK();
}

// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
this->children.push_back(child);
}

Status ProjectNode::ValidateParams() {
if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_));

return Status::OK();
}

std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return node_ops;
}

// Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);
}

Status RenameNode::ValidateParams() {
if (input_columns_.size() != output_columns_.size()) {
std::string err_msg = "RenameNode: input and output columns must be the same size";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_));

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_));

return Status::OK();
}

std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}

RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
}

std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return node_ops;
}

Status RepeatNode::ValidateParams() {
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
std::to_string(repeat_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
this->children.push_back(child);
}

// Function to build the ShuffleOp
std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_));
return node_ops;
}

// Function to validate the parameters for ShuffleNode
Status ShuffleNode::ValidateParams() {
if (shuffle_size_ <= 1) {
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

// Constructor for SkipNode
SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) {
this->children.push_back(child);
@@ -2113,31 +1845,6 @@ Status SkipNode::ValidateParams() {
return Status::OK();
}

// Constructor for TakeNode
TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) {
this->children.push_back(child);
}

// Function to build the TakeOp
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
return node_ops;
}

// Function to validate the parameters for TakeNode
Status TakeNode::ValidateParams() {
if (take_count_ <= 0 && take_count_ != -1) {
std::string err_msg =
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

// Function to build ZipOp
ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
for (auto dataset : datasets_) {


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 19
- 2
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt View File

@@ -1,5 +1,22 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_subdirectory(source)
add_library(engine-ir-datasetops OBJECT
batch_node.cc)

set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
batch_node.cc
concat_node.cc
project_node.cc
rename_node.cc
repeat_node.cc
shuffle_node.cc
take_node.cc
)

if (NOT ENABLE_ANDROID)
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES}
bucket_batch_by_length_node.cc
build_vocab_node.cc)
endif ()

add_library(engine-ir-datasetops OBJECT ${DATASET_ENGINE_IR_DATASETOPS_SRC_FILES})

+ 121
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.cc View File

@@ -0,0 +1,121 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"

#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
bool drop_remainder)
: column_names_(column_names),
bucket_boundaries_(bucket_boundaries),
bucket_batch_sizes_(bucket_batch_sizes),
element_length_function_(element_length_function),
pad_info_(pad_info),
pad_to_bucket_boundary_(pad_to_bucket_boundary),
drop_remainder_(drop_remainder) {
this->children.push_back(child);
}

std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

std::shared_ptr<TensorOp> c_func;
if (element_length_function_ != nullptr) {
c_func = std::make_shared<CFuncOp>(element_length_function_);
} else {
c_func = nullptr;
}
node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
c_func, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_));
return node_ops;
}

Status BucketBatchByLengthNode::ValidateParams() {
if (element_length_function_ == nullptr && column_names_.size() != 1) {
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
std::to_string(column_names_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

// Check bucket_boundaries: must be positive and strictly increasing
if (bucket_boundaries_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

for (int i = 0; i < bucket_boundaries_.size(); i++) {
if (bucket_boundaries_[i] <= 0) {
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: ";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: "
<< i << " was: " << bucket_boundaries_[i];
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) {
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing.";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: "
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i]
<< " respectively.";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}

// Check bucket_batch_sizes: must be positive
if (bucket_batch_sizes_.empty()) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) {
std::string err_msg =
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) {
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 64
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h View File

@@ -0,0 +1,64 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {
namespace api {
class BucketBatchByLengthNode : public Dataset {
public:
/// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false);

/// \brief Destructor
~BucketBatchByLengthNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;
std::function<TensorRow(TensorRow)> element_length_function_;
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
bool pad_to_bucket_boundary_;
bool drop_remainder_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUCKET_BATCH_BY_LENGTH_NODE_H_

+ 83
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.cc View File

@@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "minddata/dataset/engine/datasetops/build_vocab_op.h"

#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
: vocab_(vocab),
columns_(columns),
freq_range_(freq_range),
top_k_(top_k),
special_tokens_(special_tokens),
special_first_(special_first) {
this->children.push_back(child);
}

// Function to build BuildVocabNode
std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_);
node_ops.push_back(build_vocab_op);
return node_ops;
}

Status BuildVocabNode::ValidateParams() {
if (vocab_ == nullptr) {
std::string err_msg = "BuildVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (top_k_ <= 0) {
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (!columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_));
}

return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 61
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h View File

@@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {
namespace api {

class BuildVocabNode : public Dataset {
public:
/// \brief Constructor
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first);

/// \brief Destructor
~BuildVocabNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> columns_;
std::pair<int64_t, int64_t> freq_range_;
int64_t top_k_;
std::vector<std::string> special_tokens_;
bool special_first_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_BUILD_VOCAB_NODE_H_

+ 60
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.cc View File

@@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/concat_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/datasetops/concat_op.h"

#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
this->children = datasets_;
}

Status ConcatNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
return node_ops;
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 53
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h View File

@@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {
namespace api {

class ConcatNode : public Dataset {
public:
/// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);

/// \brief Destructor
~ConcatNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CONCAT_NODE_H_

+ 57
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc View File

@@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/project_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/datasetops/project_op.h"

#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
this->children.push_back(child);
}

Status ProjectNode::ValidateParams() {
if (columns_.empty()) {
std::string err_msg = "ProjectNode: No columns are specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_));

return Status::OK();
}

std::vector<std::shared_ptr<DatasetOp>> ProjectNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<ProjectOp>(columns_));
return node_ops;
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 54
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {

namespace api {

class ProjectNode : public Dataset {
public:
/// \brief Constructor
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);

/// \brief Destructor
~ProjectNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::string> columns_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_PROJECT_NODE_H_

+ 59
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.cc View File

@@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/rename_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/datasetops/rename_op.h"

#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);
}

Status RenameNode::ValidateParams() {
if (input_columns_.size() != output_columns_.size()) {
std::string err_msg = "RenameNode: input and output columns must be the same size";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_));

RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_));

return Status::OK();
}

std::vector<std::shared_ptr<DatasetOp>> RenameNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
return node_ops;
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 56
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/rename_node.h View File

@@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {

namespace api {

class RenameNode : public Dataset {
public:
/// \brief Constructor
explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);

/// \brief Destructor
~RenameNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_RENAME_NODE_H_

+ 54
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.cc View File

@@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/datasetops/repeat_op.h"

#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

RepeatNode::RepeatNode(std::shared_ptr<Dataset> child, int32_t count) : repeat_count_(count) {
this->children.push_back(child);
}

std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
return node_ops;
}

Status RepeatNode::ValidateParams() {
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
std::to_string(repeat_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 56
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_

#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {

namespace api {

class RepeatNode : public Dataset {
public:
/// \brief Constructor
explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count);

/// \brief Destructor
~RepeatNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
int32_t repeat_count_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_REPEAT_NODE_H_

+ 59
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.cc View File

@@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {

// Constructor for ShuffleNode
ShuffleNode::ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch)
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
this->children.push_back(child);
}

// Function to build the ShuffleOp
std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_));
return node_ops;
}

// Function to validate the parameters for ShuffleNode
Status ShuffleNode::ValidateParams() {
if (shuffle_size_ <= 1) {
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

return Status::OK();
}

} // namespace api
} // namespace dataset
} // namespace mindspore

+ 52
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/shuffle_node.h View File

@@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_

#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {

namespace api {

class ShuffleNode : public Dataset {
public:
ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch);

~ShuffleNode() = default;

std::vector<std::shared_ptr<DatasetOp>> Build() override;

Status ValidateParams() override;

private:
int32_t shuffle_size_;
uint32_t shuffle_seed_;
bool reset_every_epoch_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_SHUFFLE_NODE_H_

+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 55
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc View File

@@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/dataset/engine/ir/datasetops/take_node.h"

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/engine/datasetops/take_op.h"
#include "minddata/dataset/util/status.h"

namespace mindspore {
namespace dataset {
namespace api {
// Constructor for TakeNode
TakeNode::TakeNode(std::shared_ptr<Dataset> child, int32_t count) : take_count_(count) {
this->children.push_back(child);
}

// Function to build the TakeOp
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;

node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
return node_ops;
}

// Function to validate the parameters for TakeNode
Status TakeNode::ValidateParams() {
if (take_count_ <= 0 && take_count_ != -1) {
std::string err_msg =
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

+ 54
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/include/datasets.h"

namespace mindspore {
namespace dataset {

namespace api {

class TakeNode : public Dataset {
public:
/// \brief Constructor
explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count);

/// \brief Destructor
~TakeNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
int32_t take_count_;
};

} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TAKE_NODE_H_

+ 0
- 177
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -1201,85 +1201,6 @@ class VOCNode : public Dataset {
// DERIVED DATASET CLASSES FOR DATASET OPS
// (In alphabetical order)

#ifndef ENABLE_ANDROID
class BucketBatchByLengthNode : public Dataset {
public:
/// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},
bool pad_to_bucket_boundary = false, bool drop_remainder = false);

/// \brief Destructor
~BucketBatchByLengthNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::string> column_names_;
std::vector<int32_t> bucket_boundaries_;
std::vector<int32_t> bucket_batch_sizes_;
std::function<TensorRow(TensorRow)> element_length_function_;
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_;
bool pad_to_bucket_boundary_;
bool drop_remainder_;
};

class BuildVocabNode : public Dataset {
public:
/// \brief Constructor
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first);

/// \brief Destructor
~BuildVocabNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::shared_ptr<Vocab> vocab_;
std::vector<std::string> columns_;
std::pair<int64_t, int64_t> freq_range_;
int64_t top_k_;
std::vector<std::string> special_tokens_;
bool special_first_;
};
#endif

class ConcatNode : public Dataset {
public:
/// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);

/// \brief Destructor
~ConcatNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};

class MapNode : public Dataset {
public:
/// \brief Constructor
@@ -1305,84 +1226,6 @@ class MapNode : public Dataset {
std::vector<std::string> project_columns_;
};

class ProjectNode : public Dataset {
public:
/// \brief Constructor
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);

/// \brief Destructor
~ProjectNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::string> columns_;
};

class RenameNode : public Dataset {
public:
/// \brief Constructor
explicit RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns);

/// \brief Destructor
~RenameNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
};

class RepeatNode : public Dataset {
public:
/// \brief Constructor
explicit RepeatNode(std::shared_ptr<Dataset> child, int32_t count);

/// \brief Destructor
~RepeatNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
int32_t repeat_count_;
};

class ShuffleNode : public Dataset {
public:
ShuffleNode(std::shared_ptr<Dataset> child, int32_t shuffle_size, bool reset_every_epoch);

~ShuffleNode() = default;

std::vector<std::shared_ptr<DatasetOp>> Build() override;

Status ValidateParams() override;

private:
int32_t shuffle_size_;
uint32_t shuffle_seed_;
bool reset_every_epoch_;
};

class SkipNode : public Dataset {
public:
/// \brief Constructor
@@ -1403,26 +1246,6 @@ class SkipNode : public Dataset {
int32_t skip_count_;
};

class TakeNode : public Dataset {
public:
/// \brief Constructor
explicit TakeNode(std::shared_ptr<Dataset> child, int32_t count);

/// \brief Destructor
~TakeNode() = default;

/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;

/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;

private:
int32_t take_count_;
};

class ZipNode : public Dataset {
public:
/// \brief Constructor


+ 7
- 0
tests/ut/cpp/dataset/c_api_dataset_config_test.cc View File

@@ -18,6 +18,13 @@
#include "minddata/dataset/include/config.h"
#include "minddata/dataset/include/datasets.h"

#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"

using namespace mindspore::dataset::api;
using mindspore::dataset::ShuffleMode;
using mindspore::dataset::Tensor;


+ 5
- 0
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc View File

@@ -17,6 +17,11 @@
#include "minddata/dataset/include/datasets.h"

#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"

using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;


+ 10
- 1
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -18,8 +18,17 @@
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/vision.h"

#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"

using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;



+ 5
- 1
tests/ut/cpp/dataset/c_api_dataset_tfrecord_test.cc View File

@@ -16,10 +16,14 @@
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/vision.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"

#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"

using namespace mindspore::dataset;
using namespace mindspore::dataset::api;


+ 7
- 1
tests/ut/cpp/dataset/c_api_datasets_test.cc View File

@@ -16,8 +16,14 @@
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"

#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"

using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;


+ 6
- 1
tests/ut/cpp/dataset/c_api_samplers_test.cc View File

@@ -16,8 +16,13 @@
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"

#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"

using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;


+ 5
- 0
tests/ut/cpp/dataset/c_api_transforms_test.cc View File

@@ -19,6 +19,11 @@
#include "minddata/dataset/include/vision.h"

#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"

using namespace mindspore::dataset::api;
using mindspore::dataset::BorderType;


+ 7
- 1
tests/ut/cpp/dataset/c_api_vision_test.cc View File

@@ -18,8 +18,14 @@
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/vision.h"

#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"

using namespace mindspore::dataset::api;
using mindspore::dataset::BorderType;
using mindspore::dataset::Tensor;


Loading…
Cancel
Save