Browse Source

!9482 Remove Repeat(1), Take(-1), and Skip(0) in IR optimizer

From: @nsyca
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a13f9acb79
80 changed files with 527 additions and 426 deletions
  1. +0
    -13
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +9
    -9
      mindspore/ccsrc/minddata/dataset/engine/data_schema.h
  3. +6
    -6
      mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h
  4. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h
  5. +15
    -15
      mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h
  6. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h
  7. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h
  8. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h
  9. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
  10. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h
  11. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
  12. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h
  13. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h
  14. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
  15. +10
    -10
      mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h
  16. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h
  17. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h
  18. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h
  19. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h
  20. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
  21. +5
    -5
      mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h
  22. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h
  23. +20
    -20
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h
  24. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h
  25. +9
    -9
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h
  26. +22
    -22
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h
  27. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h
  28. +11
    -11
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h
  29. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h
  30. +13
    -13
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h
  31. +3
    -3
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h
  32. +16
    -16
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h
  33. +8
    -8
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h
  34. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h
  35. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h
  36. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h
  37. +7
    -7
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h
  38. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h
  39. +15
    -15
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h
  40. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h
  41. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h
  42. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
  43. +10
    -10
      mindspore/ccsrc/minddata/dataset/engine/execution_tree.h
  44. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h
  45. +9
    -9
      mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h
  46. +9
    -9
      mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h
  47. +18
    -18
      mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h
  48. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h
  49. +5
    -5
      mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h
  50. +5
    -5
      mindspore/ccsrc/minddata/dataset/engine/gnn/node.h
  51. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h
  52. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h
  53. +4
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h
  54. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h
  55. +6
    -6
      mindspore/ccsrc/minddata/dataset/engine/opt/pass.h
  56. +14
    -14
      mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
  57. +20
    -20
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h
  58. +21
    -21
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h
  59. +10
    -10
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.h
  60. +5
    -5
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h
  61. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h
  62. +31
    -1
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc
  63. +22
    -4
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h
  64. +4
    -4
      mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h
  65. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h
  66. +1
    -1
      mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h
  67. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h
  68. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h
  69. +5
    -5
      mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h
  70. +1
    -1
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  71. +5
    -5
      mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h
  72. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h
  73. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h
  74. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h
  75. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h
  76. +1
    -1
      mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h
  77. +6
    -6
      mindspore/ccsrc/minddata/dataset/util/data_helper.h
  78. +0
    -4
      mindspore/dataset/engine/datasets.py
  79. +45
    -0
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
  80. +16
    -3
      tests/ut/python/dataset/test_take.py

+ 0
- 13
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -490,12 +490,6 @@ RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<s
#endif #endif


RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) { RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
// Workaround for repeat == 1, do not inject repeat.
if (count == 1) {
ir_node_ = input->IRNode();
return;
}

auto ds = std::make_shared<RepeatNode>(input->IRNode(), count); auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);


ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@@ -516,13 +510,6 @@ SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
} }


TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) { TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
// If count is greater than the number of element in dataset or equal to -1,
// all the element in dataset will be taken
if (count == -1) {
ir_node_ = input->IRNode();
return;
}

auto ds = std::make_shared<TakeNode>(input->IRNode(), count); auto ds = std::make_shared<TakeNode>(input->IRNode(), count);


ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);


+ 9
- 9
mindspore/ccsrc/minddata/dataset/engine/data_schema.h View File

@@ -66,7 +66,7 @@ class ColDescriptor {
/// an unknown dimension, then the output shape returned shall resolve dimensions as needed. /// an unknown dimension, then the output shape returned shall resolve dimensions as needed.
/// \param[in] num_elements - The number of elements in the data for a Tensor /// \param[in] num_elements - The number of elements in the data for a Tensor
/// \param[inout] out_shape - The materialized output Tensor shape /// \param[inout] out_shape - The materialized output Tensor shape
/// \return Status - The error code return
/// \return Status The status code returned
Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const;


/// \brief << Stream output operator overload /// \brief << Stream output operator overload
@@ -124,13 +124,13 @@ class DataSchema {
/// \brief Parses a schema json file and populates the columns and meta info. /// \brief Parses a schema json file and populates the columns and meta info.
/// \param[in] schema_file_path - the schema file that has the column's info to load /// \param[in] schema_file_path - the schema file that has the column's info to load
/// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
/// \return Status - The error code return
/// \return Status The status code returned
Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load); Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load);


/// \brief Parses a schema JSON string and populates the columns and meta info. /// \brief Parses a schema JSON string and populates the columns and meta info.
/// \param[in] schema_json_string - the schema file that has the column's info to load /// \param[in] schema_json_string - the schema file that has the column's info to load
/// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
/// \return Status - The error code return
/// \return Status The status code returned
Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load); Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load);


/// \brief A print method typically used for debugging /// \brief A print method typically used for debugging
@@ -148,7 +148,7 @@ class DataSchema {


/// \brief Adds a column descriptor to the schema /// \brief Adds a column descriptor to the schema
/// \param[in] cd - The ColDescriptor to add /// \param[in] cd - The ColDescriptor to add
/// \return Status - The error code return
/// \return Status The status code returned
Status AddColumn(const ColDescriptor &cd); Status AddColumn(const ColDescriptor &cd);


/// \brief getter /// \brief getter
@@ -169,7 +169,7 @@ class DataSchema {


/// \brief Loops through all columns in the schema and returns a map with the column name to column index number. /// \brief Loops through all columns in the schema and returns a map with the column name to column index number.
/// \param[inout] out_column_name_map - The output map of columns names to column index /// \param[inout] out_column_name_map - The output map of columns names to column index
/// \return Status - The error code return
/// \return Status The status code returned
Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map); Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map);


private: private:
@@ -177,7 +177,7 @@ class DataSchema {
/// does not follow any particular order (json standard does not enforce any ordering protocol). /// does not follow any particular order (json standard does not enforce any ordering protocol).
/// This one produces a schema that contains all of the columns from the schema file. /// This one produces a schema that contains all of the columns from the schema file.
/// \param[in] column_tree - The nlohmann tree from the json file to parse /// \param[in] column_tree - The nlohmann tree from the json file to parse
/// \return Status - The error code return
/// \return Status The status code returned
Status AnyOrderLoad(nlohmann::json column_tree); Status AnyOrderLoad(nlohmann::json column_tree);


/// \brief Internal helper function. For each input column name, perform a lookup to the json document to /// \brief Internal helper function. For each input column name, perform a lookup to the json document to
@@ -185,18 +185,18 @@ class DataSchema {
/// descriptor and add to the schema in the order in which the input column names are given. /// descriptor and add to the schema in the order in which the input column names are given.
/// \param[in] column_tree - The nlohmann tree from the json file to parse /// \param[in] column_tree - The nlohmann tree from the json file to parse
/// \param[in] columns_to_load - list of strings for the columns to add to the schema /// \param[in] columns_to_load - list of strings for the columns to add to the schema
/// \return Status - The error code return
/// \return Status The status code returned
Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load); Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load);


/// \brief Internal helper function. Given the json tree for a given column, load it into our schema. /// \brief Internal helper function. Given the json tree for a given column, load it into our schema.
/// \param[in] columnTree - The nlohmann child tree for a given column to load. /// \param[in] columnTree - The nlohmann child tree for a given column to load.
/// \param[in] col_name - The string name of the column for that subtree. /// \param[in] col_name - The string name of the column for that subtree.
/// \return Status - The error code return
/// \return Status The status code returned
Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name);


/// \brief Internal helper function. Performs sanity checks on the json file setup. /// \brief Internal helper function. Performs sanity checks on the json file setup.
/// \param[in] js - The nlohmann tree for the schema file /// \param[in] js - The nlohmann tree for the schema file
/// \return Status - The error code return
/// \return Status The status code returned
Status PreLoadExceptionCheck(const nlohmann::json &js); Status PreLoadExceptionCheck(const nlohmann::json &js);


std::vector<ColDescriptor> col_descs_; // Vector of column descriptors std::vector<ColDescriptor> col_descs_; // Vector of column descriptors


+ 6
- 6
mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h View File

@@ -53,7 +53,7 @@ class IteratorBase {
// functionality exists in the derived versions of this function. // functionality exists in the derived versions of this function.
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
// @return Status - The error code return
// @return Status The status code returned
// @note The position of a Tensor/column might be different from the initial column order // @note The position of a Tensor/column might be different from the initial column order
// in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change
// the column ordering. // the column ordering.
@@ -97,17 +97,17 @@ class DatasetIterator : public IteratorBase {
// from the tree root node directly. // from the tree root node directly.
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
// @return Status - The error code return
// @return Status The status code returned
Status FetchNextTensorRow(TensorRow *out_row) override; Status FetchNextTensorRow(TensorRow *out_row) override;


// Fetches the next tensor row into device row, and returns it's shape. // Fetches the next tensor row into device row, and returns it's shape.
// @param out_shapes - A vector of tensor shapes (one shape per column) // @param out_shapes - A vector of tensor shapes (one shape per column)
// @return Status - The error code return
// @return Status The status code returned
Status GetOutputShapes(std::vector<TensorShape> *out_shapes); Status GetOutputShapes(std::vector<TensorShape> *out_shapes);


// Fetches the next tensor row into device row, and returns it's shape. // Fetches the next tensor row into device row, and returns it's shape.
// @param outShapes - A vector of tensor shapes (one shape per column) // @param outShapes - A vector of tensor shapes (one shape per column)
// @return Status - The error code return
// @return Status The status code returned
Status GetOutputTypes(std::vector<DataType> *out_types); Status GetOutputTypes(std::vector<DataType> *out_types);


// Getter // Getter
@@ -140,12 +140,12 @@ class ChildIterator : public IteratorBase {
// only from the child/worker id as given from the constructor. // only from the child/worker id as given from the constructor.
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
// @return Status - The error code return
// @return Status The status code returned
Status FetchNextTensorRow(TensorRow *out_row) override; Status FetchNextTensorRow(TensorRow *out_row) override;


// This function drains buffer until next eoe has been received. // This function drains buffer until next eoe has been received.
// It will be a no-op if the previous row returned is empty. // It will be a no-op if the previous row returned is empty.
// @return Status - The error code return
// @return Status The status code returned
Status Drain(); Status Drain();


// Getter // Getter


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h View File

@@ -134,7 +134,7 @@ class BarrierOp : public PipelineOp {
// Class functor operator () override. // Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Handles preprocessing of the main loop, used when starting new epoch // Handles preprocessing of the main loop, used when starting new epoch


+ 15
- 15
mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h View File

@@ -112,12 +112,12 @@ class BatchOp : public ParallelOp {
#endif #endif


// @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg
// @return Status - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<BatchOp> *); Status Build(std::shared_ptr<BatchOp> *);


private: private:
// Sanity check for builder class args // Sanity check for builder class args
// @return Status - The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


bool builder_drop_; bool builder_drop_;
@@ -167,11 +167,11 @@ class BatchOp : public ParallelOp {
~BatchOp() {} ~BatchOp() {}


// @param int32_t workerId // @param int32_t workerId
// @return Status - The error code return
// @return Status The status code returned
Status EofReceived(int32_t) override; Status EofReceived(int32_t) override;


// @param int32_t workerId // @param int32_t workerId
// @return Status - The error code return
// @return Status The status code returned
Status EoeReceived(int32_t) override; Status EoeReceived(int32_t) override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -190,7 +190,7 @@ class BatchOp : public ParallelOp {
} }


// Main loop of batch // Main loop of batch
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Base-class override for NodePass visitor acceptor. // Base-class override for NodePass visitor acceptor.
@@ -214,14 +214,14 @@ class BatchOp : public ParallelOp {
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
// @param int32_t size - batch_size // @param int32_t size - batch_size
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
// @return Status The status code returned
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
dsize_t batch_size); dsize_t batch_size);


// @param table // @param table
// @param const PadInfo &pad_info pad info // @param const PadInfo &pad_info pad info
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
// @return Status The status code returned
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map); const std::unordered_map<std::string, int32_t> &column_name_id_map);


@@ -233,18 +233,18 @@ class BatchOp : public ParallelOp {
private: private:
// Worker thread for doing the memcpy of batch // Worker thread for doing the memcpy of batch
// @param int32_t param workerId // @param int32_t param workerId
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Generate buffer with batched tensors // Generate buffer with batched tensors
// @return Status - The error code return
// @return Status The status code returned
Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
std::unique_ptr<DataBuffer> *db); std::unique_ptr<DataBuffer> *db);


#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
// Function that calls pyfunc to perform map on batch // Function that calls pyfunc to perform map on batch
// @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
// @return Status - The error code return
// @return Status The status code returned
Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair); Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
#endif #endif


@@ -253,7 +253,7 @@ class BatchOp : public ParallelOp {
// @param std::set<int32_t> *cols, col ids to perform pad on // @param std::set<int32_t> *cols, col ids to perform pad on
// @param std::vector<float> *vals, default padding value for each column // @param std::vector<float> *vals, default padding value for each column
// @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user // @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user
// @return Status - The error code return
// @return Status The status code returned
static Status UnpackPadInfo(const PadInfo &pad_info, static Status UnpackPadInfo(const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map, const std::unordered_map<std::string, int32_t> &column_name_id_map,
std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
@@ -264,20 +264,20 @@ class BatchOp : public ParallelOp {
int32_t num_consumers() const override { return 1; } int32_t num_consumers() const override { return 1; }


// get the batch size for next batch // get the batch size for next batch
// @return Status - The error code return
// @return Status The status code returned
Status GetBatchSize(int32_t *batch_size, CBatchInfo info); Status GetBatchSize(int32_t *batch_size, CBatchInfo info);


// Do the initialization of all queues then start all worker threads // Do the initialization of all queues then start all worker threads
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
// Invoke batch size function with current BatchInfo to generate batch size. // Invoke batch size function with current BatchInfo to generate batch size.
// @return Status - The error code return
// @return Status The status code returned
Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info);


// Invoke batch map function with current BatchInfo to generate tensors to batch. // Invoke batch map function with current BatchInfo to generate tensors to batch.
// @return Status - The error code return
// @return Status The status code returned
Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
#endif #endif




+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h View File

@@ -107,7 +107,7 @@ class BucketBatchByLengthOp : public PipelineOp {


// Might need to batch remaining buckets after receiving eoe, so override this method. // Might need to batch remaining buckets after receiving eoe, so override this method.
// @param int32_t workerId // @param int32_t workerId
// @return Status - The error code returned
// @return Status The status code returned
Status EoeReceived(int32_t) override; Status EoeReceived(int32_t) override;


std::string Name() const override { return kBucketBatchByLengthOp; } std::string Name() const override { return kBucketBatchByLengthOp; }
@@ -123,7 +123,7 @@ class BucketBatchByLengthOp : public PipelineOp {
} }


// Main loop of batch // Main loop of batch
// @return Status - The error code returned
// @return Status The status code returned
Status operator()() override; Status operator()() override;


private: private:


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h View File

@@ -104,7 +104,7 @@ class BuildSentencePieceVocabOp : public PipelineOp {


// The builder "build" method creates the final object. // The builder "build" method creates the final object.
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op); Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op);


private: private:


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h View File

@@ -110,7 +110,7 @@ class BuildVocabOp : public ParallelOp {


// The builder "build" method creates the final object. // The builder "build" method creates the final object.
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<BuildVocabOp> *op); Status Build(std::shared_ptr<BuildVocabOp> *op);


private: private:


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h View File

@@ -53,7 +53,7 @@ class CacheBase : public ParallelOp {
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
/// info from it's previous execution and then initializes itself so that it can be executed /// info from it's previous execution and then initializes itself so that it can be executed
/// again. /// again.
/// \return Status - The error code return
/// \return Status The status code returned
Status Reset() override; Status Reset() override;


/// \brief A print method typically used for debugging /// \brief A print method typically used for debugging


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h View File

@@ -80,7 +80,7 @@ class CacheLookupOp : public CacheBase, public SamplerRT {
std::shared_ptr<SamplerRT> build_sampler_; std::shared_ptr<SamplerRT> build_sampler_;


// Check if the required parameters are set by the builder. // Check if the required parameters are set by the builder.
// \return Status The error code return
// \return Status The status code returned
Status SanityCheck() const; Status SanityCheck() const;
}; };
/// \brief Constructor /// \brief Constructor


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h View File

@@ -136,7 +136,7 @@ class CacheMergeOp : public ParallelOp {
std::shared_ptr<SamplerRT> build_sampler_; std::shared_ptr<SamplerRT> build_sampler_;


/// Check if the required parameters are set by the builder. /// Check if the required parameters are set by the builder.
/// \return Status The error code return
/// \return Status The status code returned
Status SanityCheck() const; Status SanityCheck() const;
}; };


@@ -189,7 +189,7 @@ class CacheMergeOp : public ParallelOp {


/// \brief Base-class override for handling cases when an eof is received. /// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id /// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
Status EofReceived(int32_t worker_id) override; Status EofReceived(int32_t worker_id) override;


protected: protected:


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h View File

@@ -99,7 +99,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
std::shared_ptr<SamplerRT> build_sampler_; std::shared_ptr<SamplerRT> build_sampler_;


/// \brief Check if the required parameters are set by the builder. /// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return
/// \return Status The status code returned
Status SanityCheck() const; Status SanityCheck() const;
}; };


@@ -119,7 +119,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \brief Base-class override for special eoe handler. /// \brief Base-class override for special eoe handler.
/// CacheOp must override this because it shall not perform default handling of eoe. Instead /// CacheOp must override this because it shall not perform default handling of eoe. Instead
/// the CacheOp manages actions related to the end of the epoch. /// the CacheOp manages actions related to the end of the epoch.
/// \return Status - The error code return
/// \return Status The status code returned
Status EoeReceived(int32_t worker_id) override; Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for NodePass pre-visit acceptor /// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit /// \param[in] p The node to visit
@@ -133,7 +133,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for handling cases when an eof is received. /// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id /// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
Status EofReceived(int32_t worker_id) override; Status EofReceived(int32_t worker_id) override;
Status operator()() override; Status operator()() override;
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;
@@ -159,7 +159,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
Status CacheAllRows(int32_t worker_id); Status CacheAllRows(int32_t worker_id);
Status RegisterResources() override; Status RegisterResources() override;
/// \brief Private function for cache setup/init work just after construction /// \brief Private function for cache setup/init work just after construction
/// \return Status The error code return
/// \return Status The status code returned
Status InitCache(); Status InitCache();
}; };
} // namespace dataset } // namespace dataset


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h View File

@@ -94,7 +94,7 @@ class ConcatOp : public PipelineOp {


// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Op name getter // Op name getter


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h View File

@@ -146,14 +146,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// DatasetOps operate by launching a thread (see ExecutionTree). /// DatasetOps operate by launching a thread (see ExecutionTree).
/// This pure virtual version makes the requirement that derived classes must provide a functor /// This pure virtual version makes the requirement that derived classes must provide a functor
/// that will execute their main runtime loop code. /// that will execute their main runtime loop code.
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status operator()() = 0; virtual Status operator()() = 0;


/// \brief Gets the next buffer from the given child /// \brief Gets the next buffer from the given child
/// \notes See GetNextInput for similar function that has built-in message handling /// \notes See GetNextInput for similar function that has built-in message handling
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \param worker_id - The worker id /// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) { virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) {
return GetNextBuffer(p_buffer, worker_id, false); return GetNextBuffer(p_buffer, worker_id, false);
} }
@@ -161,7 +161,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \brief Gets the next buffer from the given child /// \brief Gets the next buffer from the given child
/// \notes See GetNextInput for similar function that has built-in message handling /// \notes See GetNextInput for similar function that has built-in message handling
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); }


/// \brief Gets the next buffer from the given child /// \brief Gets the next buffer from the given child
@@ -169,7 +169,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \param worker_id - The worker id /// \param worker_id - The worker id
/// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe); virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe);


/// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof
@@ -177,7 +177,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// those messages are received. /// those messages are received.
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \param worker_id - The worker id /// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);


/// \brief Gets the batch size /// \brief Gets the batch size
@@ -200,19 +200,19 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// The base class implementation simply flows the eoe message to output. Derived classes /// The base class implementation simply flows the eoe message to output. Derived classes
/// may override if they need to perform special eoe handling. /// may override if they need to perform special eoe handling.
/// \param worker_id - The worker id /// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status EoeReceived(int32_t worker_id); virtual Status EoeReceived(int32_t worker_id);


/// \brief Performs handling for when an eof message is received. /// \brief Performs handling for when an eof message is received.
/// The base class implementation simply flows the eof message to output. Derived classes /// The base class implementation simply flows the eof message to output. Derived classes
/// may override if they need to perform special eof handling. /// may override if they need to perform special eof handling.
/// \param worker_id - The worker id /// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status EofReceived(int32_t worker_id); virtual Status EofReceived(int32_t worker_id);


/// \brief Derived classes may implement the reset function if the operator is stateful and needs /// \brief Derived classes may implement the reset function if the operator is stateful and needs
/// specific reset handling that is not contained in this common code version of the reset /// specific reset handling that is not contained in this common code version of the reset
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status Reset(); virtual Status Reset();


/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on


+ 10
- 10
mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h View File

@@ -79,7 +79,7 @@ class FilterOp : public ParallelOp {


private: private:
// Sanity check for builder class args. // Sanity check for builder class args.
// @return Status - The error code return.
// @return Status The status code returned.
Status SanityCheck(); Status SanityCheck();
std::vector<std::string> build_in_col_names_; std::vector<std::string> build_in_col_names_;
std::shared_ptr<TensorOp> builder_predicate_func_; std::shared_ptr<TensorOp> builder_predicate_func_;
@@ -105,15 +105,15 @@ class FilterOp : public ParallelOp {
// Class functor operator () override. // Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree),This class functor will // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
// provide the master loop that drives the logic for performing the work. // provide the master loop that drives the logic for performing the work.
// @return Status The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// @param int32_t workerId. // @param int32_t workerId.
// @return Status - The error code return.
// @return Status The status code returned.
Status EofReceived(int32_t) override; Status EofReceived(int32_t) override;


// @param int32_t workerId. // @param int32_t workerId.
// @return Status - The error code return.
// @return Status The status code returned.
Status EoeReceived(int32_t) override; Status EoeReceived(int32_t) override;


// A print method typically used for debugging. // A print method typically used for debugging.
@@ -151,34 +151,34 @@ class FilterOp : public ParallelOp {
// logic of FilterOp, getting the data from previous Op, validating user specified column names, // logic of FilterOp, getting the data from previous Op, validating user specified column names,
// applying predicate to each of the data, filter the data when predicate result is false. // applying predicate to each of the data, filter the data when predicate result is false.
// @param worker_id The id assigned to this thread/worker upon creation. // @param worker_id The id assigned to this thread/worker upon creation.
// @return Status The error code return.
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_


// Filter the data by predicate function . // Filter the data by predicate function .
// @param in_buffer input data buffer. // @param in_buffer input data buffer.
// @param to_proess_indices Indices of columns to be processed. // @param to_proess_indices Indices of columns to be processed.
// @param out data buffer that are filtered by predicate. // @param out data buffer that are filtered by predicate.
// @return Status The error code return.
// @return Status The status code returned
Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out); Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out);


// Collector databuffer. // Collector databuffer.
// @return Status The error code return.
// @return Status The status code returned
Status Collector(); Status Collector();


// @param input tensor vector. // @param input tensor vector.
// @return Status - The error code return.
// @return Status The status code returned.
Status CheckInput(const TensorRow &input) const; Status CheckInput(const TensorRow &input) const;


// Invoke python func. // Invoke python func.
// @param input tensor vector. // @param input tensor vector.
// @param the result of predicate. // @param the result of predicate.
// @return Status - The error code return.
// @return Status The status code returned.
Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate);


// Private function for validating if each of the user specified input column names // Private function for validating if each of the user specified input column names
// exist in the DataBuffer. // exist in the DataBuffer.
// @param input_columns The vector of input column names used in the current thread. // @param input_columns The vector of input column names used in the current thread.
// @return Status The error code return.
// @return Status The status code returned
Status ValidateInColumns(const std::vector<std::string> *input_columns); Status ValidateInColumns(const std::vector<std::string> *input_columns);


// Private function for checking the column legality // Private function for checking the column legality


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op/map_op.h View File

@@ -133,7 +133,7 @@ class MapOp : public ParallelOp {
int32_t build_op_connector_size_; int32_t build_op_connector_size_;


// Check if the required parameters are set by the builder. // Check if the required parameters are set by the builder.
// @return Status The error code return
// @return Status The status code returned
Status sanityCheck() const; Status sanityCheck() const;
}; };


@@ -170,7 +170,7 @@ class MapOp : public ParallelOp {
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// This main thread creates local queues, pulls databuffers from the previous // This main thread creates local queues, pulls databuffers from the previous
// op's Connector and distributes them to the local queues. Workers pull from the local queues. // op's Connector and distributes them to the local queues. Workers pull from the local queues.
// @return Status The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Getter // Getter
@@ -239,7 +239,7 @@ class MapOp : public ParallelOp {
// applying a list of TensorOps to each of the data, process the results and then // applying a list of TensorOps to each of the data, process the results and then
// pushing them back to MapOp's output Connector to be fetched by the next Op. // pushing them back to MapOp's output Connector to be fetched by the next Op.
// @param worker_id The id assigned to this thread/worker upon creation. // @param worker_id The id assigned to this thread/worker upon creation.
// @return Status The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_


// Private function for worker thread to perform TensorOp's compute function and get the result. // Private function for worker thread to perform TensorOp's compute function and get the result.


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h View File

@@ -89,7 +89,7 @@ class ParallelOp : public DatasetOp {
} }


// Override base class reset to provide reset actions specific to the ParallelOp class. // Override base class reset to provide reset actions specific to the ParallelOp class.
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Getter // Getter
@@ -115,7 +115,7 @@ class ParallelOp : public DatasetOp {
protected: protected:
// Interface for derived classes to implement. All derived classes must provide the entry // Interface for derived classes to implement. All derived classes must provide the entry
// function with the main execution loop for worker threads. // function with the main execution loop for worker threads.
// @return Status - The error code return
// @return Status The status code returned
virtual Status WorkerEntry(int32_t workerId) = 0; virtual Status WorkerEntry(int32_t workerId) = 0;


/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h View File

@@ -75,7 +75,7 @@ class ProjectOp : public PipelineOp {
// However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to // functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error). // ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code returned.
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node.
@@ -93,12 +93,12 @@ class ProjectOp : public PipelineOp {


// Base-class override for special eoe handler. // Base-class override for special eoe handler.
// Inline operators must override this because there is no connector to push eoe onto. // Inline operators must override this because there is no connector to push eoe onto.
// @return Status - The error code returned.
// @return Status The status code returned
Status EoeReceived(int32_t worker_id) override; Status EoeReceived(int32_t worker_id) override;


// Base-class override for special eof handler. // Base-class override for special eof handler.
// Inline operators must override this because there is no connector to push eof onto. // Inline operators must override this because there is no connector to push eof onto.
// @return Status - The error code returned.
// @return Status The status code returned
Status EofReceived(int32_t worker_id) override; Status EofReceived(int32_t worker_id) override;


// Base-class override for NodePass visitor acceptor. // Base-class override for NodePass visitor acceptor.


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h View File

@@ -107,7 +107,7 @@ class RenameOp : public PipelineOp {
// Class functor operator () override. // Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Base-class override for NodePass visitor acceptor. // Base-class override for NodePass visitor acceptor.


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h View File

@@ -78,7 +78,7 @@ class RepeatOp : public PipelineOp {
// However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to // functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error). // ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// This function returns the buffer that is at the top of our output connector. The caller is // This function returns the buffer that is at the top of our output connector. The caller is
@@ -90,7 +90,7 @@ class RepeatOp : public PipelineOp {
// @param p_buffer - output pointer to the buffer that it will fetch. // @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id // @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
// @return Status The status code returned
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;


// Base-class override for handling cases when an eoe is received. // Base-class override for handling cases when an eoe is received.
@@ -130,7 +130,7 @@ class RepeatOp : public PipelineOp {
int32_t num_repeats() { return num_repeats_; } int32_t num_repeats() { return num_repeats_; }


/// \brief reset Op /// \brief reset Op
/// \@return Status - The error code return
/// \@return Status The status code returned
Status Reset() override; Status Reset() override;


int64_t GetTreeRepeatCount() override; int64_t GetTreeRepeatCount() override;


+ 5
- 5
mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h View File

@@ -146,13 +146,13 @@ class ShuffleOp : public PipelineOp {
// Class functor operator () override. // Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Base-class override for special eoe handler. // Base-class override for special eoe handler.
// ShuffleOp must override this because it shall not perform default handling of eoe. Instead // ShuffleOp must override this because it shall not perform default handling of eoe. Instead
// the ShuffleOp needs to manage actions related to the end of the epoch itself. // the ShuffleOp needs to manage actions related to the end of the epoch itself.
// @return Status - The error code return
// @return Status The status code returned
Status EoeReceived(int32_t worker_id) override; Status EoeReceived(int32_t worker_id) override;


// Base-class override for NodePass visitor acceptor. // Base-class override for NodePass visitor acceptor.
@@ -167,17 +167,17 @@ class ShuffleOp : public PipelineOp {


private: private:
// Private function to add a new row to the shuffle buffer. // Private function to add a new row to the shuffle buffer.
// @return Status - The error code return
// @return Status The status code returned
Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); Status AddRowToShuffleBuffer(TensorRow new_shuffle_row);


// Private function to populate the shuffle buffer initially by fetching from the child output // Private function to populate the shuffle buffer initially by fetching from the child output
// connector until the shuffle buffer is full (or there is no more data coming). // connector until the shuffle buffer is full (or there is no more data coming).
// @return Status - The error code return
// @return Status The status code returned
Status InitShuffleBuffer(); Status InitShuffleBuffer();


// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by
// itself rather than waiting for the reset driven from operators above it in the pipeline. // itself rather than waiting for the reset driven from operators above it in the pipeline.
// @return Status - The error code return
// @return Status The status code returned
Status SelfReset(); Status SelfReset();


int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows)


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h View File

@@ -63,7 +63,7 @@ class SkipOp : public PipelineOp {
// Class functor operator () override. // Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Base-class override for handling cases when an eoe is received. // Base-class override for handling cases when an eoe is received.


+ 20
- 20
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/album_op.h View File

@@ -130,12 +130,12 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
} }


/// \brief Check validity of input args /// \brief Check validity of input args
/// \return - The error code returned
/// \return Status The status code returned
Status SanityCheck(); Status SanityCheck();


/// \brief The builder "build" method creates the final object. /// \brief The builder "build" method creates the final object.
/// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp /// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp
/// \return - The error code returned
/// \return Status The status code returned
Status Build(std::shared_ptr<AlbumOp> *op); Status Build(std::shared_ptr<AlbumOp> *op);


private: private:
@@ -168,18 +168,18 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
~AlbumOp() = default; ~AlbumOp() = default;


/// \brief Initialize AlbumOp related var, calls the function to walk all files /// \brief Initialize AlbumOp related var, calls the function to walk all files
/// \return - The error code returned
/// \return Status The status code returned
Status PrescanEntry(); Status PrescanEntry();


/// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector /// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
/// \param[in] int32_t workerId - id of each worker /// \param[in] int32_t workerId - id of each worker
/// \return Status - The error code returned
/// \return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


/// \brief Main Loop of AlbumOp /// \brief Main Loop of AlbumOp
/// Master thread: Fill IOBlockQueue, then goes to sleep /// Master thread: Fill IOBlockQueue, then goes to sleep
/// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector /// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
/// \return Status - The error code returned
/// \return Status The status code returned
Status operator()() override; Status operator()() override;


/// \brief A print method typically used for debugging /// \brief A print method typically used for debugging
@@ -204,93 +204,93 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {


private: private:
/// \brief Initialize Sampler, calls sampler->Init() within /// \brief Initialize Sampler, calls sampler->Init() within
/// \return Status The error code returned
/// \return Status The status code returned
Status InitSampler(); Status InitSampler();


/// \brief Load image to tensor row /// \brief Load image to tensor row
/// \param[in] image_file Image name of file /// \param[in] image_file Image name of file
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row); Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row);


/// \brief Load vector of ints to tensor, append tensor to tensor row /// \brief Load vector of ints to tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing multi-dimensional label /// \param[in] json_obj Json object containing multi-dimensional label
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);


/// \brief Load vector of floatss to tensor, append tensor to tensor row /// \brief Load vector of floatss to tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing array data /// \param[in] json_obj Json object containing array data
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);


/// \brief Load string array into a tensor, append tensor to tensor row /// \brief Load string array into a tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing string tensor /// \param[in] json_obj Json object containing string tensor
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);


/// \brief Load string into a tensor, append tensor to tensor row /// \brief Load string into a tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing string tensor /// \param[in] json_obj Json object containing string tensor
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);


/// \brief Load float value to tensor row /// \brief Load float value to tensor row
/// \param[in] json_obj Json object containing float /// \param[in] json_obj Json object containing float
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);


/// \brief Load int value to tensor row /// \brief Load int value to tensor row
/// \param[in] json_obj Json object containing int /// \param[in] json_obj Json object containing int
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);


/// \brief Load emtpy tensor to tensor row /// \brief Load emtpy tensor to tensor row
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadEmptyTensor(uint32_t col_num, TensorRow *row); Status LoadEmptyTensor(uint32_t col_num, TensorRow *row);


/// \brief Load id from file name to tensor row /// \brief Load id from file name to tensor row
/// \param[in] file The file name to get ID from /// \param[in] file The file name to get ID from
/// \param[in] col_num Column num in schema /// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to /// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row); Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row);


/// \brief Load a tensor row according to a json file /// \brief Load a tensor row according to a json file
/// \param[in] row_id_type row_id - id for this tensor row /// \param[in] row_id_type row_id - id for this tensor row
/// \param[in] ImageColumns file Json file location /// \param[in] ImageColumns file Json file location
/// \param[inout] TensorRow row Json content stored into a tensor row /// \param[inout] TensorRow row Json content stored into a tensor row
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::string &file, TensorRow *row); Status LoadTensorRow(row_id_type row_id, const std::string &file, TensorRow *row);


/// \param[in] const std::vector<int64_t> &keys Keys in ioblock /// \param[in] const std::vector<int64_t> &keys Keys in ioblock
/// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to /// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


/// \brief Called first when function is called /// \brief Called first when function is called
/// \return Status The error code returned
/// \return Status The status code returned
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


/// \brief reset Op /// \brief reset Op
/// \return Status The error code return
/// \return Status The status code returned
Status Reset() override; Status Reset() override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.
// @return Status The error code returned
// @return Status The status code returned
Status ComputeColMap() override; Status ComputeColMap() override;


int32_t rows_per_buffer_; int32_t rows_per_buffer_;


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h View File

@@ -116,12 +116,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
return *this; return *this;
} }
// Check validity of input args // Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


// The builder "build" method creates the final object. // The builder "build" method creates the final object.
// @param std::shared_ptr<CelebAOp> *op - DatasetOp // @param std::shared_ptr<CelebAOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<CelebAOp> *op); Status Build(std::shared_ptr<CelebAOp> *op);


private: private:
@@ -151,12 +151,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// Main Loop of CelebAOp // Main Loop of CelebAOp
// Master thread: Fill IOBlockQueue, then goes to sleep // Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker // @param int32_t worker_id - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -166,7 +166,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {


// Method in operator(), to fill IOBlockQueue // Method in operator(), to fill IOBlockQueue
// @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue
// @return Status - The error code return
// @return Status The status code returned
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer); Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);


/// \brief Base-class override for NodePass visitor acceptor /// \brief Base-class override for NodePass visitor acceptor
@@ -199,14 +199,14 @@ class CelebAOp : public ParallelOp, RandomAccessOp {


// @param const std::vector<int64_t> &keys - keys in ioblock // @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db // @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


// Load a tensor row according to a pair // Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row // @param row_id_type row_id - id for this tensor row
// @param std::pair - <image_file,<label>> // @param std::pair - <image_file,<label>>
// @param TensorRow row - image & label read into this tensor row // @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<int32_t>> &image_label, Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<int32_t>> &image_label,
TensorRow *row); TensorRow *row);


@@ -215,7 +215,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
bool CheckDatasetTypeValid(); bool CheckDatasetTypeValid();


// reset Op // reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.


+ 9
- 9
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h View File

@@ -109,12 +109,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
} }


// Check validity of input args // Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


// The builder "build" method creates the final object. // The builder "build" method creates the final object.
// @param std::shared_ptr<CifarOp> *op - DatasetOp // @param std::shared_ptr<CifarOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<CifarOp> *op); Status Build(std::shared_ptr<CifarOp> *op);


private: private:
@@ -144,13 +144,13 @@ class CifarOp : public ParallelOp, public RandomAccessOp {


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param uint32_t workerId - id of each worker // @param uint32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Main Loop of CifarOp // Main Loop of CifarOp
// Master thread: Fill IOBlockQueue, then goes to sleep // Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -177,18 +177,18 @@ class CifarOp : public ParallelOp, public RandomAccessOp {


private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler(); Status InitSampler();


// Load a tensor row according to a pair // Load a tensor row according to a pair
// @param uint64_t index - index need to load // @param uint64_t index - index need to load
// @param TensorRow row - image & label read into this tensor row // @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(uint64_t index, TensorRow *row); Status LoadTensorRow(uint64_t index, TensorRow *row);


// @param const std::vector<uint64_t> &keys - keys in ioblock // @param const std::vector<uint64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db // @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


// Read block data from cifar file // Read block data from cifar file
@@ -200,7 +200,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


// reset Op // reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Get cifar files in dir // Get cifar files in dir
@@ -221,7 +221,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {


// Method derived from RandomAccess Op, enable Sampler to get all ids for each calss // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.


+ 22
- 22
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h View File

@@ -133,12 +133,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
} }


// Check validity of input args // Check validity of input args
// @return = The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


// The builder "Build" method creates the final object. // The builder "Build" method creates the final object.
// @param std::shared_ptr<CocoOp> *op - DatasetOp // @param std::shared_ptr<CocoOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<CocoOp> *op); Status Build(std::shared_ptr<CocoOp> *op);


private: private:
@@ -173,13 +173,13 @@ class CocoOp : public ParallelOp, public RandomAccessOp {


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker // @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Main Loop of CocoOp // Main Loop of CocoOp
// Master thread: Fill IOBlockQueue, then goes to sleep // Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -214,19 +214,19 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
std::string Name() const override { return "CocoOp"; } std::string Name() const override { return "CocoOp"; }


/// \brief Gets the class indexing /// \brief Gets the class indexing
/// \return Status - The status code return
/// \return Status The status code returned
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override; Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;


private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler(); Status InitSampler();


// Load a tensor row according to image id // Load a tensor row according to image id
// @param row_id_type row_id - id for this tensor row // @param row_id_type row_id - id for this tensor row
// @param std::string image_id - image id // @param std::string image_id - image id
// @param TensorRow row - image & target read into this tensor row // @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row);


// Load a tensor row with vector which a vector to a tensor // Load a tensor row with vector which a vector to a tensor
@@ -235,7 +235,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Tensor> image - image tensor // @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor // @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row // @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
std::shared_ptr<Tensor> coordinate, TensorRow *trow); std::shared_ptr<Tensor> coordinate, TensorRow *trow);


@@ -245,7 +245,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Tensor> image - image tensor // @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor // @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row // @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
std::shared_ptr<Tensor> coordinate, TensorRow *trow); std::shared_ptr<Tensor> coordinate, TensorRow *trow);


@@ -255,69 +255,69 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Tensor> image - image tensor // @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor // @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row // @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
std::shared_ptr<Tensor> coordinate, TensorRow *trow); std::shared_ptr<Tensor> coordinate, TensorRow *trow);


// @param const std::string &path - path to the image file // @param const std::string &path - path to the image file
// @param const ColDescriptor &col - contains tensor implementation and datatype // @param const ColDescriptor &col - contains tensor implementation and datatype
// @param std::shared_ptr<Tensor> tensor - return // @param std::shared_ptr<Tensor> tensor - return
// @return Status - The error code return
// @return Status The status code returned
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);


// @param const std::vector<uint64_t> &keys - keys in ioblock // @param const std::vector<uint64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db // @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


// Read annotation from Annotation folder // Read annotation from Annotation folder
// @return Status - The error code return
// @return Status The status code returned
Status ParseAnnotationIds(); Status ParseAnnotationIds();


// @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor
// @param std::vector<int64_t> *keys - image id // @param std::vector<int64_t> *keys - image id
// @return Status - The error code return
// @return Status The status code returned
Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);


// Called first when function is called // Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


// Reset dataset state // Reset dataset state
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// @param nlohmann::json image_tree - image tree of json // @param nlohmann::json image_tree - image tree of json
// @param std::vector<std::string> *image_vec - image id list of json // @param std::vector<std::string> *image_vec - image id list of json
// @return Status - The error code return
// @return Status The status code returned
Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec); Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec);


// @param nlohmann::json categories_tree - categories tree of json // @param nlohmann::json categories_tree - categories tree of json
// return Status - The error code return
// @return Status The status code returned
Status CategoriesColumnLoad(const nlohmann::json &categories_tree); Status CategoriesColumnLoad(const nlohmann::json &categories_tree);


// @param nlohmann::json categories_tree - categories tree of json // @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation // @param const std::string &image_file - current image name in annotation
// @param const int32_t &id - current unique id of annotation // @param const int32_t &id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);


// @param nlohmann::json categories_tree - categories tree of json // @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation // @param const std::string &image_file - current image name in annotation
// @param const int32_t &id - current unique id of annotation // @param const int32_t &id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);


// @param nlohmann::json categories_tree - categories tree of json // @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation // @param const std::string &image_file - current image name in annotation
// @param const int32_t &id - current unique id of annotation // @param const int32_t &id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);


// @param nlohmann::json categories_tree - categories tree of json // @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation // @param const std::string &image_file - current image name in annotation
// @param const int32_t &image_id - current unique id of annotation // @param const int32_t &image_id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file,
const int32_t &image_id); const int32_t &image_id);




+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h View File

@@ -115,13 +115,13 @@ class GeneratorOp : public PipelineOp {
// Class functor operator () override. // Class functor operator () override.
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work. // provide the master loop that drives the logic for performing the work.
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Overrides base class reset method. When an operator does a reset, it cleans up any state // Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed // info from it's previous execution and then initializes itself so that it can be executed
// again. // again.
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Base-class override for NodePass visitor acceptor. // Base-class override for NodePass visitor acceptor.


+ 11
- 11
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h View File

@@ -135,12 +135,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
} }


// Check validity of input args // Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


// The builder "build" method creates the final object. // The builder "build" method creates the final object.
// @param std::shared_ptr<ImageFolderOp> *op - DatasetOp // @param std::shared_ptr<ImageFolderOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<ImageFolderOp> *op); Status Build(std::shared_ptr<ImageFolderOp> *op);


private: private:
@@ -172,28 +172,28 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {


// Initialize ImageFOlderOp related var, calls the function to walk all files // Initialize ImageFOlderOp related var, calls the function to walk all files
// @param - std::string dir file directory to ImageNetFolder // @param - std::string dir file directory to ImageNetFolder
// @return - The error code return
// @return Status The status code returned
Status PrescanMasterEntry(const std::string &dir); Status PrescanMasterEntry(const std::string &dir);


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker // @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker // @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status PrescanWorkerEntry(int32_t worker_id); Status PrescanWorkerEntry(int32_t worker_id);


// Main Loop of ImageFolderOp // Main Loop of ImageFolderOp
// Master thread: Fill IOBlockQueue, then goes to sleep // Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Method derived from RandomAccess Op, enable Sampler to get all ids for each class // Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -224,19 +224,19 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {


private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler(); Status InitSampler();


// Load a tensor row according to a pair // Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row // @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label> // @param ImageLabelPair pair - <imagefile,label>
// @param TensorRow row - image & label read into this tensor row // @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row);


// @param const std::vector<int64_t> &keys - keys in ioblock // @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db // @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


// @param std::string & dir - dir to walk all images // @param std::string & dir - dir to walk all images
@@ -253,7 +253,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


// reset Op // reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h View File

@@ -58,12 +58,12 @@ class IOBlock {
// Fetches the first key from the block. // Fetches the first key from the block.
// @note Only useful if you know the block only has 1 key. // @note Only useful if you know the block only has 1 key.
// @return A copy of the first key from the block // @return A copy of the first key from the block
// @return Status - The error code return
// @return Status The status code returned
Status GetKey(int64_t *out_key) const; Status GetKey(int64_t *out_key) const;


// Fetches the list of keys from this block. // Fetches the list of keys from this block.
// @param out_keys - A copy of the vector of keys from the block. // @param out_keys - A copy of the vector of keys from the block.
// @return Status - The error code return
// @return Status The status code returned
Status GetKeys(std::vector<int64_t> *out_keys) const; Status GetKeys(std::vector<int64_t> *out_keys) const;


// Does this block have the eoe flag turned on? // Does this block have the eoe flag turned on?
@@ -110,7 +110,7 @@ class FilenameBlock : public IOBlock {
// Gets the filename from the block using the provided index container // Gets the filename from the block using the provided index container
// @param out_filename - The filename to add to the block // @param out_filename - The filename to add to the block
// @param index - The index to perform lookup against // @param index - The index to perform lookup against
// @return Status - The error code return
// @return Status The status code returned
Status GetFilename(std::string *out_filename, const AutoIndexObj<std::string> &index) const; Status GetFilename(std::string *out_filename, const AutoIndexObj<std::string> &index) const;


// Get the start offset of file // Get the start offset of file


+ 13
- 13
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h View File

@@ -110,12 +110,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
} }


// Check validity of input args // Check validity of input args
// @return Status - The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


// The builder "build" method creates the final object. // The builder "build" method creates the final object.
// @param std::shared_ptr<ManifestOp> *op - DatasetOp // @param std::shared_ptr<ManifestOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<ManifestOp> *op); Status Build(std::shared_ptr<ManifestOp> *op);


private: private:
@@ -145,18 +145,18 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker // @param int32_t worker_id - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Main Loop of ManifestOp // Main Loop of ManifestOp
// Master thread: Fill IOBlockQueue, then goes to sleep // Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Method derived from RandomAccess Op, enable Sampler to get all ids for each class // Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -201,37 +201,37 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {


private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler(); Status InitSampler();


// Method in operator(), to fill IOBlockQueue // Method in operator(), to fill IOBlockQueue
// @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue
// @return Status - The error code return
// @return Status The status code returned
Status AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer); Status AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer);


// Load a tensor row according to a pair // Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row // @param row_id_type row_id - id for this tensor row
// @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>> // @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>>
// @param TensorRow row - image & label read into this tensor row // @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<std::string>> &data, Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<std::string>> &data,
TensorRow *row); TensorRow *row);


// @param const std::vector<int64_t> &keys - keys in ioblock // @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db // @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


// Parse manifest file to get image path and label and so on. // Parse manifest file to get image path and label and so on.
// @return Status - The error code return
// @return Status The status code returned
Status ParseManifestFile(); Status ParseManifestFile();


// Called first when function is called // Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


// reset Op // reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Check if image ia valid.Only support JPEG/PNG/GIF/BMP // Check if image ia valid.Only support JPEG/PNG/GIF/BMP
@@ -239,7 +239,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
Status CheckImageType(const std::string &file_name, bool *valid); Status CheckImageType(const std::string &file_name, bool *valid);


// Count label index,num rows and num samples // Count label index,num rows and num samples
// @return Status - The error code return
// @return Status The status code returned
Status CountDatasetInfo(); Status CountDatasetInfo();


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.


+ 3
- 3
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h View File

@@ -164,13 +164,13 @@ class MindRecordOp : public ParallelOp {


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker // @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Class functor operator () override. // Class functor operator () override.
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work. // provide the master loop that drives the logic for performing the work.
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Called first when function is called // Called first when function is called
@@ -180,7 +180,7 @@ class MindRecordOp : public ParallelOp {
// Overrides base class reset method. When an operator does a reset, it cleans up any state // Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed // info from it's previous execution and then initializes itself so that it can be executed
// again. // again.
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Getter method // Getter method


+ 16
- 16
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h View File

@@ -99,12 +99,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
return *this; return *this;
} }
// Check validity of input args // Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


// The builder "Build" method creates the final object. // The builder "Build" method creates the final object.
// @param std::shared_ptr<MnistOp> *op - DatasetOp // @param std::shared_ptr<MnistOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<MnistOp> *op); Status Build(std::shared_ptr<MnistOp> *op);


private: private:
@@ -133,18 +133,18 @@ class MnistOp : public ParallelOp, public RandomAccessOp {


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker // @param int32_t worker_id - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Main Loop of MnistOp // Main Loop of MnistOp
// Master thread: Fill IOBlockQueue, then goes to sleep // Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Method derived from RandomAccess Op, enable Sampler to get all ids for each class // Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -170,39 +170,39 @@ class MnistOp : public ParallelOp, public RandomAccessOp {


private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler(); Status InitSampler();


// Load a tensor row according to a pair // Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row // @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label> // @param ImageLabelPair pair - <imagefile,label>
// @param TensorRow row - image & label read into this tensor row // @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row);


// @param const std::vector<int64_t> &keys - keys in ioblock // @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db // @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


// Iterate through all members in sampleIds and fill them into IOBlock. // Iterate through all members in sampleIds and fill them into IOBlock.
// @param std::shared_ptr<Tensor> sample_ids - // @param std::shared_ptr<Tensor> sample_ids -
// @param std::vector<int64_t> *keys - keys in ioblock // @param std::vector<int64_t> *keys - keys in ioblock
// @return Status - The error code return
// @return Status The status code returned
Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);


// Check image file stream. // Check image file stream.
// @param const std::string *file_name - image file name // @param const std::string *file_name - image file name
// @param std::ifstream *image_reader - image file stream // @param std::ifstream *image_reader - image file stream
// @param uint32_t num_images - returns the number of images // @param uint32_t num_images - returns the number of images
// @return Status - The error code return
// @return Status The status code returned
Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);


// Check label stream. // Check label stream.
// @param const std::string &file_name - label file name // @param const std::string &file_name - label file name
// @param std::ifstream *label_reader - label file stream // @param std::ifstream *label_reader - label file stream
// @param uint32_t num_labels - returns the number of labels // @param uint32_t num_labels - returns the number of labels
// @return Status - The error code return
// @return Status The status code returned
Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);


// Read 4 bytes of data from a file stream. // Read 4 bytes of data from a file stream.
@@ -219,23 +219,23 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param std::ifstream *image_reader - image file stream // @param std::ifstream *image_reader - image file stream
// @param std::ifstream *label_reader - label file stream // @param std::ifstream *label_reader - label file stream
// @param int64_t read_num - number of image to read // @param int64_t read_num - number of image to read
// @return Status - The error code return
// @return Status The status code returned
Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);


// Parse all mnist dataset files // Parse all mnist dataset files
// @return Status - The error code return
// @return Status The status code returned
Status ParseMnistData(); Status ParseMnistData();


// Read all files in the directory // Read all files in the directory
// @return Status - The error code return
// @return Status The status code returned
Status WalkAllFiles(); Status WalkAllFiles();


// Called first when function is called // Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


// reset Op // reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.


+ 8
- 8
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h View File

@@ -63,7 +63,7 @@ class RandomDataOp : public ParallelOp {
/** /**
* The build method that produces the instantiated RandomDataOp as a shared pointer * The build method that produces the instantiated RandomDataOp as a shared pointer
* @param out_op - The output RandomDataOperator that was constructed * @param out_op - The output RandomDataOperator that was constructed
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status Build(std::shared_ptr<RandomDataOp> *out_op); Status Build(std::shared_ptr<RandomDataOp> *out_op);


@@ -128,7 +128,7 @@ class RandomDataOp : public ParallelOp {
private: private:
/** /**
* Check if the required parameters are set by the builder. * Check if the required parameters are set by the builder.
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status SanityCheck() const; Status SanityCheck() const;


@@ -182,7 +182,7 @@ class RandomDataOp : public ParallelOp {
* Class functor operator () override. * Class functor operator () override.
* All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
* provide the master loop that drives the logic for performing the work. * provide the master loop that drives the logic for performing the work.
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status operator()() override; Status operator()() override;


@@ -190,7 +190,7 @@ class RandomDataOp : public ParallelOp {
* Overrides base class reset method. When an operator does a reset, it cleans up any state * Overrides base class reset method. When an operator does a reset, it cleans up any state
* info from it's previous execution and then initializes itself so that it can be executed * info from it's previous execution and then initializes itself so that it can be executed
* again. * again.
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status Reset() override; Status Reset() override;


@@ -207,7 +207,7 @@ class RandomDataOp : public ParallelOp {
/** /**
* The entry point code for when workers are launched * The entry point code for when workers are launched
* @param worker_id - The worker id * @param worker_id - The worker id
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


@@ -219,7 +219,7 @@ class RandomDataOp : public ParallelOp {
/** /**
* Performs a synchronization between workers at the end of an epoch * Performs a synchronization between workers at the end of an epoch
* @param worker_id - The worker id * @param worker_id - The worker id
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status EpochSync(int32_t worker_id, bool *quitting); Status EpochSync(int32_t worker_id, bool *quitting);


@@ -227,7 +227,7 @@ class RandomDataOp : public ParallelOp {
* A helper function to stuff the tensor table into a buffer and send it to output connector * A helper function to stuff the tensor table into a buffer and send it to output connector
* @param worker_id - The worker id * @param worker_id - The worker id
* @param in_table - The tensor table to pack and send * @param in_table - The tensor table to pack and send
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table); Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table);


@@ -235,7 +235,7 @@ class RandomDataOp : public ParallelOp {
* A helper function to create random data for the row * A helper function to create random data for the row
* @param worker_id - The worker id * @param worker_id - The worker id
* @param new_row - The output row to produce * @param new_row - The output row to produce
* @return Status - The error code return
* @return Status The status code returned
*/ */
Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); Status CreateRandomRow(int32_t worker_id, TensorRow *new_row);




+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h View File

@@ -40,7 +40,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED


// @param std::unique_ptr<DataBuffer pBuffer // @param std::unique_ptr<DataBuffer pBuffer
// @param int32_t workerId // @param int32_t workerId
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;


// first handshake between leaf source op and Sampler. This func will determine the amount of data // first handshake between leaf source op and Sampler. This func will determine the amount of data
@@ -53,7 +53,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
Status InitSampler() override; Status InitSampler() override;


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override; Status ResetSampler() override;


// Printer for debugging purposes. // Printer for debugging purposes.


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h View File

@@ -41,13 +41,13 @@ class PythonSamplerRT : public SamplerRT {
Status InitSampler() override; Status InitSampler() override;


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override; Status ResetSampler() override;


// Op calls this to get next Buffer that contains all the sampleIds // Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
// @param int32_t workerId - not meant to be used // @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;


// Printer for debugging purposes. // Printer for debugging purposes.


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h View File

@@ -40,14 +40,14 @@ class RandomSamplerRT : public SamplerRT {
// Op calls this to get next Buffer that contains all the sampleIds // Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used // @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;


// meant to be called by base class or python // meant to be called by base class or python
Status InitSampler() override; Status InitSampler() override;


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override; Status ResetSampler() override;


void SamplerPrint(std::ostream &out, bool show_all) const override; void SamplerPrint(std::ostream &out, bool show_all) const override;


+ 7
- 7
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h View File

@@ -35,12 +35,12 @@ class RandomAccessOp {
public: public:
// Sampler get number of rows in the dataset // Sampler get number of rows in the dataset
// @param int64_t num - return number of rows for this dataset // @param int64_t num - return number of rows for this dataset
// @return - The error code return
// @return Status The status code returned
Status GetNumRowsInDataset(int64_t *num_rows) const; Status GetNumRowsInDataset(int64_t *num_rows) const;


// sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK
// @param std::map<int64_t, std::vector<int64_t>> * map // @param std::map<int64_t, std::vector<int64_t>> * map
// @return - The error code return
// @return Status The status code returned
virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const { virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const {
RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK");
} }
@@ -71,7 +71,7 @@ class SamplerRT {
// @note It is Sampler responsibility to make sure that the id is not out of bound. // @note It is Sampler responsibility to make sure that the id is not out of bound.
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used // @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0; virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0;


// This function only called by python layer. Not needed by Android. // This function only called by python layer. Not needed by Android.
@@ -81,7 +81,7 @@ class SamplerRT {
#endif #endif


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
virtual Status ResetSampler() = 0; virtual Status ResetSampler() = 0;


// first handshake between leaf source op and Sampler. This func will determine the amount of data // first handshake between leaf source op and Sampler. This func will determine the amount of data
@@ -114,13 +114,13 @@ class SamplerRT {


// Adds a sampler to become our child. // Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child. // @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
// @return Status The status code returned
Status AddChild(std::shared_ptr<SamplerRT> child); Status AddChild(std::shared_ptr<SamplerRT> child);


// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds // @param std::shared_ptr<Tensor>* sampleIds
// @param int64_t numElements - must be a non 0 number // @param int64_t numElements - must be a non 0 number
// @return - The error code returned.
// @return Status The status code returned
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);


// A print method typically used for debugging // A print method typically used for debugging
@@ -146,7 +146,7 @@ class SamplerRT {
// associated id. // associated id.
// @param int64_t* out_associated_id - Out parameter, contains the associated id. // @param int64_t* out_associated_id - Out parameter, contains the associated id.
// @param int64_t id - The id used as an index to get the associated child id. // @param int64_t id - The id used as an index to get the associated child id.
// @return - The error code returned.
// @return Status The status code returned
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);


protected: protected:


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h View File

@@ -40,13 +40,13 @@ class SequentialSamplerRT : public SamplerRT {
Status InitSampler() override; Status InitSampler() override;


// for next epoch of sampleIds // for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override; Status ResetSampler() override;


// Op calls this to get next Buffer that contains all the sampleIds // Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
// @param int32_t workerId - not meant to be used // @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;


// Printer for debugging purposes. // Printer for debugging purposes.


+ 15
- 15
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h View File

@@ -132,12 +132,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
} }


// Check validity of input args // Check validity of input args
// @return = The error code return
// @return Status The status code returned
Status SanityCheck(); Status SanityCheck();


// The builder "Build" method creates the final object. // The builder "Build" method creates the final object.
// @param std::shared_ptr<VOCOp> *op - DatasetOp // @param std::shared_ptr<VOCOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<VOCOp> *op); Status Build(std::shared_ptr<VOCOp> *op);


private: private:
@@ -173,13 +173,13 @@ class VOCOp : public ParallelOp, public RandomAccessOp {


// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker // @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; Status WorkerEntry(int32_t worker_id) override;


// Main Loop of VOCOp // Main Loop of VOCOp
// Master thread: Fill IOBlockQueue, then goes to sleep // Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// A print method typically used for debugging // A print method typically used for debugging
@@ -222,55 +222,55 @@ class VOCOp : public ParallelOp, public RandomAccessOp {


private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler(); Status InitSampler();


// Load a tensor row according to image id // Load a tensor row according to image id
// @param row_id_type row_id - id for this tensor row // @param row_id_type row_id - id for this tensor row
// @param std::string image_id - image id // @param std::string image_id - image id
// @param TensorRow row - image & target read into this tensor row // @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row);


// @param const std::string &path - path to the image file // @param const std::string &path - path to the image file
// @param const ColDescriptor &col - contains tensor implementation and datatype // @param const ColDescriptor &col - contains tensor implementation and datatype
// @param std::shared_ptr<Tensor> tensor - return // @param std::shared_ptr<Tensor> tensor - return
// @return Status - The error code return
// @return Status The status code returned
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);


// @param const std::string &path - path to the image file // @param const std::string &path - path to the image file
// @param TensorRow *row - return // @param TensorRow *row - return
// @return Status - The error code return
// @return Status The status code returned
Status ReadAnnotationToTensor(const std::string &path, TensorRow *row); Status ReadAnnotationToTensor(const std::string &path, TensorRow *row);


// @param const std::vector<uint64_t> &keys - keys in ioblock // @param const std::vector<uint64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db // @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);


// Read image list from ImageSets // Read image list from ImageSets
// @return Status - The error code return
// @return Status The status code returned
Status ParseImageIds(); Status ParseImageIds();


// Read annotation from Annotation folder // Read annotation from Annotation folder
// @return Status - The error code return
// @return Status The status code returned
Status ParseAnnotationIds(); Status ParseAnnotationIds();


// @param const std::string &path - path to annotation xml // @param const std::string &path - path to annotation xml
// @return Status - The error code return
// @return Status The status code returned
Status ParseAnnotationBbox(const std::string &path); Status ParseAnnotationBbox(const std::string &path);


// @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor
// @param std::vector<int64_t> *keys - image id // @param std::vector<int64_t> *keys - image id
// @return Status - The error code return
// @return Status The status code returned
Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);


// Called first when function is called // Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp(); Status LaunchThreadsAndInitOp();


// Reset dataset state // Reset dataset state
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override; Status Reset() override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h View File

@@ -75,7 +75,7 @@ class TakeOp : public PipelineOp {


// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


// Base-class override for NodePass visitor acceptor. // Base-class override for NodePass visitor acceptor.


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h View File

@@ -101,7 +101,7 @@ class ZipOp : public PipelineOp {
// Class functor operator () override. // Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work // provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override; Status operator()() override;


/// \brief Base-class override for NodePass pre-visit acceptor /// \brief Base-class override for NodePass pre-visit acceptor


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc View File

@@ -226,7 +226,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// Compulsory transformation/action post optimization. // Compulsory transformation/action post optimization.
// For example, repeatOp inlining // For example, repeatOp inlining
// //
// @return Status - The error code return
// @return Status The status code returned
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) { Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
num_epochs_ = num_epochs; num_epochs_ = num_epochs;
partially_prepare_ = partial; partially_prepare_ = partial;


+ 10
- 10
mindspore/ccsrc/minddata/dataset/engine/execution_tree.h View File

@@ -115,16 +115,16 @@ class ExecutionTree {
// provides it with a link to the tree. A node cannot form any relationships (parent/child) with // provides it with a link to the tree. A node cannot form any relationships (parent/child) with
// other nodes unless they are associated with the same tree. // other nodes unless they are associated with the same tree.
// @param op - The operator to associate // @param op - The operator to associate
// @return Status - The error code return
// @return Status The status code returned
Status AssociateNode(const std::shared_ptr<DatasetOp> &op); Status AssociateNode(const std::shared_ptr<DatasetOp> &op);


// Sets the root node of the tree // Sets the root node of the tree
// @param op - The operator to assign as root // @param op - The operator to assign as root
// @return Status - The error code return
// @return Status The status code returned
Status AssignRoot(const std::shared_ptr<DatasetOp> &op); Status AssignRoot(const std::shared_ptr<DatasetOp> &op);


// Start the execution of the tree // Start the execution of the tree
// @return Status - The error code return
// @return Status The status code returned
Status Launch(); Status Launch();


/// A print method typically used for debugging /// A print method typically used for debugging
@@ -155,7 +155,7 @@ class ExecutionTree {
// wrapper for the TaskGroup handling that is stored inside the execution tree. // wrapper for the TaskGroup handling that is stored inside the execution tree.
// @param num_workers - The number of workers to launch // @param num_workers - The number of workers to launch
// @param func - The function entry point that workers will execute // @param func - The function entry point that workers will execute
// @return Status - The error code return
// @return Status The status code returned
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = ""); Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "");


// Getter method // Getter method
@@ -181,32 +181,32 @@ class ExecutionTree {
// Compulsory transformation/action post optimization. // Compulsory transformation/action post optimization.
// For example, repeatOp inlining // For example, repeatOp inlining
// //
// @return Status - The error code return
// @return Status The status code returned
Status Prepare(int num_epochs = -1, bool partial = false); Status Prepare(int num_epochs = -1, bool partial = false);


// Compulsory transformation/action pre optimization. // Compulsory transformation/action pre optimization.
// @return Status - The error code return
// @return Status The status code returned
Status PreAction(); Status PreAction();


// Compulsory transformation/action post optimization. // Compulsory transformation/action post optimization.
// @return Status - The error code return
// @return Status The status code returned
Status PostAction(); Status PostAction();


// Optimization transformation/action, optional. // Optimization transformation/action, optional.
// @return Status - The error code return
// @return Status The status code returned
Status Optimize(); Status Optimize();


// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get // walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution. // it ready for execution.
// @param Total number of epochs that will be run on this tree // @param Total number of epochs that will be run on this tree
// @return Status - The error code return
// @return Status The status code returned
Status PrepareDeprecated(); Status PrepareDeprecated();


// Recursive function used during prepare phase to visit a node and drive any pre- and post- // Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk. // node actions during a tree walk.
// @param op - The dataset op to work on // @param op - The dataset op to work on
// @return Status - The error code return
// @return Status The status code returned
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op); Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);


// Return the pointer to the TaskGroup // Return the pointer to the TaskGroup


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h View File

@@ -51,7 +51,7 @@ class Edge {
// Get the feature of a edge // Get the feature of a edge
// @param FeatureType feature_type - type of feature // @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;


// Get nodes on the edge // Get nodes on the edge
@@ -71,7 +71,7 @@ class Edge {


// Update feature of edge // Update feature of edge
// @param std::shared_ptr<Feature> feature - // @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;


protected: protected:


+ 9
- 9
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h View File

@@ -47,19 +47,19 @@ class GraphData {
// Get all nodes from the graph. // Get all nodes from the graph.
// @param NodeType node_type - type of node // @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id // @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0; virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;


// Get all edges from the graph. // Get all edges from the graph.
// @param NodeType edge_type - type of edge // @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids // @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0; virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;


// Get the node id from the edge. // Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges // @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids // @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;


// All neighbors of the acquisition node. // All neighbors of the acquisition node.
@@ -68,7 +68,7 @@ class GraphData {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1. // is not enough, fill in tensor as -1.
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) = 0; std::shared_ptr<Tensor> *out) = 0;


@@ -77,7 +77,7 @@ class GraphData {
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
@@ -87,7 +87,7 @@ class GraphData {
// @param NodeIdType samples_num - Number of neighbors sampled // @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor. // @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0; NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;


@@ -98,7 +98,7 @@ class GraphData {
// @param float step_away_param - inout hyper parameter in node2vec algorithm // @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id // @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
// @return Status The status code returned
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node, float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) = 0; std::shared_ptr<Tensor> *out) = 0;
@@ -108,7 +108,7 @@ class GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist. // does not exist.
// @param TensorRow *out - Returned features // @param TensorRow *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0; TensorRow *out) = 0;


@@ -117,7 +117,7 @@ class GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist. // does not exist.
// @param Tensor *out - Returned features // @param Tensor *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0; TensorRow *out) = 0;




+ 9
- 9
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h View File

@@ -57,19 +57,19 @@ class GraphDataClient : public GraphData {
// Get all nodes from the graph. // Get all nodes from the graph.
// @param NodeType node_type - type of node // @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id // @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;


// Get all edges from the graph. // Get all edges from the graph.
// @param NodeType edge_type - type of edge // @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids // @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
// @return Status The status code returned
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;


// Get the node id from the edge. // Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges // @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids // @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
// @return Status The status code returned
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;


// All neighbors of the acquisition node. // All neighbors of the acquisition node.
@@ -78,7 +78,7 @@ class GraphDataClient : public GraphData {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1. // is not enough, fill in tensor as -1.
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) override; std::shared_ptr<Tensor> *out) override;


@@ -87,7 +87,7 @@ class GraphDataClient : public GraphData {
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;


@@ -96,7 +96,7 @@ class GraphDataClient : public GraphData {
// @param NodeIdType samples_num - Number of neighbors sampled // @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor. // @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;


@@ -107,7 +107,7 @@ class GraphDataClient : public GraphData {
// @param float step_away_param - inout hyper parameter in node2vec algorithm // @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id // @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
// @return Status The status code returned
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node, float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) override; std::shared_ptr<Tensor> *out) override;
@@ -117,7 +117,7 @@ class GraphDataClient : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist. // does not exist.
// @param TensorRow *out - Returned features // @param TensorRow *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) override; TensorRow *out) override;


@@ -126,7 +126,7 @@ class GraphDataClient : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist. // does not exist.
// @param Tensor *out - Returned features // @param Tensor *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override; TensorRow *out) override;




+ 18
- 18
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h View File

@@ -51,19 +51,19 @@ class GraphDataImpl : public GraphData {
// Get all nodes from the graph. // Get all nodes from the graph.
// @param NodeType node_type - type of node // @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id // @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;


// Get all edges from the graph. // Get all edges from the graph.
// @param NodeType edge_type - type of edge // @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids // @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
// @return Status The status code returned
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;


// Get the node id from the edge. // Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges // @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids // @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
// @return Status The status code returned
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;


// All neighbors of the acquisition node. // All neighbors of the acquisition node.
@@ -72,7 +72,7 @@ class GraphDataImpl : public GraphData {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1. // is not enough, fill in tensor as -1.
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) override; std::shared_ptr<Tensor> *out) override;


@@ -81,7 +81,7 @@ class GraphDataImpl : public GraphData {
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;


@@ -90,7 +90,7 @@ class GraphDataImpl : public GraphData {
// @param NodeIdType samples_num - Number of neighbors sampled // @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor. // @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;


@@ -101,7 +101,7 @@ class GraphDataImpl : public GraphData {
// @param float step_away_param - inout hyper parameter in node2vec algorithm // @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id // @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
// @return Status The status code returned
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node, float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) override; std::shared_ptr<Tensor> *out) override;
@@ -111,7 +111,7 @@ class GraphDataImpl : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist. // does not exist.
// @param TensorRow *out - Returned features // @param TensorRow *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) override; TensorRow *out) override;


@@ -123,7 +123,7 @@ class GraphDataImpl : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist. // does not exist.
// @param Tensor *out - Returned features // @param Tensor *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override; TensorRow *out) override;


@@ -132,7 +132,7 @@ class GraphDataImpl : public GraphData {


// Get meta information of graph // Get meta information of graph
// @param MetaInfo *meta_info - Returned meta information // @param MetaInfo *meta_info - Returned meta information
// @return Status - The error code return
// @return Status The status code returned
Status GetMetaInfo(MetaInfo *meta_info); Status GetMetaInfo(MetaInfo *meta_info);


#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@@ -202,14 +202,14 @@ class GraphDataImpl : public GraphData {
}; };


// Load graph data from mindrecord file // Load graph data from mindrecord file
// @return Status - The error code return
// @return Status The status code returned
Status LoadNodeAndEdge(); Status LoadNodeAndEdge();


// Create Tensor By Vector // Create Tensor By Vector
// @param std::vector<std::vector<T>> &data - // @param std::vector<std::vector<T>> &data -
// @param DataType type - // @param DataType type -
// @param std::shared_ptr<Tensor> *out - // @param std::shared_ptr<Tensor> *out -
// @return Status - The error code return
// @return Status The status code returned
template <typename T> template <typename T>
Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out); Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out);


@@ -217,32 +217,32 @@ class GraphDataImpl : public GraphData {
// @param std::vector<std::vector<T>> *data - To be completed vector // @param std::vector<std::vector<T>> *data - To be completed vector
// @param size_t max_size - The size of the completed vector // @param size_t max_size - The size of the completed vector
// @param T default_value - Filled default // @param T default_value - Filled default
// @return Status - The error code return
// @return Status The status code returned
template <typename T> template <typename T>
Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value); Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value);


// Get the default feature of a node // Get the default feature of a node
// @param FeatureType feature_type - // @param FeatureType feature_type -
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);


// Get the default feature of a edge // Get the default feature of a edge
// @param FeatureType feature_type - // @param FeatureType feature_type -
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);


// Find node object using node id // Find node object using node id
// @param NodeIdType id - // @param NodeIdType id -
// @param std::shared_ptr<Node> *node - Returned node object // @param std::shared_ptr<Node> *node - Returned node object
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node);


// Find edge object using edge id // Find edge object using edge id
// @param EdgeIdType id - // @param EdgeIdType id -
// @param std::shared_ptr<Node> *edge - Returned edge object // @param std::shared_ptr<Node> *edge - Returned edge object
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge); Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge);


// Negative sampling // Negative sampling
@@ -250,7 +250,7 @@ class GraphDataImpl : public GraphData {
// @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
// @param int32_t samples_num - // @param int32_t samples_num -
// @param std::vector<NodeIdType> *out_samples - Sampling results returned // @param std::vector<NodeIdType> *out_samples - Sampling results returned
// @return Status - The error code return
// @return Status The status code returned
Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids, Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
std::vector<NodeIdType> *out_samples); std::vector<NodeIdType> *out_samples);


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h View File

@@ -43,12 +43,12 @@ class LocalEdge : public Edge {
// Get the feature of a edge // Get the feature of a edge
// @param FeatureType feature_type - type of feature // @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;


// Update feature of edge // Update feature of edge
// @param std::shared_ptr<Feature> feature - // @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;


private: private:


+ 5
- 5
mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h View File

@@ -40,13 +40,13 @@ class LocalNode : public Node {
// Get the feature of a node // Get the feature of a node
// @param FeatureType feature_type - type of feature // @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;


// Get the all neighbors of a node // Get the all neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) override; bool exclude_itself = false) override;


@@ -54,18 +54,18 @@ class LocalNode : public Node {
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired // @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
std::vector<NodeIdType> *out_neighbors) override; std::vector<NodeIdType> *out_neighbors) override;


// Add neighbor of node // Add neighbor of node
// @param std::shared_ptr<Node> node - // @param std::shared_ptr<Node> node -
// @return Status - The error code return
// @return Status The status code returned
Status AddNeighbor(const std::shared_ptr<Node> &node) override; Status AddNeighbor(const std::shared_ptr<Node> &node) override;


// Update feature of node // Update feature of node
// @param std::shared_ptr<Feature> feature - // @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;


private: private:


+ 5
- 5
mindspore/ccsrc/minddata/dataset/engine/gnn/node.h View File

@@ -49,13 +49,13 @@ class Node {
// Get the feature of a node // Get the feature of a node
// @param FeatureType feature_type - type of feature // @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;


// Get the all neighbors of a node // Get the all neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) = 0; bool exclude_itself = false) = 0;


@@ -63,18 +63,18 @@ class Node {
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired // @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
std::vector<NodeIdType> *out_neighbors) = 0; std::vector<NodeIdType> *out_neighbors) = 0;


// Add neighbor of node // Add neighbor of node
// @param std::shared_ptr<Node> node - // @param std::shared_ptr<Node> node -
// @return Status - The error code return
// @return Status The status code returned
virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0; virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0;


// Update feature of node // Update feature of node
// @param std::shared_ptr<Feature> feature - // @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;


protected: protected:


+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/repeat_node.h View File

@@ -57,6 +57,10 @@ class RepeatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Getter
/// \return Number of cycles to repeat the execution
const int32_t Count() const { return repeat_count_; }

/// \brief Base-class override for GetDatasetSize /// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter /// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting


+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h View File

@@ -55,6 +55,10 @@ class SkipNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Getter
/// \return Number of rows to skip
const int32_t Count() const { return skip_count_; }

/// \brief Base-class override for GetDatasetSize /// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter /// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting


+ 4
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h View File

@@ -55,6 +55,10 @@ class TakeNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid /// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override; Status ValidateParams() override;


/// \brief Getter
/// \return Number of rows to output
const int32_t Count() const { return take_count_; }

/// \brief Base-class override for GetDatasetSize /// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter /// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h View File

@@ -29,7 +29,7 @@ class TensorOpFusionPass : public NodePass {
/// \brief Identifies and fuses tensor ops within MapOp /// \brief Identifies and fuses tensor ops within MapOp
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] *modified indicates whether the node has been visited /// \param[inout] *modified indicates whether the node has been visited
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
}; };
} // namespace dataset } // namespace dataset


+ 6
- 6
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h View File

@@ -135,7 +135,7 @@ class IRTreePass : public IRPass {
/// "modified" flag needs to be set to true if tree is modified during the pass execution. /// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on. /// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate if the tree was modified. /// \param[inout] Indicate if the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
}; };


@@ -170,14 +170,14 @@ class IRNodePass : public IRPass {
/// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution /// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all /// \param[out] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }


/// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation
/// "modified" flag needs to be set to true if node is modified during the pass execution /// "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all. /// \param[out] modified Indicator if the node was changed at all.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }


// Visit()/VisitAfter() method to be overridden. // Visit()/VisitAfter() method to be overridden.
@@ -266,7 +266,7 @@ class TreePass : public Pass {
/// "modified" flag needs to be set to true if tree is modified during the pass execution. /// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on. /// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified. /// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
}; };


@@ -301,14 +301,14 @@ class NodePass : public Pass {
/// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all /// \param[out] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }


/// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation
/// "modified" flag needs to be set to true if tree is modified during the pass execution /// "modified" flag needs to be set to true if tree is modified during the pass execution
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all. /// \param[out] modified Indicator if the node was changed at all.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }


// Visit methods to be overridden. // Visit methods to be overridden.


+ 14
- 14
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h View File

@@ -41,80 +41,80 @@ class RepeatPass : public NodePass {
/// \brief Identifies the subtree below this node as being in a repeated path of the tree. /// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;


/// \brief Identifies the subtree below this node as being in a repeated path of the tree. /// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;


/// \brief Identifies the subtree below this node as being in a cache merge path /// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;


/// \brief Identifies the subtree below this node as being cached /// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;


/// \brief Hooks up any identified eoe nodes under this repeat. /// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;


/// \brief Hooks up any identified eoe nodes under this repeat. /// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;


/// \brief CacheOp removes previous leaf ops and replaces them with itself /// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;


/// \brief Turns of the tracking for operations under merge op /// \brief Turns of the tracking for operations under merge op
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;


/// \brief Saves the lookup up in case it needs to be referenced by a repeat /// \brief Saves the lookup up in case it needs to be referenced by a repeat
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;


/// \brief Set the epoch count for DeviceQueue /// \brief Set the epoch count for DeviceQueue
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;


/// \brief Special case for GeneratorOp /// \brief Special case for GeneratorOp
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;


/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it. /// for use with a controlling repeat above it.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;


private: private:
/// \brief Adds an operator to the eoe operator stack save area /// \brief Adds an operator to the eoe operator stack save area
/// \param op - The dataset op to work add to eoe stack /// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
/// \return Status The status code returned
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);


/// \brief Pops an operator from the eoe operator stack save area /// \brief Pops an operator from the eoe operator stack save area
@@ -127,7 +127,7 @@ class RepeatPass : public NodePass {


/// \brief Adds an operator to the cached operator stack save area /// \brief Adds an operator to the cached operator stack save area
/// \param op - The dataset op to work add to cached stack /// \param op - The dataset op to work add to cached stack
/// \return Status - The error code return
/// \return Status The status code returned
void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op); void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op);


/// \brief Pops an operator from the cached operator stack save area /// \brief Pops an operator from the cached operator stack save area


+ 20
- 20
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h View File

@@ -38,123 +38,123 @@ class CacheErrorPass : public NodePass {
/// \brief Identifies the subtree below this node as being cached /// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;


/// \brief Returns an error if ZipOp exists under a cache /// \brief Returns an error if ZipOp exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;


/// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;


/// \brief Returns an error if ConcatOp exists under a cache /// \brief Returns an error if ConcatOp exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;


/// \brief Returns an error if TakeOp exists under a cache /// \brief Returns an error if TakeOp exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;


/// \brief Returns an error if SkipOp exists under a cache /// \brief Returns an error if SkipOp exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;


/// \brief Returns an error if SkipOp exists under a cache /// \brief Returns an error if SkipOp exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override;


#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
/// \brief Returns an error if FilterOp exists under a cache /// \brief Returns an error if FilterOp exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
#endif #endif


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;


/// \brief Identifies the leaf dataset as being mappable /// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;


/// \brief Identifies the subtree above this node as not being cached /// \brief Identifies the subtree above this node as not being cached
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;


/// \brief Identifies and block repeat under cache scenarios /// \brief Identifies and block repeat under cache scenarios
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;


private: private:


+ 21
- 21
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h View File

@@ -48,14 +48,14 @@ class CacheTransformPass : public TreePass {
/// \brief Identifies the subtree below this node as a cached descendant tree. /// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;


/// \brief Resets the tracking of the cache within the tree and assigns the operators that /// \brief Resets the tracking of the cache within the tree and assigns the operators that
/// will be involved in a cache transformation /// will be involved in a cache transformation
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@@ -63,95 +63,95 @@ class CacheTransformPass : public TreePass {
/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
#endif #endif


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;


#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
#endif #endif


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;


/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Perform leaf node cache transform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
#endif #endif


@@ -161,12 +161,12 @@ class CacheTransformPass : public TreePass {
private: private:
/// \brief Common code for mappable leaf setup. /// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work. /// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
/// \return Status The status code returned
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);


/// \brief Common code for non-mappable leaf setup. /// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work. /// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
/// \return Status The status code returned
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);


/// \brief Assigns the leaf and cache operators that are involved in a cache transformation /// \brief Assigns the leaf and cache operators that are involved in a cache transformation
@@ -191,7 +191,7 @@ class CacheTransformPass : public TreePass {
/// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
/// \param[inout] tree The tree to operate on. /// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified. /// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *modified) override; Status RunOnTree(ExecutionTree *tree, bool *modified) override;


private: private:
@@ -212,7 +212,7 @@ class CacheTransformPass : public TreePass {
/// \param[in] leaf_op The leaf node in the transform /// \param[in] leaf_op The leaf node in the transform
/// \param[in] cache_op The cache op in the transform (will get removed) /// \param[in] cache_op The cache op in the transform (will get removed)
/// \param[in] cache_client The cache client /// \param[in] cache_client The cache client
/// \return Status The error code return
/// \return Status The status code returned
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
}; };


+ 10
- 10
mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_validation_pass.h View File

@@ -38,61 +38,61 @@ class CacheValidationPass : public IRNodePass {
/// \brief Returns an error if BatchNode exists under a cache /// \brief Returns an error if BatchNode exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override; Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override;


/// \brief Returns an error if ConcatNode exists under a cache /// \brief Returns an error if ConcatNode exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override; Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override;


/// \brief Returns an error if FilterNode exists under a cache /// \brief Returns an error if FilterNode exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override; Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override;


/// \brief Returns an error if SkipNode exists under a cache /// \brief Returns an error if SkipNode exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override; Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;


/// \brief Returns an error if TakeNode exists under a cache /// \brief Returns an error if TakeNode exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override; Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;


/// \brief Returns an error if ZipNode exists under a cache /// \brief Returns an error if ZipNode exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override; Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override;


/// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache /// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;


/// \brief Returns an error if there is a cache over another cache /// \brief Returns an error if there is a cache over another cache
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;


/// \brief Identifies and block repeat under cache scenarios /// \brief Identifies and block repeat under cache scenarios
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override; Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override;


/// \brief Identifies the subtree above this node as not being cached /// \brief Identifies the subtree above this node as not being cached
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;


private: private:


+ 5
- 5
mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h View File

@@ -45,27 +45,27 @@ class EpochCtrlPass : public IRTreePass {
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<RootNode> node, bool *modified) override; Status Visit(std::shared_ptr<RootNode> node, bool *modified) override;


/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override; Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override;


#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection. /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override; Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override;
#endif #endif


/// \brief Register the TransferNode for further action. /// \brief Register the TransferNode for further action.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override; Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override;


/// \brief Getter /// \brief Getter
@@ -89,7 +89,7 @@ class EpochCtrlPass : public IRTreePass {
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage /// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on. /// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified. /// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
}; };
} // namespace dataset } // namespace dataset


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_injection_pass.h View File

@@ -46,20 +46,20 @@ class EpochInjectionPass : public TreePass {
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override;


/// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection. /// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override;
#endif #endif


/// \brief Register the DeviceQueueOp for further action. /// \brief Register the DeviceQueueOp for further action.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;


/// \brief Getter /// \brief Getter
@@ -79,7 +79,7 @@ class EpochInjectionPass : public TreePass {
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage /// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on. /// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified. /// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *modified) override; Status RunOnTree(ExecutionTree *tree, bool *modified) override;
}; };
} // namespace dataset } // namespace dataset


+ 31
- 1
mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc View File

@@ -17,7 +17,10 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h" #include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -47,7 +50,16 @@ Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> no
return Status::OK(); return Status::OK();
} }


// Perform ShuffleOp removal check.
// Perform RepeatNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
*modified = false;
if (node->Count() == 1) {
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
}
return Status::OK();
}

// Perform ShuffleNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
*modified = false; *modified = false;
#if 0 #if 0
@@ -60,6 +72,24 @@ Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, b
return Status::OK(); return Status::OK();
} }


// Perform SkipNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
*modified = false;
if (node->Count() == 0) {
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
}
return Status::OK();
}

// Perform TakeNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
*modified = false;
if (node->Count() == -1) {
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
}
return Status::OK();
}

// constructor // constructor
NodeRemovalPass::NodeRemovalPass() {} NodeRemovalPass::NodeRemovalPass() {}




+ 22
- 4
mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h View File

@@ -45,21 +45,39 @@ class NodeRemovalPass : public IRTreePass {
/// \brief Identifies the subtree below this node as a cached descendant tree. /// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;


/// \brief Resets the tracking of the cache within the tree /// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;


/// \brief Perform RepeatNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<RepeatNode> node, bool *modified) override;

/// \brief Perform ShuffleNode removal check /// \brief Perform ShuffleNode removal check
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override; Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override;


/// \brief Perform SkipNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;

/// \brief Perform TakeNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;

/// \brief Getter /// \brief Getter
/// \return All the nodes to be removed /// \return All the nodes to be removed
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; } std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; }
@@ -79,7 +97,7 @@ class NodeRemovalPass : public IRTreePass {
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
/// \param[inout] tree The tree to operate on. /// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified. /// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
}; };
} // namespace dataset } // namespace dataset


+ 4
- 4
mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h View File

@@ -46,20 +46,20 @@ class RemovalPass : public TreePass {
/// \brief Identifies the subtree below this node as a cached descendant tree. /// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;


/// \brief Resets the tracking of the cache within the tree /// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
#endif #endif


/// \brief Perform ShuffleOp removal check /// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;


/// \brief Getter /// \brief Getter
@@ -81,7 +81,7 @@ class RemovalPass : public TreePass {
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
/// \param[inout] tree The tree to operate on. /// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified. /// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *modified) override; Status RunOnTree(ExecutionTree *tree, bool *modified) override;
}; };
} // namespace dataset } // namespace dataset


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h View File

@@ -53,7 +53,7 @@ class ConnectorSize : public Sampling {
std::string Name() const override { return kConnectorSizeSamplingName; } std::string Name() const override { return kConnectorSizeSamplingName; }


// Save sampling data to file // Save sampling data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override; Status SaveToFile() override;


Status Init(const std::string &dir_path, const std::string &device_id) override; Status Init(const std::string &dir_path, const std::string &device_id) override;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h View File

@@ -65,7 +65,7 @@ class ConnectorThroughput : public Sampling {
std::string Name() const override { return name_; }; std::string Name() const override { return name_; };


// Save sampling data to file // Save sampling data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override; Status SaveToFile() override;


Status Init(const std::string &dir_path, const std::string &device_id); Status Init(const std::string &dir_path, const std::string &device_id);


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h View File

@@ -32,13 +32,13 @@ class DatasetIteratorTracing : public Tracing {
~DatasetIteratorTracing() override = default; ~DatasetIteratorTracing() override = default;


// Record tracing data // Record tracing data
// @return Status - The error code return
// @return Status The status code returned
Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value);


std::string Name() const override { return kDatasetIteratorTracingName; }; std::string Name() const override { return kDatasetIteratorTracingName; };


// Save tracing data to file // Save tracing data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override; Status SaveToFile() override;


Status Init(const std::string &dir_path, const std::string &device_id) override; Status Init(const std::string &dir_path, const std::string &device_id) override;


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h View File

@@ -32,13 +32,13 @@ class DeviceQueueTracing : public Tracing {
~DeviceQueueTracing() override = default; ~DeviceQueueTracing() override = default;


// Record tracing data // Record tracing data
// @return Status - The error code return
// @return Status The status code returned
Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value);


std::string Name() const override { return kDeviceQueueTracingName; }; std::string Name() const override { return kDeviceQueueTracingName; };


// Save tracing data to file // Save tracing data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override; Status SaveToFile() override;


Status Init(const std::string &dir_path, const std::string &device_id) override; Status Init(const std::string &dir_path, const std::string &device_id) override;


+ 5
- 5
mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h View File

@@ -87,19 +87,19 @@ class ProfilingManager {
Status Initialize(); Status Initialize();


// Save profile data to file // Save profile data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveProfilingData(); Status SaveProfilingData();


// Sampling node getter // Sampling node getter
// @param name - The name of the requested node // @param name - The name of the requested node
// @param node - Pointer to the shared pointer for the Sampling node // @param node - Pointer to the shared pointer for the Sampling node
// @return Status - The error code return
// @return Status The status code returned
Status GetSamplingNode(const std::string &name, std::shared_ptr<Sampling> *node); Status GetSamplingNode(const std::string &name, std::shared_ptr<Sampling> *node);


// Tracing node getter // Tracing node getter
// @param name - The name of the requested node // @param name - The name of the requested node
// @param node - Pointer to the shared pointer for the Tracing node // @param node - Pointer to the shared pointer for the Tracing node
// @return Status - The error code return
// @return Status The status code returned
Status GetTracingNode(const std::string &name, std::shared_ptr<Tracing> *node); Status GetTracingNode(const std::string &name, std::shared_ptr<Tracing> *node);


// If profiling is enabled. // If profiling is enabled.
@@ -120,12 +120,12 @@ class ProfilingManager {


// Register profile node to tree // Register profile node to tree
// @param node - Profiling node // @param node - Profiling node
// @return Status - The error code return
// @return Status The status code returned
Status RegisterTracingNode(std::shared_ptr<Tracing> node); Status RegisterTracingNode(std::shared_ptr<Tracing> node);


// Register profile node to tree // Register profile node to tree
// @param node - Profiling node // @param node - Profiling node
// @return Status - The error code return
// @return Status The status code returned
Status RegisterSamplingNode(std::shared_ptr<Sampling> node); Status RegisterSamplingNode(std::shared_ptr<Sampling> node);


ExecutionTree *tree_ = nullptr; // ExecutionTree pointer ExecutionTree *tree_ = nullptr; // ExecutionTree pointer


+ 1
- 1
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -442,7 +442,7 @@ class SchemaObj {
Status parse_column(nlohmann::json columns); Status parse_column(nlohmann::json columns);


/// \brief Get schema file from JSON file /// \brief Get schema file from JSON file
/// \param[in] json_obj Object of JSON parsed.
/// \param[in] json_obj parsed JSON object
/// \return Status code /// \return Status code
Status from_json(nlohmann::json json_obj); Status from_json(nlohmann::json json_obj);




+ 5
- 5
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h View File

@@ -77,7 +77,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
// @param std::shared_ptr<Tensor> *dst - return tensor padded // @param std::shared_ptr<Tensor> *dst - return tensor padded
// @param std::vector<dsize_t> pad_shape - shape to pad to // @param std::vector<dsize_t> pad_shape - shape to pad to
// @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format, // @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format,
// @return - The error code return
// @return Status The status code returned
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
const std::shared_ptr<Tensor> &pad_val); const std::shared_ptr<Tensor> &pad_val);


@@ -86,7 +86,7 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
// @param std::shared_ptr<Tensor> *dst - return tensor padded // @param std::shared_ptr<Tensor> *dst - return tensor padded
// @param std::vector<dsize_t> pad_shape - shape to pad to // @param std::vector<dsize_t> pad_shape - shape to pad to
// @param float pad_val - value to pad with // @param float pad_val - value to pad with
// @return - The error code return
// @return Status The status code returned
Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
const std::vector<dsize_t> &pad_shape, float pad_val); const std::vector<dsize_t> &pad_shape, float pad_val);


@@ -98,7 +98,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
// @param std::vector<dsize_t> cur_ind - recursion helper // @param std::vector<dsize_t> cur_ind - recursion helper
// @param T pad_val - value to pad tensor with // @param T pad_val - value to pad tensor with
// @param size_t cur_dim - recursion helper // @param size_t cur_dim - recursion helper
// @return Status - The error code return
// @return Status The status code returned
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
std::vector<dsize_t> cur_ind, size_t cur_dim = 0); std::vector<dsize_t> cur_ind, size_t cur_dim = 0);


@@ -107,7 +107,7 @@ Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<T
// @param std::shared_ptr<Tensor> *dst - return tensor padded // @param std::shared_ptr<Tensor> *dst - return tensor padded
// @param std::vector<dsize_t> pad_shape - shape to pad to // @param std::vector<dsize_t> pad_shape - shape to pad to
// @param std::string pad_val - value to pad with // @param std::string pad_val - value to pad with
// @return - The error code return
// @return Status The status code returned
Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
const std::vector<dsize_t> &pad_shape, const std::string &pad_val); const std::vector<dsize_t> &pad_shape, const std::string &pad_val);


@@ -119,7 +119,7 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
// @param std::vector<dsize_t> cur_ind - recursion helperas text // @param std::vector<dsize_t> cur_ind - recursion helperas text
// @param std::string pad_val - value to pad tensor with // @param std::string pad_val - value to pad tensor with
// @param size_t cur_dim - recursion helper // @param size_t cur_dim - recursion helper
// @return Status - The error code return
// @return Status The status code returned
Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst, Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim, const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
const std::string &pad_value); const std::string &pad_value);


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h View File

@@ -36,7 +36,7 @@ class ToFloat16Op : public TensorOp {
// Overrides the base class compute function // Overrides the base class compute function
// Calls the ToFloat16 function in ImageUtils, this function takes an input tensor // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor
// and transforms its data to float16, the output memory is manipulated to contain the result // and transforms its data to float16, the output memory is manipulated to contain the result
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;


Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h View File

@@ -58,7 +58,7 @@ class CutOutOp : public TensorOp {
// Overrides the base class compute function // Overrides the base class compute function
// Calls the erase function in ImageUtils, this function takes an input tensor // Calls the erase function in ImageUtils, this function takes an input tensor
// and overwrites some of its data using openCV, the output memory is manipulated to contain the result // and overwrites some of its data using openCV, the output memory is manipulated to contain the result
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;


std::string Name() const override { return kCutOutOp; } std::string Name() const override { return kCutOutOp; }


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h View File

@@ -50,7 +50,7 @@ class RandomColorAdjustOp : public TensorOp {
// Overrides the base class compute function. // Overrides the base class compute function.
// Calls multiple transform functions in ImageUtils, this function takes an input tensor. // Calls multiple transform functions in ImageUtils, this function takes an input tensor.
// and transforms its data using openCV, the output memory is manipulated to contain the result. // and transforms its data using openCV, the output memory is manipulated to contain the result.
// @return Status - The error code return.
// @return Status The status code returned.
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;


std::string Name() const override { return kRandomColorAdjustOp; } std::string Name() const override { return kRandomColorAdjustOp; }


+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h View File

@@ -61,7 +61,7 @@ class RandomRotationOp : public TensorOp {
// Overrides the base class compute function // Overrides the base class compute function
// Calls the rotate function in ImageUtils, this function takes an input tensor // Calls the rotate function in ImageUtils, this function takes an input tensor
// and transforms its data using openCV, the output memory is manipulated to contain the result // and transforms its data using openCV, the output memory is manipulated to contain the result
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;




+ 1
- 1
mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h View File

@@ -43,7 +43,7 @@ class UniformAugOp : public TensorOp {
void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; } void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; }


// Overrides the base class compute function // Overrides the base class compute function
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const TensorRow &input, TensorRow *output) override; Status Compute(const TensorRow &input, TensorRow *output) override;


std::string Name() const override { return kUniformAugOp; } std::string Name() const override { return kUniformAugOp; }


+ 6
- 6
mindspore/ccsrc/minddata/dataset/util/data_helper.h View File

@@ -53,7 +53,7 @@ class DataHelper {
/// \param key Key of field to write to /// \param key Key of field to write to
/// \param value Value array to write to file /// \param value Value array to write to file
/// \param out_file Optional input for output file path, will write to input file if not specified /// \param out_file Optional input for output file path, will write to input file if not specified
/// \return Status The error code return
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value, Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value,
const std::string &out_file = ""); const std::string &out_file = "");


@@ -62,7 +62,7 @@ class DataHelper {
/// \param key Key of field to write to /// \param key Key of field to write to
/// \param value Value array to write to file /// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified /// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The error code return
/// \return Status The status code returned
template <typename T> template <typename T>
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value, Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value,
const std::string &out_file = "") { const std::string &out_file = "") {
@@ -99,7 +99,7 @@ class DataHelper {
/// \param key Key of field to write to /// \param key Key of field to write to
/// \param value Value to write to file /// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified /// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The error code return
/// \return Status The status code returned
template <typename T> template <typename T>
Status UpdateValue(const std::string &in_file, const std::string &key, const T &value, Status UpdateValue(const std::string &in_file, const std::string &key, const T &value,
const std::string &out_file = "") { const std::string &out_file = "") {
@@ -134,7 +134,7 @@ class DataHelper {
/// \brief Template function to write tensor to file /// \brief Template function to write tensor to file
/// \param[in] in_file File to write to /// \param[in] in_file File to write to
/// \param[in] data Array of type T values /// \param[in] data Array of type T values
/// \return Status The error code return
/// \return Status The status code returned
template <typename T> template <typename T>
Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) { Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) {
try { try {
@@ -157,7 +157,7 @@ class DataHelper {
/// \param[in] in_file File name to write to /// \param[in] in_file File name to write to
/// \param[in] data Pointer to data /// \param[in] data Pointer to data
/// \param[in] length Length of values to write from pointer /// \param[in] length Length of values to write from pointer
/// \return Status The error code return
/// \return Status The status code returned
template <typename T> template <typename T>
Status WriteBinFile(const std::string &in_file, T *data, size_t length) { Status WriteBinFile(const std::string &in_file, T *data, size_t length) {
try { try {
@@ -188,7 +188,7 @@ class DataHelper {
/// note This function will return okay even if key not found /// note This function will return okay even if key not found
/// \param[in] in_file Json file to remove key from /// \param[in] in_file Json file to remove key from
/// \param[in] key The key to remove /// \param[in] key The key to remove
/// \return Status The error code return
/// \return Status The status code returned
Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = ""); Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = "");


/// \brief A print method typically used for debugging /// \brief A print method typically used for debugging


+ 0
- 4
mindspore/dataset/engine/datasets.py View File

@@ -669,8 +669,6 @@ class Dataset:
>>> repeat_and_shuffle = data.repeat(50) >>> repeat_and_shuffle = data.repeat(50)
>>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
""" """
if count == 1:
return self
return RepeatDataset(self, count) return RepeatDataset(self, count)


@check_skip @check_skip
@@ -717,8 +715,6 @@ class Dataset:
>>> # Create a dataset where the dataset includes 50 elements. >>> # Create a dataset where the dataset includes 50 elements.
>>> data = data.take(50) >>> data = data.take(50)
""" """
if count == -1:
return self
return TakeDataset(self, count) return TakeDataset(self, count)


def _get_absolute_split_sizes(self, sizes): def _get_absolute_split_sizes(self, sizes):


+ 45
- 0
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -1311,6 +1311,51 @@ TEST_F(MindDataTestPipeline, TestSkipDataset) {
iter->Stop(); iter->Stop();
} }


TEST_F(MindDataTestPipeline, TestSkipTakeRepeat) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipTakeRepeat.";

// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 6));

// Create a Skip operation on ds
int32_t count = 0;
ds = ds->Skip(count);

// Create a Project operation on ds
std::vector<std::string> column_project = {"image"};
ds = ds->Project(column_project);

// Add a Take(-1)
ds = ds->Take(-1);

// Add a Repeat(1)
ds = ds->Repeat(1);

// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();

// iterate over the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);

uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
MS_LOG(INFO) << "Number of rows: " << i;

// Expect 6 rows
EXPECT_EQ(i, 6);

// Manually terminate the pipeline
iter->Stop();
}

TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) { TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize."; MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize.";




+ 16
- 3
tests/ut/python/dataset/test_take.py View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger


@@ -163,7 +163,7 @@ def test_take_08():


def test_take_09(): def test_take_09():
""" """
Test take: repeat count is -1, and read the whole dataset, take after repeat
Test take: take count is -1, and read the whole dataset, take after repeat
""" """
logger.info("test_take_09") logger.info("test_take_09")
data1 = ds.GeneratorDataset(generator, ["data"]) data1 = ds.GeneratorDataset(generator, ["data"])
@@ -180,7 +180,7 @@ def test_take_09():


def test_take_10(): def test_take_10():
""" """
Test take: repeat count is -1, and read the whole dataset, take before repeat
Test take: take count is -1, and read the whole dataset, take before repeat
""" """
logger.info("test_take_10") logger.info("test_take_10")
data1 = ds.GeneratorDataset(generator, ["data"]) data1 = ds.GeneratorDataset(generator, ["data"])
@@ -341,6 +341,18 @@ def test_take_18():
assert sum([1 for _ in data1]) == 2 assert sum([1 for _ in data1]) == 2




def test_take_19():
"""
Test take: take is after batch, that mean take(N), N refer to batches num
"""
logger.info("test_take_19")
with pytest.raises(ValueError) as info:
data1 = ds.GeneratorDataset(generator, ["data"])

data1 = data1.batch(2)
data1 = data1.take(0)
assert "positive integer" in str(info.value)

if __name__ == '__main__': if __name__ == '__main__':
test_take_01() test_take_01()
test_take_02() test_take_02()
@@ -360,4 +372,5 @@ if __name__ == '__main__':
test_take_16() test_take_16()
test_take_17() test_take_17()
test_take_18() test_take_18()
test_take_19()
logger.info('== test take operation finished ==') logger.info('== test take operation finished ==')

Loading…
Cancel
Save