From: @nsyca Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -490,12 +490,6 @@ RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<s | |||||
| #endif | #endif | ||||
| RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) { | RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) { | ||||
| // Workaround for repeat == 1, do not inject repeat. | |||||
| if (count == 1) { | |||||
| ir_node_ = input->IRNode(); | |||||
| return; | |||||
| } | |||||
| auto ds = std::make_shared<RepeatNode>(input->IRNode(), count); | auto ds = std::make_shared<RepeatNode>(input->IRNode(), count); | ||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ||||
| @@ -516,13 +510,6 @@ SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) { | |||||
| } | } | ||||
| TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) { | TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) { | ||||
| // If count is greater than the number of element in dataset or equal to -1, | |||||
| // all the element in dataset will be taken | |||||
| if (count == -1) { | |||||
| ir_node_ = input->IRNode(); | |||||
| return; | |||||
| } | |||||
| auto ds = std::make_shared<TakeNode>(input->IRNode(), count); | auto ds = std::make_shared<TakeNode>(input->IRNode(), count); | ||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ||||
| @@ -66,7 +66,7 @@ class ColDescriptor { | |||||
| /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. | /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. | ||||
| /// \param[in] num_elements - The number of elements in the data for a Tensor | /// \param[in] num_elements - The number of elements in the data for a Tensor | ||||
| /// \param[inout] out_shape - The materialized output Tensor shape | /// \param[inout] out_shape - The materialized output Tensor shape | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; | Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; | ||||
| /// \brief << Stream output operator overload | /// \brief << Stream output operator overload | ||||
| @@ -124,13 +124,13 @@ class DataSchema { | |||||
| /// \brief Parses a schema json file and populates the columns and meta info. | /// \brief Parses a schema json file and populates the columns and meta info. | ||||
| /// \param[in] schema_file_path - the schema file that has the column's info to load | /// \param[in] schema_file_path - the schema file that has the column's info to load | ||||
| /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. | /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load); | Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load); | ||||
| /// \brief Parses a schema JSON string and populates the columns and meta info. | /// \brief Parses a schema JSON string and populates the columns and meta info. | ||||
| /// \param[in] schema_json_string - the schema file that has the column's info to load | /// \param[in] schema_json_string - the schema file that has the column's info to load | ||||
| /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. | /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load); | Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load); | ||||
| /// \brief A print method typically used for debugging | /// \brief A print method typically used for debugging | ||||
| @@ -148,7 +148,7 @@ class DataSchema { | |||||
| /// \brief Adds a column descriptor to the schema | /// \brief Adds a column descriptor to the schema | ||||
| /// \param[in] cd - The ColDescriptor to add | /// \param[in] cd - The ColDescriptor to add | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status AddColumn(const ColDescriptor &cd); | Status AddColumn(const ColDescriptor &cd); | ||||
| /// \brief getter | /// \brief getter | ||||
| @@ -169,7 +169,7 @@ class DataSchema { | |||||
| /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. | /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. | ||||
| /// \param[inout] out_column_name_map - The output map of columns names to column index | /// \param[inout] out_column_name_map - The output map of columns names to column index | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map); | Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map); | ||||
| private: | private: | ||||
| @@ -177,7 +177,7 @@ class DataSchema { | |||||
| /// does not follow any particular order (json standard does not enforce any ordering protocol). | /// does not follow any particular order (json standard does not enforce any ordering protocol). | ||||
| /// This one produces a schema that contains all of the columns from the schema file. | /// This one produces a schema that contains all of the columns from the schema file. | ||||
| /// \param[in] column_tree - The nlohmann tree from the json file to parse | /// \param[in] column_tree - The nlohmann tree from the json file to parse | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status AnyOrderLoad(nlohmann::json column_tree); | Status AnyOrderLoad(nlohmann::json column_tree); | ||||
| /// \brief Internal helper function. For each input column name, perform a lookup to the json document to | /// \brief Internal helper function. For each input column name, perform a lookup to the json document to | ||||
| @@ -185,18 +185,18 @@ class DataSchema { | |||||
| /// descriptor and add to the schema in the order in which the input column names are given. | /// descriptor and add to the schema in the order in which the input column names are given. | ||||
| /// \param[in] column_tree - The nlohmann tree from the json file to parse | /// \param[in] column_tree - The nlohmann tree from the json file to parse | ||||
| /// \param[in] columns_to_load - list of strings for the columns to add to the schema | /// \param[in] columns_to_load - list of strings for the columns to add to the schema | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load); | Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load); | ||||
| /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. | /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. | ||||
| /// \param[in] columnTree - The nlohmann child tree for a given column to load. | /// \param[in] columnTree - The nlohmann child tree for a given column to load. | ||||
| /// \param[in] col_name - The string name of the column for that subtree. | /// \param[in] col_name - The string name of the column for that subtree. | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); | Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); | ||||
| /// \brief Internal helper function. Performs sanity checks on the json file setup. | /// \brief Internal helper function. Performs sanity checks on the json file setup. | ||||
| /// \param[in] js - The nlohmann tree for the schema file | /// \param[in] js - The nlohmann tree for the schema file | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreLoadExceptionCheck(const nlohmann::json &js); | Status PreLoadExceptionCheck(const nlohmann::json &js); | ||||
| std::vector<ColDescriptor> col_descs_; // Vector of column descriptors | std::vector<ColDescriptor> col_descs_; // Vector of column descriptors | ||||
| @@ -53,7 +53,7 @@ class IteratorBase { | |||||
| // functionality exists in the derived versions of this function. | // functionality exists in the derived versions of this function. | ||||
| // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | ||||
| // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| // @note The position of a Tensor/column might be different from the initial column order | // @note The position of a Tensor/column might be different from the initial column order | ||||
| // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change | // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change | ||||
| // the column ordering. | // the column ordering. | ||||
| @@ -97,17 +97,17 @@ class DatasetIterator : public IteratorBase { | |||||
| // from the tree root node directly. | // from the tree root node directly. | ||||
| // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | ||||
| // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status FetchNextTensorRow(TensorRow *out_row) override; | Status FetchNextTensorRow(TensorRow *out_row) override; | ||||
| // Fetches the next tensor row into device row, and returns it's shape. | // Fetches the next tensor row into device row, and returns it's shape. | ||||
| // @param out_shapes - A vector of tensor shapes (one shape per column) | // @param out_shapes - A vector of tensor shapes (one shape per column) | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetOutputShapes(std::vector<TensorShape> *out_shapes); | Status GetOutputShapes(std::vector<TensorShape> *out_shapes); | ||||
| // Fetches the next tensor row into device row, and returns it's shape. | // Fetches the next tensor row into device row, and returns it's shape. | ||||
| // @param outShapes - A vector of tensor shapes (one shape per column) | // @param outShapes - A vector of tensor shapes (one shape per column) | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetOutputTypes(std::vector<DataType> *out_types); | Status GetOutputTypes(std::vector<DataType> *out_types); | ||||
| // Getter | // Getter | ||||
| @@ -140,12 +140,12 @@ class ChildIterator : public IteratorBase { | |||||
| // only from the child/worker id as given from the constructor. | // only from the child/worker id as given from the constructor. | ||||
| // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | ||||
| // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status FetchNextTensorRow(TensorRow *out_row) override; | Status FetchNextTensorRow(TensorRow *out_row) override; | ||||
| // This function drains buffer until next eoe has been received. | // This function drains buffer until next eoe has been received. | ||||
| // It will be a no-op if the previous row returned is empty. | // It will be a no-op if the previous row returned is empty. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Drain(); | Status Drain(); | ||||
| // Getter | // Getter | ||||
| @@ -134,7 +134,7 @@ class BarrierOp : public PipelineOp { | |||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Handles preprocessing of the main loop, used when starting new epoch | // Handles preprocessing of the main loop, used when starting new epoch | ||||
| @@ -112,12 +112,12 @@ class BatchOp : public ParallelOp { | |||||
| #endif | #endif | ||||
| // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg | // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<BatchOp> *); | Status Build(std::shared_ptr<BatchOp> *); | ||||
| private: | private: | ||||
| // Sanity check for builder class args | // Sanity check for builder class args | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| bool builder_drop_; | bool builder_drop_; | ||||
| @@ -167,11 +167,11 @@ class BatchOp : public ParallelOp { | |||||
| ~BatchOp() {} | ~BatchOp() {} | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status EofReceived(int32_t) override; | Status EofReceived(int32_t) override; | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status EoeReceived(int32_t) override; | Status EoeReceived(int32_t) override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -190,7 +190,7 @@ class BatchOp : public ParallelOp { | |||||
| } | } | ||||
| // Main loop of batch | // Main loop of batch | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| @@ -214,14 +214,14 @@ class BatchOp : public ParallelOp { | |||||
| // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows | // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows | ||||
| // @param int32_t size - batch_size | // @param int32_t size - batch_size | ||||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, | static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, | ||||
| dsize_t batch_size); | dsize_t batch_size); | ||||
| // @param table | // @param table | ||||
| // @param const PadInfo &pad_info pad info | // @param const PadInfo &pad_info pad info | ||||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | ||||
| const std::unordered_map<std::string, int32_t> &column_name_id_map); | const std::unordered_map<std::string, int32_t> &column_name_id_map); | ||||
| @@ -233,18 +233,18 @@ class BatchOp : public ParallelOp { | |||||
| private: | private: | ||||
| // Worker thread for doing the memcpy of batch | // Worker thread for doing the memcpy of batch | ||||
| // @param int32_t param workerId | // @param int32_t param workerId | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Generate buffer with batched tensors | // Generate buffer with batched tensors | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, | Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, | ||||
| std::unique_ptr<DataBuffer> *db); | std::unique_ptr<DataBuffer> *db); | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| // Function that calls pyfunc to perform map on batch | // Function that calls pyfunc to perform map on batch | ||||
| // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor | // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair); | Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair); | ||||
| #endif | #endif | ||||
| @@ -253,7 +253,7 @@ class BatchOp : public ParallelOp { | |||||
| // @param std::set<int32_t> *cols, col ids to perform pad on | // @param std::set<int32_t> *cols, col ids to perform pad on | ||||
| // @param std::vector<float> *vals, default padding value for each column | // @param std::vector<float> *vals, default padding value for each column | ||||
| // @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user | // @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| static Status UnpackPadInfo(const PadInfo &pad_info, | static Status UnpackPadInfo(const PadInfo &pad_info, | ||||
| const std::unordered_map<std::string, int32_t> &column_name_id_map, | const std::unordered_map<std::string, int32_t> &column_name_id_map, | ||||
| std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, | std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, | ||||
| @@ -264,20 +264,20 @@ class BatchOp : public ParallelOp { | |||||
| int32_t num_consumers() const override { return 1; } | int32_t num_consumers() const override { return 1; } | ||||
| // get the batch size for next batch | // get the batch size for next batch | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetBatchSize(int32_t *batch_size, CBatchInfo info); | Status GetBatchSize(int32_t *batch_size, CBatchInfo info); | ||||
| // Do the initialization of all queues then start all worker threads | // Do the initialization of all queues then start all worker threads | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| // Invoke batch size function with current BatchInfo to generate batch size. | // Invoke batch size function with current BatchInfo to generate batch size. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); | Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); | ||||
| // Invoke batch map function with current BatchInfo to generate tensors to batch. | // Invoke batch map function with current BatchInfo to generate tensors to batch. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); | Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); | ||||
| #endif | #endif | ||||
| @@ -107,7 +107,7 @@ class BucketBatchByLengthOp : public PipelineOp { | |||||
| // Might need to batch remaining buckets after receiving eoe, so override this method. | // Might need to batch remaining buckets after receiving eoe, so override this method. | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return Status - The error code returned | |||||
| // @return Status The status code returned | |||||
| Status EoeReceived(int32_t) override; | Status EoeReceived(int32_t) override; | ||||
| std::string Name() const override { return kBucketBatchByLengthOp; } | std::string Name() const override { return kBucketBatchByLengthOp; } | ||||
| @@ -123,7 +123,7 @@ class BucketBatchByLengthOp : public PipelineOp { | |||||
| } | } | ||||
| // Main loop of batch | // Main loop of batch | ||||
| // @return Status - The error code returned | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| private: | private: | ||||
| @@ -104,7 +104,7 @@ class BuildSentencePieceVocabOp : public PipelineOp { | |||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op); | Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op); | ||||
| private: | private: | ||||
| @@ -110,7 +110,7 @@ class BuildVocabOp : public ParallelOp { | |||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<BuildVocabOp> *op); | Status Build(std::shared_ptr<BuildVocabOp> *op); | ||||
| private: | private: | ||||
| @@ -53,7 +53,7 @@ class CacheBase : public ParallelOp { | |||||
| /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state | /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state | ||||
| /// info from it's previous execution and then initializes itself so that it can be executed | /// info from it's previous execution and then initializes itself so that it can be executed | ||||
| /// again. | /// again. | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| /// \brief A print method typically used for debugging | /// \brief A print method typically used for debugging | ||||
| @@ -80,7 +80,7 @@ class CacheLookupOp : public CacheBase, public SamplerRT { | |||||
| std::shared_ptr<SamplerRT> build_sampler_; | std::shared_ptr<SamplerRT> build_sampler_; | ||||
| // Check if the required parameters are set by the builder. | // Check if the required parameters are set by the builder. | ||||
| // \return Status The error code return | |||||
| // \return Status The status code returned | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| }; | }; | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -136,7 +136,7 @@ class CacheMergeOp : public ParallelOp { | |||||
| std::shared_ptr<SamplerRT> build_sampler_; | std::shared_ptr<SamplerRT> build_sampler_; | ||||
| /// Check if the required parameters are set by the builder. | /// Check if the required parameters are set by the builder. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| }; | }; | ||||
| @@ -189,7 +189,7 @@ class CacheMergeOp : public ParallelOp { | |||||
| /// \brief Base-class override for handling cases when an eof is received. | /// \brief Base-class override for handling cases when an eof is received. | ||||
| /// \param worker_id - The worker id | /// \param worker_id - The worker id | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status EofReceived(int32_t worker_id) override; | Status EofReceived(int32_t worker_id) override; | ||||
| protected: | protected: | ||||
| @@ -99,7 +99,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||||
| std::shared_ptr<SamplerRT> build_sampler_; | std::shared_ptr<SamplerRT> build_sampler_; | ||||
| /// \brief Check if the required parameters are set by the builder. | /// \brief Check if the required parameters are set by the builder. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| }; | }; | ||||
| @@ -119,7 +119,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||||
| /// \brief Base-class override for special eoe handler. | /// \brief Base-class override for special eoe handler. | ||||
| /// CacheOp must override this because it shall not perform default handling of eoe. Instead | /// CacheOp must override this because it shall not perform default handling of eoe. Instead | ||||
| /// the CacheOp manages actions related to the end of the epoch. | /// the CacheOp manages actions related to the end of the epoch. | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status EoeReceived(int32_t worker_id) override; | Status EoeReceived(int32_t worker_id) override; | ||||
| /// \brief Base-class override for NodePass pre-visit acceptor | /// \brief Base-class override for NodePass pre-visit acceptor | ||||
| /// \param[in] p The node to visit | /// \param[in] p The node to visit | ||||
| @@ -133,7 +133,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||||
| Status Accept(NodePass *p, bool *modified) override; | Status Accept(NodePass *p, bool *modified) override; | ||||
| /// \brief Base-class override for handling cases when an eof is received. | /// \brief Base-class override for handling cases when an eof is received. | ||||
| /// \param worker_id - The worker id | /// \param worker_id - The worker id | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status EofReceived(int32_t worker_id) override; | Status EofReceived(int32_t worker_id) override; | ||||
| Status operator()() override; | Status operator()() override; | ||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| @@ -159,7 +159,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||||
| Status CacheAllRows(int32_t worker_id); | Status CacheAllRows(int32_t worker_id); | ||||
| Status RegisterResources() override; | Status RegisterResources() override; | ||||
| /// \brief Private function for cache setup/init work just after construction | /// \brief Private function for cache setup/init work just after construction | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status InitCache(); | Status InitCache(); | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -94,7 +94,7 @@ class ConcatOp : public PipelineOp { | |||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Op name getter | // Op name getter | ||||
| @@ -146,14 +146,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// DatasetOps operate by launching a thread (see ExecutionTree). | /// DatasetOps operate by launching a thread (see ExecutionTree). | ||||
| /// This pure virtual version makes the requirement that derived classes must provide a functor | /// This pure virtual version makes the requirement that derived classes must provide a functor | ||||
| /// that will execute their main runtime loop code. | /// that will execute their main runtime loop code. | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status operator()() = 0; | virtual Status operator()() = 0; | ||||
| /// \brief Gets the next buffer from the given child | /// \brief Gets the next buffer from the given child | ||||
| /// \notes See GetNextInput for similar function that has built-in message handling | /// \notes See GetNextInput for similar function that has built-in message handling | ||||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | ||||
| /// \param worker_id - The worker id | /// \param worker_id - The worker id | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) { | virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) { | ||||
| return GetNextBuffer(p_buffer, worker_id, false); | return GetNextBuffer(p_buffer, worker_id, false); | ||||
| } | } | ||||
| @@ -161,7 +161,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \brief Gets the next buffer from the given child | /// \brief Gets the next buffer from the given child | ||||
| /// \notes See GetNextInput for similar function that has built-in message handling | /// \notes See GetNextInput for similar function that has built-in message handling | ||||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } | virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } | ||||
| /// \brief Gets the next buffer from the given child | /// \brief Gets the next buffer from the given child | ||||
| @@ -169,7 +169,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | ||||
| /// \param worker_id - The worker id | /// \param worker_id - The worker id | ||||
| /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe); | virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe); | ||||
| /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof | /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof | ||||
| @@ -177,7 +177,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// those messages are received. | /// those messages are received. | ||||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | ||||
| /// \param worker_id - The worker id | /// \param worker_id - The worker id | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); | Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); | ||||
| /// \brief Gets the batch size | /// \brief Gets the batch size | ||||
| @@ -200,19 +200,19 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// The base class implementation simply flows the eoe message to output. Derived classes | /// The base class implementation simply flows the eoe message to output. Derived classes | ||||
| /// may override if they need to perform special eoe handling. | /// may override if they need to perform special eoe handling. | ||||
| /// \param worker_id - The worker id | /// \param worker_id - The worker id | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status EoeReceived(int32_t worker_id); | virtual Status EoeReceived(int32_t worker_id); | ||||
| /// \brief Performs handling for when an eof message is received. | /// \brief Performs handling for when an eof message is received. | ||||
| /// The base class implementation simply flows the eof message to output. Derived classes | /// The base class implementation simply flows the eof message to output. Derived classes | ||||
| /// may override if they need to perform special eof handling. | /// may override if they need to perform special eof handling. | ||||
| /// \param worker_id - The worker id | /// \param worker_id - The worker id | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status EofReceived(int32_t worker_id); | virtual Status EofReceived(int32_t worker_id); | ||||
| /// \brief Derived classes may implement the reset function if the operator is stateful and needs | /// \brief Derived classes may implement the reset function if the operator is stateful and needs | ||||
| /// specific reset handling that is not contained in this common code version of the reset | /// specific reset handling that is not contained in this common code version of the reset | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status Reset(); | virtual Status Reset(); | ||||
| /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on | /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on | ||||
| @@ -79,7 +79,7 @@ class FilterOp : public ParallelOp { | |||||
| private: | private: | ||||
| // Sanity check for builder class args. | // Sanity check for builder class args. | ||||
| // @return Status - The error code return. | |||||
| // @return Status The status code returned. | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| std::vector<std::string> build_in_col_names_; | std::vector<std::string> build_in_col_names_; | ||||
| std::shared_ptr<TensorOp> builder_predicate_func_; | std::shared_ptr<TensorOp> builder_predicate_func_; | ||||
| @@ -105,15 +105,15 @@ class FilterOp : public ParallelOp { | |||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will | ||||
| // provide the master loop that drives the logic for performing the work. | // provide the master loop that drives the logic for performing the work. | ||||
| // @return Status The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // @param int32_t workerId. | // @param int32_t workerId. | ||||
| // @return Status - The error code return. | |||||
| // @return Status The status code returned. | |||||
| Status EofReceived(int32_t) override; | Status EofReceived(int32_t) override; | ||||
| // @param int32_t workerId. | // @param int32_t workerId. | ||||
| // @return Status - The error code return. | |||||
| // @return Status The status code returned. | |||||
| Status EoeReceived(int32_t) override; | Status EoeReceived(int32_t) override; | ||||
| // A print method typically used for debugging. | // A print method typically used for debugging. | ||||
| @@ -151,34 +151,34 @@ class FilterOp : public ParallelOp { | |||||
| // logic of FilterOp, getting the data from previous Op, validating user specified column names, | // logic of FilterOp, getting the data from previous Op, validating user specified column names, | ||||
| // applying predicate to each of the data, filter the data when predicate result is false. | // applying predicate to each of the data, filter the data when predicate result is false. | ||||
| // @param worker_id The id assigned to this thread/worker upon creation. | // @param worker_id The id assigned to this thread/worker upon creation. | ||||
| // @return Status The error code return. | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | ||||
| // Filter the data by predicate function . | // Filter the data by predicate function . | ||||
| // @param in_buffer input data buffer. | // @param in_buffer input data buffer. | ||||
| // @param to_proess_indices Indices of columns to be processed. | // @param to_proess_indices Indices of columns to be processed. | ||||
| // @param out data buffer that are filtered by predicate. | // @param out data buffer that are filtered by predicate. | ||||
| // @return Status The error code return. | |||||
| // @return Status The status code returned | |||||
| Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out); | Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out); | ||||
| // Collector databuffer. | // Collector databuffer. | ||||
| // @return Status The error code return. | |||||
| // @return Status The status code returned | |||||
| Status Collector(); | Status Collector(); | ||||
| // @param input tensor vector. | // @param input tensor vector. | ||||
| // @return Status - The error code return. | |||||
| // @return Status The status code returned. | |||||
| Status CheckInput(const TensorRow &input) const; | Status CheckInput(const TensorRow &input) const; | ||||
| // Invoke python func. | // Invoke python func. | ||||
| // @param input tensor vector. | // @param input tensor vector. | ||||
| // @param the result of predicate. | // @param the result of predicate. | ||||
| // @return Status - The error code return. | |||||
| // @return Status The status code returned. | |||||
| Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); | Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); | ||||
| // Private function for validating if each of the user specified input column names | // Private function for validating if each of the user specified input column names | ||||
| // exist in the DataBuffer. | // exist in the DataBuffer. | ||||
| // @param input_columns The vector of input column names used in the current thread. | // @param input_columns The vector of input column names used in the current thread. | ||||
| // @return Status The error code return. | |||||
| // @return Status The status code returned | |||||
| Status ValidateInColumns(const std::vector<std::string> *input_columns); | Status ValidateInColumns(const std::vector<std::string> *input_columns); | ||||
| // Private function for checking the column legality | // Private function for checking the column legality | ||||
| @@ -133,7 +133,7 @@ class MapOp : public ParallelOp { | |||||
| int32_t build_op_connector_size_; | int32_t build_op_connector_size_; | ||||
| // Check if the required parameters are set by the builder. | // Check if the required parameters are set by the builder. | ||||
| // @return Status The error code return | |||||
| // @return Status The status code returned | |||||
| Status sanityCheck() const; | Status sanityCheck() const; | ||||
| }; | }; | ||||
| @@ -170,7 +170,7 @@ class MapOp : public ParallelOp { | |||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // This main thread creates local queues, pulls databuffers from the previous | // This main thread creates local queues, pulls databuffers from the previous | ||||
| // op's Connector and distributes them to the local queues. Workers pull from the local queues. | // op's Connector and distributes them to the local queues. Workers pull from the local queues. | ||||
| // @return Status The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Getter | // Getter | ||||
| @@ -239,7 +239,7 @@ class MapOp : public ParallelOp { | |||||
| // applying a list of TensorOps to each of the data, process the results and then | // applying a list of TensorOps to each of the data, process the results and then | ||||
| // pushing them back to MapOp's output Connector to be fetched by the next Op. | // pushing them back to MapOp's output Connector to be fetched by the next Op. | ||||
| // @param worker_id The id assigned to this thread/worker upon creation. | // @param worker_id The id assigned to this thread/worker upon creation. | ||||
| // @return Status The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | ||||
| // Private function for worker thread to perform TensorOp's compute function and get the result. | // Private function for worker thread to perform TensorOp's compute function and get the result. | ||||
| @@ -89,7 +89,7 @@ class ParallelOp : public DatasetOp { | |||||
| } | } | ||||
| // Override base class reset to provide reset actions specific to the ParallelOp class. | // Override base class reset to provide reset actions specific to the ParallelOp class. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Getter | // Getter | ||||
| @@ -115,7 +115,7 @@ class ParallelOp : public DatasetOp { | |||||
| protected: | protected: | ||||
| // Interface for derived classes to implement. All derived classes must provide the entry | // Interface for derived classes to implement. All derived classes must provide the entry | ||||
| // function with the main execution loop for worker threads. | // function with the main execution loop for worker threads. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status WorkerEntry(int32_t workerId) = 0; | virtual Status WorkerEntry(int32_t workerId) = 0; | ||||
| /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp | /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp | ||||
| @@ -75,7 +75,7 @@ class ProjectOp : public PipelineOp { | |||||
| // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the | // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the | ||||
| // functor since this op runs inlined inside another operator. The function is overloaded to | // functor since this op runs inlined inside another operator. The function is overloaded to | ||||
| // ensure that it is not called by mistake (it will generate an error). | // ensure that it is not called by mistake (it will generate an error). | ||||
| // @return Status - The error code returned. | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. | // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. | ||||
| @@ -93,12 +93,12 @@ class ProjectOp : public PipelineOp { | |||||
| // Base-class override for special eoe handler. | // Base-class override for special eoe handler. | ||||
| // Inline operators must override this because there is no connector to push eoe onto. | // Inline operators must override this because there is no connector to push eoe onto. | ||||
| // @return Status - The error code returned. | |||||
| // @return Status The status code returned | |||||
| Status EoeReceived(int32_t worker_id) override; | Status EoeReceived(int32_t worker_id) override; | ||||
| // Base-class override for special eof handler. | // Base-class override for special eof handler. | ||||
| // Inline operators must override this because there is no connector to push eof onto. | // Inline operators must override this because there is no connector to push eof onto. | ||||
| // @return Status - The error code returned. | |||||
| // @return Status The status code returned | |||||
| Status EofReceived(int32_t worker_id) override; | Status EofReceived(int32_t worker_id) override; | ||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| @@ -107,7 +107,7 @@ class RenameOp : public PipelineOp { | |||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| @@ -78,7 +78,7 @@ class RepeatOp : public PipelineOp { | |||||
| // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the | // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the | ||||
| // functor since this op runs inlined inside another operator. The function is overloaded to | // functor since this op runs inlined inside another operator. The function is overloaded to | ||||
| // ensure that it is not called by mistake (it will generate an error). | // ensure that it is not called by mistake (it will generate an error). | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // This function returns the buffer that is at the top of our output connector. The caller is | // This function returns the buffer that is at the top of our output connector. The caller is | ||||
| @@ -90,7 +90,7 @@ class RepeatOp : public PipelineOp { | |||||
| // @param p_buffer - output pointer to the buffer that it will fetch. | // @param p_buffer - output pointer to the buffer that it will fetch. | ||||
| // @param worker_id - The worker id | // @param worker_id - The worker id | ||||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | ||||
| // Base-class override for handling cases when an eoe is received. | // Base-class override for handling cases when an eoe is received. | ||||
| @@ -130,7 +130,7 @@ class RepeatOp : public PipelineOp { | |||||
| int32_t num_repeats() { return num_repeats_; } | int32_t num_repeats() { return num_repeats_; } | ||||
| /// \brief reset Op | /// \brief reset Op | ||||
| /// \@return Status - The error code return | |||||
| /// \@return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| int64_t GetTreeRepeatCount() override; | int64_t GetTreeRepeatCount() override; | ||||
| @@ -146,13 +146,13 @@ class ShuffleOp : public PipelineOp { | |||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for special eoe handler. | // Base-class override for special eoe handler. | ||||
| // ShuffleOp must override this because it shall not perform default handling of eoe. Instead | // ShuffleOp must override this because it shall not perform default handling of eoe. Instead | ||||
| // the ShuffleOp needs to manage actions related to the end of the epoch itself. | // the ShuffleOp needs to manage actions related to the end of the epoch itself. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status EoeReceived(int32_t worker_id) override; | Status EoeReceived(int32_t worker_id) override; | ||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| @@ -167,17 +167,17 @@ class ShuffleOp : public PipelineOp { | |||||
| private: | private: | ||||
| // Private function to add a new row to the shuffle buffer. | // Private function to add a new row to the shuffle buffer. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); | Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); | ||||
| // Private function to populate the shuffle buffer initially by fetching from the child output | // Private function to populate the shuffle buffer initially by fetching from the child output | ||||
| // connector until the shuffle buffer is full (or there is no more data coming). | // connector until the shuffle buffer is full (or there is no more data coming). | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InitShuffleBuffer(); | Status InitShuffleBuffer(); | ||||
| // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by | // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by | ||||
| // itself rather than waiting for the reset driven from operators above it in the pipeline. | // itself rather than waiting for the reset driven from operators above it in the pipeline. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SelfReset(); | Status SelfReset(); | ||||
| int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) | int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) | ||||
| @@ -63,7 +63,7 @@ class SkipOp : public PipelineOp { | |||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for handling cases when an eoe is received. | // Base-class override for handling cases when an eoe is received. | ||||
| @@ -130,12 +130,12 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| /// \brief Check validity of input args | /// \brief Check validity of input args | ||||
| /// \return - The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| /// \brief The builder "build" method creates the final object. | /// \brief The builder "build" method creates the final object. | ||||
| /// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp | /// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp | ||||
| /// \return - The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status Build(std::shared_ptr<AlbumOp> *op); | Status Build(std::shared_ptr<AlbumOp> *op); | ||||
| private: | private: | ||||
| @@ -168,18 +168,18 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||||
| ~AlbumOp() = default; | ~AlbumOp() = default; | ||||
| /// \brief Initialize AlbumOp related var, calls the function to walk all files | /// \brief Initialize AlbumOp related var, calls the function to walk all files | ||||
| /// \return - The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status PrescanEntry(); | Status PrescanEntry(); | ||||
| /// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | /// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| /// \param[in] int32_t workerId - id of each worker | /// \param[in] int32_t workerId - id of each worker | ||||
| /// \return Status - The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| /// \brief Main Loop of AlbumOp | /// \brief Main Loop of AlbumOp | ||||
| /// Master thread: Fill IOBlockQueue, then goes to sleep | /// Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| /// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | /// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | ||||
| /// \return Status - The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| /// \brief A print method typically used for debugging | /// \brief A print method typically used for debugging | ||||
| @@ -204,93 +204,93 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| /// \brief Initialize Sampler, calls sampler->Init() within | /// \brief Initialize Sampler, calls sampler->Init() within | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status InitSampler(); | Status InitSampler(); | ||||
| /// \brief Load image to tensor row | /// \brief Load image to tensor row | ||||
| /// \param[in] image_file Image name of file | /// \param[in] image_file Image name of file | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row); | Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load vector of ints to tensor, append tensor to tensor row | /// \brief Load vector of ints to tensor, append tensor to tensor row | ||||
| /// \param[in] json_obj Json object containing multi-dimensional label | /// \param[in] json_obj Json object containing multi-dimensional label | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load vector of floatss to tensor, append tensor to tensor row | /// \brief Load vector of floatss to tensor, append tensor to tensor row | ||||
| /// \param[in] json_obj Json object containing array data | /// \param[in] json_obj Json object containing array data | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load string array into a tensor, append tensor to tensor row | /// \brief Load string array into a tensor, append tensor to tensor row | ||||
| /// \param[in] json_obj Json object containing string tensor | /// \param[in] json_obj Json object containing string tensor | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load string into a tensor, append tensor to tensor row | /// \brief Load string into a tensor, append tensor to tensor row | ||||
| /// \param[in] json_obj Json object containing string tensor | /// \param[in] json_obj Json object containing string tensor | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load float value to tensor row | /// \brief Load float value to tensor row | ||||
| /// \param[in] json_obj Json object containing float | /// \param[in] json_obj Json object containing float | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load int value to tensor row | /// \brief Load int value to tensor row | ||||
| /// \param[in] json_obj Json object containing int | /// \param[in] json_obj Json object containing int | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load emtpy tensor to tensor row | /// \brief Load emtpy tensor to tensor row | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadEmptyTensor(uint32_t col_num, TensorRow *row); | Status LoadEmptyTensor(uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load id from file name to tensor row | /// \brief Load id from file name to tensor row | ||||
| /// \param[in] file The file name to get ID from | /// \param[in] file The file name to get ID from | ||||
| /// \param[in] col_num Column num in schema | /// \param[in] col_num Column num in schema | ||||
| /// \param[inout] row Tensor row to push to | /// \param[inout] row Tensor row to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row); | Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row); | ||||
| /// \brief Load a tensor row according to a json file | /// \brief Load a tensor row according to a json file | ||||
| /// \param[in] row_id_type row_id - id for this tensor row | /// \param[in] row_id_type row_id - id for this tensor row | ||||
| /// \param[in] ImageColumns file Json file location | /// \param[in] ImageColumns file Json file location | ||||
| /// \param[inout] TensorRow row Json content stored into a tensor row | /// \param[inout] TensorRow row Json content stored into a tensor row | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadTensorRow(row_id_type row_id, const std::string &file, TensorRow *row); | Status LoadTensorRow(row_id_type row_id, const std::string &file, TensorRow *row); | ||||
| /// \param[in] const std::vector<int64_t> &keys Keys in ioblock | /// \param[in] const std::vector<int64_t> &keys Keys in ioblock | ||||
| /// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to | /// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| /// \brief Called first when function is called | /// \brief Called first when function is called | ||||
| /// \return Status The error code returned | |||||
| /// \return Status The status code returned | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| /// \brief reset Op | /// \brief reset Op | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| // @return Status The error code returned | |||||
| // @return Status The status code returned | |||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| @@ -116,12 +116,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| // @param std::shared_ptr<CelebAOp> *op - DatasetOp | // @param std::shared_ptr<CelebAOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<CelebAOp> *op); | Status Build(std::shared_ptr<CelebAOp> *op); | ||||
| private: | private: | ||||
| @@ -151,12 +151,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| // Main Loop of CelebAOp | // Main Loop of CelebAOp | ||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t worker_id - id of each worker | // @param int32_t worker_id - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -166,7 +166,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| // Method in operator(), to fill IOBlockQueue | // Method in operator(), to fill IOBlockQueue | ||||
| // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue | // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer); | Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer); | ||||
| /// \brief Base-class override for NodePass visitor acceptor | /// \brief Base-class override for NodePass visitor acceptor | ||||
| @@ -199,14 +199,14 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| // @param const std::vector<int64_t> &keys - keys in ioblock | // @param const std::vector<int64_t> &keys - keys in ioblock | ||||
| // @param std::unique_ptr<DataBuffer> db | // @param std::unique_ptr<DataBuffer> db | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| // Load a tensor row according to a pair | // Load a tensor row according to a pair | ||||
| // @param row_id_type row_id - id for this tensor row | // @param row_id_type row_id - id for this tensor row | ||||
| // @param std::pair - <image_file,<label>> | // @param std::pair - <image_file,<label>> | ||||
| // @param TensorRow row - image & label read into this tensor row | // @param TensorRow row - image & label read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<int32_t>> &image_label, | Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<int32_t>> &image_label, | ||||
| TensorRow *row); | TensorRow *row); | ||||
| @@ -215,7 +215,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| bool CheckDatasetTypeValid(); | bool CheckDatasetTypeValid(); | ||||
| // reset Op | // reset Op | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| @@ -109,12 +109,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| // @param std::shared_ptr<CifarOp> *op - DatasetOp | // @param std::shared_ptr<CifarOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<CifarOp> *op); | Status Build(std::shared_ptr<CifarOp> *op); | ||||
| private: | private: | ||||
| @@ -144,13 +144,13 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param uint32_t workerId - id of each worker | // @param uint32_t workerId - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Main Loop of CifarOp | // Main Loop of CifarOp | ||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -177,18 +177,18 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InitSampler(); | Status InitSampler(); | ||||
| // Load a tensor row according to a pair | // Load a tensor row according to a pair | ||||
| // @param uint64_t index - index need to load | // @param uint64_t index - index need to load | ||||
| // @param TensorRow row - image & label read into this tensor row | // @param TensorRow row - image & label read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadTensorRow(uint64_t index, TensorRow *row); | Status LoadTensorRow(uint64_t index, TensorRow *row); | ||||
| // @param const std::vector<uint64_t> &keys - keys in ioblock | // @param const std::vector<uint64_t> &keys - keys in ioblock | ||||
| // @param std::unique_ptr<DataBuffer> db | // @param std::unique_ptr<DataBuffer> db | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| // Read block data from cifar file | // Read block data from cifar file | ||||
| @@ -200,7 +200,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| // reset Op | // reset Op | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Get cifar files in dir | // Get cifar files in dir | ||||
| @@ -221,7 +221,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss | // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss | ||||
| // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class | // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | ||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| @@ -133,12 +133,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return = The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| // The builder "Build" method creates the final object. | // The builder "Build" method creates the final object. | ||||
| // @param std::shared_ptr<CocoOp> *op - DatasetOp | // @param std::shared_ptr<CocoOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<CocoOp> *op); | Status Build(std::shared_ptr<CocoOp> *op); | ||||
| private: | private: | ||||
| @@ -173,13 +173,13 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t workerId - id of each worker | // @param int32_t workerId - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Main Loop of CocoOp | // Main Loop of CocoOp | ||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -214,19 +214,19 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| std::string Name() const override { return "CocoOp"; } | std::string Name() const override { return "CocoOp"; } | ||||
| /// \brief Gets the class indexing | /// \brief Gets the class indexing | ||||
| /// \return Status - The status code return | |||||
| /// \return Status The status code returned | |||||
| Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override; | Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override; | ||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InitSampler(); | Status InitSampler(); | ||||
| // Load a tensor row according to image id | // Load a tensor row according to image id | ||||
| // @param row_id_type row_id - id for this tensor row | // @param row_id_type row_id - id for this tensor row | ||||
| // @param std::string image_id - image id | // @param std::string image_id - image id | ||||
| // @param TensorRow row - image & target read into this tensor row | // @param TensorRow row - image & target read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); | Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); | ||||
| // Load a tensor row with vector which a vector to a tensor | // Load a tensor row with vector which a vector to a tensor | ||||
| @@ -235,7 +235,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::shared_ptr<Tensor> image - image tensor | // @param std::shared_ptr<Tensor> image - image tensor | ||||
| // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | ||||
| // @param TensorRow row - image & target read into this tensor row | // @param TensorRow row - image & target read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | ||||
| std::shared_ptr<Tensor> coordinate, TensorRow *trow); | std::shared_ptr<Tensor> coordinate, TensorRow *trow); | ||||
| @@ -245,7 +245,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::shared_ptr<Tensor> image - image tensor | // @param std::shared_ptr<Tensor> image - image tensor | ||||
| // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | ||||
| // @param TensorRow row - image & target read into this tensor row | // @param TensorRow row - image & target read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | ||||
| std::shared_ptr<Tensor> coordinate, TensorRow *trow); | std::shared_ptr<Tensor> coordinate, TensorRow *trow); | ||||
| @@ -255,69 +255,69 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::shared_ptr<Tensor> image - image tensor | // @param std::shared_ptr<Tensor> image - image tensor | ||||
| // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | ||||
| // @param TensorRow row - image & target read into this tensor row | // @param TensorRow row - image & target read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | ||||
| std::shared_ptr<Tensor> coordinate, TensorRow *trow); | std::shared_ptr<Tensor> coordinate, TensorRow *trow); | ||||
| // @param const std::string &path - path to the image file | // @param const std::string &path - path to the image file | ||||
| // @param const ColDescriptor &col - contains tensor implementation and datatype | // @param const ColDescriptor &col - contains tensor implementation and datatype | ||||
| // @param std::shared_ptr<Tensor> tensor - return | // @param std::shared_ptr<Tensor> tensor - return | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | ||||
| // @param const std::vector<uint64_t> &keys - keys in ioblock | // @param const std::vector<uint64_t> &keys - keys in ioblock | ||||
| // @param std::unique_ptr<DataBuffer> db | // @param std::unique_ptr<DataBuffer> db | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| // Read annotation from Annotation folder | // Read annotation from Annotation folder | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ParseAnnotationIds(); | Status ParseAnnotationIds(); | ||||
| // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor | // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor | ||||
| // @param std::vector<int64_t> *keys - image id | // @param std::vector<int64_t> *keys - image id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | ||||
| // Called first when function is called | // Called first when function is called | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| // Reset dataset state | // Reset dataset state | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // @param nlohmann::json image_tree - image tree of json | // @param nlohmann::json image_tree - image tree of json | ||||
| // @param std::vector<std::string> *image_vec - image id list of json | // @param std::vector<std::string> *image_vec - image id list of json | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec); | Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec); | ||||
| // @param nlohmann::json categories_tree - categories tree of json | // @param nlohmann::json categories_tree - categories tree of json | ||||
| // return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status CategoriesColumnLoad(const nlohmann::json &categories_tree); | Status CategoriesColumnLoad(const nlohmann::json &categories_tree); | ||||
| // @param nlohmann::json categories_tree - categories tree of json | // @param nlohmann::json categories_tree - categories tree of json | ||||
| // @param const std::string &image_file - current image name in annotation | // @param const std::string &image_file - current image name in annotation | ||||
| // @param const int32_t &id - current unique id of annotation | // @param const int32_t &id - current unique id of annotation | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | ||||
| // @param nlohmann::json categories_tree - categories tree of json | // @param nlohmann::json categories_tree - categories tree of json | ||||
| // @param const std::string &image_file - current image name in annotation | // @param const std::string &image_file - current image name in annotation | ||||
| // @param const int32_t &id - current unique id of annotation | // @param const int32_t &id - current unique id of annotation | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | ||||
| // @param nlohmann::json categories_tree - categories tree of json | // @param nlohmann::json categories_tree - categories tree of json | ||||
| // @param const std::string &image_file - current image name in annotation | // @param const std::string &image_file - current image name in annotation | ||||
| // @param const int32_t &id - current unique id of annotation | // @param const int32_t &id - current unique id of annotation | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | ||||
| // @param nlohmann::json categories_tree - categories tree of json | // @param nlohmann::json categories_tree - categories tree of json | ||||
| // @param const std::string &image_file - current image name in annotation | // @param const std::string &image_file - current image name in annotation | ||||
| // @param const int32_t &image_id - current unique id of annotation | // @param const int32_t &image_id - current unique id of annotation | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, | Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, | ||||
| const int32_t &image_id); | const int32_t &image_id); | ||||
| @@ -115,13 +115,13 @@ class GeneratorOp : public PipelineOp { | |||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work. | // provide the master loop that drives the logic for performing the work. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Overrides base class reset method. When an operator does a reset, it cleans up any state | // Overrides base class reset method. When an operator does a reset, it cleans up any state | ||||
| // info from it's previous execution and then initializes itself so that it can be executed | // info from it's previous execution and then initializes itself so that it can be executed | ||||
| // again. | // again. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| @@ -135,12 +135,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| // @param std::shared_ptr<ImageFolderOp> *op - DatasetOp | // @param std::shared_ptr<ImageFolderOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<ImageFolderOp> *op); | Status Build(std::shared_ptr<ImageFolderOp> *op); | ||||
| private: | private: | ||||
| @@ -172,28 +172,28 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| // Initialize ImageFOlderOp related var, calls the function to walk all files | // Initialize ImageFOlderOp related var, calls the function to walk all files | ||||
| // @param - std::string dir file directory to ImageNetFolder | // @param - std::string dir file directory to ImageNetFolder | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PrescanMasterEntry(const std::string &dir); | Status PrescanMasterEntry(const std::string &dir); | ||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t workerId - id of each worker | // @param int32_t workerId - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t workerId - id of each worker | // @param int32_t workerId - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PrescanWorkerEntry(int32_t worker_id); | Status PrescanWorkerEntry(int32_t worker_id); | ||||
| // Main Loop of ImageFolderOp | // Main Loop of ImageFolderOp | ||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | ||||
| // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -224,19 +224,19 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InitSampler(); | Status InitSampler(); | ||||
| // Load a tensor row according to a pair | // Load a tensor row according to a pair | ||||
| // @param row_id_type row_id - id for this tensor row | // @param row_id_type row_id - id for this tensor row | ||||
| // @param ImageLabelPair pair - <imagefile,label> | // @param ImageLabelPair pair - <imagefile,label> | ||||
| // @param TensorRow row - image & label read into this tensor row | // @param TensorRow row - image & label read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); | Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); | ||||
| // @param const std::vector<int64_t> &keys - keys in ioblock | // @param const std::vector<int64_t> &keys - keys in ioblock | ||||
| // @param std::unique_ptr<DataBuffer> db | // @param std::unique_ptr<DataBuffer> db | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| // @param std::string & dir - dir to walk all images | // @param std::string & dir - dir to walk all images | ||||
| @@ -253,7 +253,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| // reset Op | // reset Op | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| @@ -58,12 +58,12 @@ class IOBlock { | |||||
| // Fetches the first key from the block. | // Fetches the first key from the block. | ||||
| // @note Only useful if you know the block only has 1 key. | // @note Only useful if you know the block only has 1 key. | ||||
| // @return A copy of the first key from the block | // @return A copy of the first key from the block | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetKey(int64_t *out_key) const; | Status GetKey(int64_t *out_key) const; | ||||
| // Fetches the list of keys from this block. | // Fetches the list of keys from this block. | ||||
| // @param out_keys - A copy of the vector of keys from the block. | // @param out_keys - A copy of the vector of keys from the block. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetKeys(std::vector<int64_t> *out_keys) const; | Status GetKeys(std::vector<int64_t> *out_keys) const; | ||||
| // Does this block have the eoe flag turned on? | // Does this block have the eoe flag turned on? | ||||
| @@ -110,7 +110,7 @@ class FilenameBlock : public IOBlock { | |||||
| // Gets the filename from the block using the provided index container | // Gets the filename from the block using the provided index container | ||||
| // @param out_filename - The filename to add to the block | // @param out_filename - The filename to add to the block | ||||
| // @param index - The index to perform lookup against | // @param index - The index to perform lookup against | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetFilename(std::string *out_filename, const AutoIndexObj<std::string> &index) const; | Status GetFilename(std::string *out_filename, const AutoIndexObj<std::string> &index) const; | ||||
| // Get the start offset of file | // Get the start offset of file | ||||
| @@ -110,12 +110,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| // The builder "build" method creates the final object. | // The builder "build" method creates the final object. | ||||
| // @param std::shared_ptr<ManifestOp> *op - DatasetOp | // @param std::shared_ptr<ManifestOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<ManifestOp> *op); | Status Build(std::shared_ptr<ManifestOp> *op); | ||||
| private: | private: | ||||
| @@ -145,18 +145,18 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t worker_id - id of each worker | // @param int32_t worker_id - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Main Loop of ManifestOp | // Main Loop of ManifestOp | ||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | ||||
| // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -201,37 +201,37 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InitSampler(); | Status InitSampler(); | ||||
| // Method in operator(), to fill IOBlockQueue | // Method in operator(), to fill IOBlockQueue | ||||
| // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue | // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer); | Status AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer); | ||||
| // Load a tensor row according to a pair | // Load a tensor row according to a pair | ||||
| // @param row_id_type row_id - id for this tensor row | // @param row_id_type row_id - id for this tensor row | ||||
| // @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>> | // @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>> | ||||
| // @param TensorRow row - image & label read into this tensor row | // @param TensorRow row - image & label read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<std::string>> &data, | Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<std::string>> &data, | ||||
| TensorRow *row); | TensorRow *row); | ||||
| // @param const std::vector<int64_t> &keys - keys in ioblock | // @param const std::vector<int64_t> &keys - keys in ioblock | ||||
| // @param std::unique_ptr<DataBuffer> db | // @param std::unique_ptr<DataBuffer> db | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| // Parse manifest file to get image path and label and so on. | // Parse manifest file to get image path and label and so on. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ParseManifestFile(); | Status ParseManifestFile(); | ||||
| // Called first when function is called | // Called first when function is called | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| // reset Op | // reset Op | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Check if image ia valid.Only support JPEG/PNG/GIF/BMP | // Check if image ia valid.Only support JPEG/PNG/GIF/BMP | ||||
| @@ -239,7 +239,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| Status CheckImageType(const std::string &file_name, bool *valid); | Status CheckImageType(const std::string &file_name, bool *valid); | ||||
| // Count label index,num rows and num samples | // Count label index,num rows and num samples | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status CountDatasetInfo(); | Status CountDatasetInfo(); | ||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| @@ -164,13 +164,13 @@ class MindRecordOp : public ParallelOp { | |||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t workerId - id of each worker | // @param int32_t workerId - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work. | // provide the master loop that drives the logic for performing the work. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Called first when function is called | // Called first when function is called | ||||
| @@ -180,7 +180,7 @@ class MindRecordOp : public ParallelOp { | |||||
| // Overrides base class reset method. When an operator does a reset, it cleans up any state | // Overrides base class reset method. When an operator does a reset, it cleans up any state | ||||
| // info from it's previous execution and then initializes itself so that it can be executed | // info from it's previous execution and then initializes itself so that it can be executed | ||||
| // again. | // again. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Getter method | // Getter method | ||||
| @@ -99,12 +99,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| // The builder "Build" method creates the final object. | // The builder "Build" method creates the final object. | ||||
| // @param std::shared_ptr<MnistOp> *op - DatasetOp | // @param std::shared_ptr<MnistOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<MnistOp> *op); | Status Build(std::shared_ptr<MnistOp> *op); | ||||
| private: | private: | ||||
| @@ -133,18 +133,18 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t worker_id - id of each worker | // @param int32_t worker_id - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Main Loop of MnistOp | // Main Loop of MnistOp | ||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | ||||
| // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class | // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -170,39 +170,39 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InitSampler(); | Status InitSampler(); | ||||
| // Load a tensor row according to a pair | // Load a tensor row according to a pair | ||||
| // @param row_id_type row_id - id for this tensor row | // @param row_id_type row_id - id for this tensor row | ||||
| // @param ImageLabelPair pair - <imagefile,label> | // @param ImageLabelPair pair - <imagefile,label> | ||||
| // @param TensorRow row - image & label read into this tensor row | // @param TensorRow row - image & label read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); | Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); | ||||
| // @param const std::vector<int64_t> &keys - keys in ioblock | // @param const std::vector<int64_t> &keys - keys in ioblock | ||||
| // @param std::unique_ptr<DataBuffer> db | // @param std::unique_ptr<DataBuffer> db | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| // Iterate through all members in sampleIds and fill them into IOBlock. | // Iterate through all members in sampleIds and fill them into IOBlock. | ||||
| // @param std::shared_ptr<Tensor> sample_ids - | // @param std::shared_ptr<Tensor> sample_ids - | ||||
| // @param std::vector<int64_t> *keys - keys in ioblock | // @param std::vector<int64_t> *keys - keys in ioblock | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | ||||
| // Check image file stream. | // Check image file stream. | ||||
| // @param const std::string *file_name - image file name | // @param const std::string *file_name - image file name | ||||
| // @param std::ifstream *image_reader - image file stream | // @param std::ifstream *image_reader - image file stream | ||||
| // @param uint32_t num_images - returns the number of images | // @param uint32_t num_images - returns the number of images | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); | Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); | ||||
| // Check label stream. | // Check label stream. | ||||
| // @param const std::string &file_name - label file name | // @param const std::string &file_name - label file name | ||||
| // @param std::ifstream *label_reader - label file stream | // @param std::ifstream *label_reader - label file stream | ||||
| // @param uint32_t num_labels - returns the number of labels | // @param uint32_t num_labels - returns the number of labels | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); | Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); | ||||
| // Read 4 bytes of data from a file stream. | // Read 4 bytes of data from a file stream. | ||||
| @@ -219,23 +219,23 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param std::ifstream *image_reader - image file stream | // @param std::ifstream *image_reader - image file stream | ||||
| // @param std::ifstream *label_reader - label file stream | // @param std::ifstream *label_reader - label file stream | ||||
| // @param int64_t read_num - number of image to read | // @param int64_t read_num - number of image to read | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); | Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); | ||||
| // Parse all mnist dataset files | // Parse all mnist dataset files | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ParseMnistData(); | Status ParseMnistData(); | ||||
| // Read all files in the directory | // Read all files in the directory | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WalkAllFiles(); | Status WalkAllFiles(); | ||||
| // Called first when function is called | // Called first when function is called | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| // reset Op | // reset Op | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| @@ -63,7 +63,7 @@ class RandomDataOp : public ParallelOp { | |||||
| /** | /** | ||||
| * The build method that produces the instantiated RandomDataOp as a shared pointer | * The build method that produces the instantiated RandomDataOp as a shared pointer | ||||
| * @param out_op - The output RandomDataOperator that was constructed | * @param out_op - The output RandomDataOperator that was constructed | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status Build(std::shared_ptr<RandomDataOp> *out_op); | Status Build(std::shared_ptr<RandomDataOp> *out_op); | ||||
| @@ -128,7 +128,7 @@ class RandomDataOp : public ParallelOp { | |||||
| private: | private: | ||||
| /** | /** | ||||
| * Check if the required parameters are set by the builder. | * Check if the required parameters are set by the builder. | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| @@ -182,7 +182,7 @@ class RandomDataOp : public ParallelOp { | |||||
| * Class functor operator () override. | * Class functor operator () override. | ||||
| * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | ||||
| * provide the master loop that drives the logic for performing the work. | * provide the master loop that drives the logic for performing the work. | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status operator()() override; | Status operator()() override; | ||||
| @@ -190,7 +190,7 @@ class RandomDataOp : public ParallelOp { | |||||
| * Overrides base class reset method. When an operator does a reset, it cleans up any state | * Overrides base class reset method. When an operator does a reset, it cleans up any state | ||||
| * info from it's previous execution and then initializes itself so that it can be executed | * info from it's previous execution and then initializes itself so that it can be executed | ||||
| * again. | * again. | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status Reset() override; | Status Reset() override; | ||||
| @@ -207,7 +207,7 @@ class RandomDataOp : public ParallelOp { | |||||
| /** | /** | ||||
| * The entry point code for when workers are launched | * The entry point code for when workers are launched | ||||
| * @param worker_id - The worker id | * @param worker_id - The worker id | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| @@ -219,7 +219,7 @@ class RandomDataOp : public ParallelOp { | |||||
| /** | /** | ||||
| * Performs a synchronization between workers at the end of an epoch | * Performs a synchronization between workers at the end of an epoch | ||||
| * @param worker_id - The worker id | * @param worker_id - The worker id | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status EpochSync(int32_t worker_id, bool *quitting); | Status EpochSync(int32_t worker_id, bool *quitting); | ||||
| @@ -227,7 +227,7 @@ class RandomDataOp : public ParallelOp { | |||||
| * A helper function to stuff the tensor table into a buffer and send it to output connector | * A helper function to stuff the tensor table into a buffer and send it to output connector | ||||
| * @param worker_id - The worker id | * @param worker_id - The worker id | ||||
| * @param in_table - The tensor table to pack and send | * @param in_table - The tensor table to pack and send | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table); | Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table); | ||||
| @@ -235,7 +235,7 @@ class RandomDataOp : public ParallelOp { | |||||
| * A helper function to create random data for the row | * A helper function to create random data for the row | ||||
| * @param worker_id - The worker id | * @param worker_id - The worker id | ||||
| * @param new_row - The output row to produce | * @param new_row - The output row to produce | ||||
| * @return Status - The error code return | |||||
| * @return Status The status code returned | |||||
| */ | */ | ||||
| Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); | Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); | ||||
| @@ -40,7 +40,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||||
| // @param std::unique_ptr<DataBuffer pBuffer | // @param std::unique_ptr<DataBuffer pBuffer | ||||
| // @param int32_t workerId | // @param int32_t workerId | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | // first handshake between leaf source op and Sampler. This func will determine the amount of data | ||||
| @@ -53,7 +53,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| @@ -41,13 +41,13 @@ class PythonSamplerRT : public SamplerRT { | |||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| // Op calls this to get next Buffer that contains all the sampleIds | // Op calls this to get next Buffer that contains all the sampleIds | ||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| @@ -40,14 +40,14 @@ class RandomSamplerRT : public SamplerRT { | |||||
| // Op calls this to get next Buffer that contains all the sampleIds | // Op calls this to get next Buffer that contains all the sampleIds | ||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // meant to be called by base class or python | // meant to be called by base class or python | ||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | void SamplerPrint(std::ostream &out, bool show_all) const override; | ||||
| @@ -35,12 +35,12 @@ class RandomAccessOp { | |||||
| public: | public: | ||||
| // Sampler get number of rows in the dataset | // Sampler get number of rows in the dataset | ||||
| // @param int64_t num - return number of rows for this dataset | // @param int64_t num - return number of rows for this dataset | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNumRowsInDataset(int64_t *num_rows) const; | Status GetNumRowsInDataset(int64_t *num_rows) const; | ||||
| // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK | // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK | ||||
| // @param std::map<int64_t, std::vector<int64_t>> * map | // @param std::map<int64_t, std::vector<int64_t>> * map | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const { | virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const { | ||||
| RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); | RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ class SamplerRT { | |||||
| // @note It is Sampler responsibility to make sure that the id is not out of bound. | // @note It is Sampler responsibility to make sure that the id is not out of bound. | ||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0; | virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0; | ||||
| // This function only called by python layer. Not needed by Android. | // This function only called by python layer. Not needed by Android. | ||||
| @@ -81,7 +81,7 @@ class SamplerRT { | |||||
| #endif | #endif | ||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status ResetSampler() = 0; | virtual Status ResetSampler() = 0; | ||||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | // first handshake between leaf source op and Sampler. This func will determine the amount of data | ||||
| @@ -114,13 +114,13 @@ class SamplerRT { | |||||
| // Adds a sampler to become our child. | // Adds a sampler to become our child. | ||||
| // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | ||||
| // @return - The error code returned. | |||||
| // @return Status The status code returned | |||||
| Status AddChild(std::shared_ptr<SamplerRT> child); | Status AddChild(std::shared_ptr<SamplerRT> child); | ||||
| // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | ||||
| // @param std::shared_ptr<Tensor>* sampleIds | // @param std::shared_ptr<Tensor>* sampleIds | ||||
| // @param int64_t numElements - must be a non 0 number | // @param int64_t numElements - must be a non 0 number | ||||
| // @return - The error code returned. | |||||
| // @return Status The status code returned | |||||
| Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -146,7 +146,7 @@ class SamplerRT { | |||||
| // associated id. | // associated id. | ||||
| // @param int64_t* out_associated_id - Out parameter, contains the associated id. | // @param int64_t* out_associated_id - Out parameter, contains the associated id. | ||||
| // @param int64_t id - The id used as an index to get the associated child id. | // @param int64_t id - The id used as an index to get the associated child id. | ||||
| // @return - The error code returned. | |||||
| // @return Status The status code returned | |||||
| Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); | Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); | ||||
| protected: | protected: | ||||
| @@ -40,13 +40,13 @@ class SequentialSamplerRT : public SamplerRT { | |||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| // for next epoch of sampleIds | // for next epoch of sampleIds | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| // Op calls this to get next Buffer that contains all the sampleIds | // Op calls this to get next Buffer that contains all the sampleIds | ||||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | ||||
| // @param int32_t workerId - not meant to be used | // @param int32_t workerId - not meant to be used | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| @@ -132,12 +132,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return = The error code return | |||||
| // @return Status The status code returned | |||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| // The builder "Build" method creates the final object. | // The builder "Build" method creates the final object. | ||||
| // @param std::shared_ptr<VOCOp> *op - DatasetOp | // @param std::shared_ptr<VOCOp> *op - DatasetOp | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Build(std::shared_ptr<VOCOp> *op); | Status Build(std::shared_ptr<VOCOp> *op); | ||||
| private: | private: | ||||
| @@ -173,13 +173,13 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | ||||
| // @param int32_t workerId - id of each worker | // @param int32_t workerId - id of each worker | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status WorkerEntry(int32_t worker_id) override; | Status WorkerEntry(int32_t worker_id) override; | ||||
| // Main Loop of VOCOp | // Main Loop of VOCOp | ||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| @@ -222,55 +222,55 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| // Initialize Sampler, calls sampler->Init() within | // Initialize Sampler, calls sampler->Init() within | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status InitSampler(); | Status InitSampler(); | ||||
| // Load a tensor row according to image id | // Load a tensor row according to image id | ||||
| // @param row_id_type row_id - id for this tensor row | // @param row_id_type row_id - id for this tensor row | ||||
| // @param std::string image_id - image id | // @param std::string image_id - image id | ||||
| // @param TensorRow row - image & target read into this tensor row | // @param TensorRow row - image & target read into this tensor row | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); | Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); | ||||
| // @param const std::string &path - path to the image file | // @param const std::string &path - path to the image file | ||||
| // @param const ColDescriptor &col - contains tensor implementation and datatype | // @param const ColDescriptor &col - contains tensor implementation and datatype | ||||
| // @param std::shared_ptr<Tensor> tensor - return | // @param std::shared_ptr<Tensor> tensor - return | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | ||||
| // @param const std::string &path - path to the image file | // @param const std::string &path - path to the image file | ||||
| // @param TensorRow *row - return | // @param TensorRow *row - return | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ReadAnnotationToTensor(const std::string &path, TensorRow *row); | Status ReadAnnotationToTensor(const std::string &path, TensorRow *row); | ||||
| // @param const std::vector<uint64_t> &keys - keys in ioblock | // @param const std::vector<uint64_t> &keys - keys in ioblock | ||||
| // @param std::unique_ptr<DataBuffer> db | // @param std::unique_ptr<DataBuffer> db | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | ||||
| // Read image list from ImageSets | // Read image list from ImageSets | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ParseImageIds(); | Status ParseImageIds(); | ||||
| // Read annotation from Annotation folder | // Read annotation from Annotation folder | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ParseAnnotationIds(); | Status ParseAnnotationIds(); | ||||
| // @param const std::string &path - path to annotation xml | // @param const std::string &path - path to annotation xml | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ParseAnnotationBbox(const std::string &path); | Status ParseAnnotationBbox(const std::string &path); | ||||
| // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor | // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor | ||||
| // @param std::vector<int64_t> *keys - image id | // @param std::vector<int64_t> *keys - image id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | ||||
| // Called first when function is called | // Called first when function is called | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LaunchThreadsAndInitOp(); | Status LaunchThreadsAndInitOp(); | ||||
| // Reset dataset state | // Reset dataset state | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Reset() override; | Status Reset() override; | ||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| @@ -75,7 +75,7 @@ class TakeOp : public PipelineOp { | |||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| // Base-class override for NodePass visitor acceptor. | // Base-class override for NodePass visitor acceptor. | ||||
| @@ -101,7 +101,7 @@ class ZipOp : public PipelineOp { | |||||
| // Class functor operator () override. | // Class functor operator () override. | ||||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | ||||
| // provide the master loop that drives the logic for performing the work | // provide the master loop that drives the logic for performing the work | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status operator()() override; | Status operator()() override; | ||||
| /// \brief Base-class override for NodePass pre-visit acceptor | /// \brief Base-class override for NodePass pre-visit acceptor | ||||
| @@ -226,7 +226,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||||
| // Compulsory transformation/action post optimization. | // Compulsory transformation/action post optimization. | ||||
| // For example, repeatOp inlining | // For example, repeatOp inlining | ||||
| // | // | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) { | Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) { | ||||
| num_epochs_ = num_epochs; | num_epochs_ = num_epochs; | ||||
| partially_prepare_ = partial; | partially_prepare_ = partial; | ||||
| @@ -115,16 +115,16 @@ class ExecutionTree { | |||||
| // provides it with a link to the tree. A node cannot form any relationships (parent/child) with | // provides it with a link to the tree. A node cannot form any relationships (parent/child) with | ||||
| // other nodes unless they are associated with the same tree. | // other nodes unless they are associated with the same tree. | ||||
| // @param op - The operator to associate | // @param op - The operator to associate | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status AssociateNode(const std::shared_ptr<DatasetOp> &op); | Status AssociateNode(const std::shared_ptr<DatasetOp> &op); | ||||
| // Sets the root node of the tree | // Sets the root node of the tree | ||||
| // @param op - The operator to assign as root | // @param op - The operator to assign as root | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status AssignRoot(const std::shared_ptr<DatasetOp> &op); | Status AssignRoot(const std::shared_ptr<DatasetOp> &op); | ||||
| // Start the execution of the tree | // Start the execution of the tree | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Launch(); | Status Launch(); | ||||
| /// A print method typically used for debugging | /// A print method typically used for debugging | ||||
| @@ -155,7 +155,7 @@ class ExecutionTree { | |||||
| // wrapper for the TaskGroup handling that is stored inside the execution tree. | // wrapper for the TaskGroup handling that is stored inside the execution tree. | ||||
| // @param num_workers - The number of workers to launch | // @param num_workers - The number of workers to launch | ||||
| // @param func - The function entry point that workers will execute | // @param func - The function entry point that workers will execute | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = ""); | Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = ""); | ||||
| // Getter method | // Getter method | ||||
| @@ -181,32 +181,32 @@ class ExecutionTree { | |||||
| // Compulsory transformation/action post optimization. | // Compulsory transformation/action post optimization. | ||||
| // For example, repeatOp inlining | // For example, repeatOp inlining | ||||
| // | // | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Prepare(int num_epochs = -1, bool partial = false); | Status Prepare(int num_epochs = -1, bool partial = false); | ||||
| // Compulsory transformation/action pre optimization. | // Compulsory transformation/action pre optimization. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PreAction(); | Status PreAction(); | ||||
| // Compulsory transformation/action post optimization. | // Compulsory transformation/action post optimization. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PostAction(); | Status PostAction(); | ||||
| // Optimization transformation/action, optional. | // Optimization transformation/action, optional. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Optimize(); | Status Optimize(); | ||||
| // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | ||||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | // walk the tree to perform modifications to the tree or specific nodes within the tree to get | ||||
| // it ready for execution. | // it ready for execution. | ||||
| // @param Total number of epochs that will be run on this tree | // @param Total number of epochs that will be run on this tree | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PrepareDeprecated(); | Status PrepareDeprecated(); | ||||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | // Recursive function used during prepare phase to visit a node and drive any pre- and post- | ||||
| // node actions during a tree walk. | // node actions during a tree walk. | ||||
| // @param op - The dataset op to work on | // @param op - The dataset op to work on | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op); | Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op); | ||||
| // Return the pointer to the TaskGroup | // Return the pointer to the TaskGroup | ||||
| @@ -51,7 +51,7 @@ class Edge { | |||||
| // Get the feature of a edge | // Get the feature of a edge | ||||
| // @param FeatureType feature_type - type of feature | // @param FeatureType feature_type - type of feature | ||||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | // @param std::shared_ptr<Feature> *out_feature - Returned feature | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | ||||
| // Get nodes on the edge | // Get nodes on the edge | ||||
| @@ -71,7 +71,7 @@ class Edge { | |||||
| // Update feature of edge | // Update feature of edge | ||||
| // @param std::shared_ptr<Feature> feature - | // @param std::shared_ptr<Feature> feature - | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; | virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; | ||||
| protected: | protected: | ||||
| @@ -47,19 +47,19 @@ class GraphData { | |||||
| // Get all nodes from the graph. | // Get all nodes from the graph. | ||||
| // @param NodeType node_type - type of node | // @param NodeType node_type - type of node | ||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | // @param std::shared_ptr<Tensor> *out - Returned nodes id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0; | virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0; | ||||
| // Get all edges from the graph. | // Get all edges from the graph. | ||||
| // @param NodeType edge_type - type of edge | // @param NodeType edge_type - type of edge | ||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | // @param std::shared_ptr<Tensor> *out - Returned edge ids | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0; | virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0; | ||||
| // Get the node id from the edge. | // Get the node id from the edge. | ||||
| // @param std::vector<EdgeIdType> edge_list - List of edges | // @param std::vector<EdgeIdType> edge_list - List of edges | ||||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | // @param std::shared_ptr<Tensor> *out - Returned node ids | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | ||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| @@ -68,7 +68,7 @@ class GraphData { | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | ||||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | ||||
| // is not enough, fill in tensor as -1. | // is not enough, fill in tensor as -1. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out) = 0; | std::shared_ptr<Tensor> *out) = 0; | ||||
| @@ -77,7 +77,7 @@ class GraphData { | |||||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | ||||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | ||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | ||||
| const std::vector<NodeIdType> &neighbor_nums, | const std::vector<NodeIdType> &neighbor_nums, | ||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; | const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; | ||||
| @@ -87,7 +87,7 @@ class GraphData { | |||||
| // @param NodeIdType samples_num - Number of neighbors sampled | // @param NodeIdType samples_num - Number of neighbors sampled | ||||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | // @param NodeType neg_neighbor_type - The type of negative neighbor. | ||||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | ||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0; | NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0; | ||||
| @@ -98,7 +98,7 @@ class GraphData { | |||||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | // @param float step_away_param - inout hyper parameter in node2vec algorithm | ||||
| // @param NodeIdType default_node - default node id | // @param NodeIdType default_node - default node id | ||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | ||||
| float step_home_param, float step_away_param, NodeIdType default_node, | float step_home_param, float step_away_param, NodeIdType default_node, | ||||
| std::shared_ptr<Tensor> *out) = 0; | std::shared_ptr<Tensor> *out) = 0; | ||||
| @@ -108,7 +108,7 @@ class GraphData { | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | ||||
| // does not exist. | // does not exist. | ||||
| // @param TensorRow *out - Returned features | // @param TensorRow *out - Returned features | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out) = 0; | TensorRow *out) = 0; | ||||
| @@ -117,7 +117,7 @@ class GraphData { | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | ||||
| // does not exist. | // does not exist. | ||||
| // @param Tensor *out - Returned features | // @param Tensor *out - Returned features | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out) = 0; | TensorRow *out) = 0; | ||||
| @@ -57,19 +57,19 @@ class GraphDataClient : public GraphData { | |||||
| // Get all nodes from the graph. | // Get all nodes from the graph. | ||||
| // @param NodeType node_type - type of node | // @param NodeType node_type - type of node | ||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | // @param std::shared_ptr<Tensor> *out - Returned nodes id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | ||||
| // Get all edges from the graph. | // Get all edges from the graph. | ||||
| // @param NodeType edge_type - type of edge | // @param NodeType edge_type - type of edge | ||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | // @param std::shared_ptr<Tensor> *out - Returned edge ids | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | ||||
| // Get the node id from the edge. | // Get the node id from the edge. | ||||
| // @param std::vector<EdgeIdType> edge_list - List of edges | // @param std::vector<EdgeIdType> edge_list - List of edges | ||||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | // @param std::shared_ptr<Tensor> *out - Returned node ids | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | ||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| @@ -78,7 +78,7 @@ class GraphDataClient : public GraphData { | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | ||||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | ||||
| // is not enough, fill in tensor as -1. | // is not enough, fill in tensor as -1. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out) override; | std::shared_ptr<Tensor> *out) override; | ||||
| @@ -87,7 +87,7 @@ class GraphDataClient : public GraphData { | |||||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | ||||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | ||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | ||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | ||||
| @@ -96,7 +96,7 @@ class GraphDataClient : public GraphData { | |||||
| // @param NodeIdType samples_num - Number of neighbors sampled | // @param NodeIdType samples_num - Number of neighbors sampled | ||||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | // @param NodeType neg_neighbor_type - The type of negative neighbor. | ||||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | ||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | ||||
| @@ -107,7 +107,7 @@ class GraphDataClient : public GraphData { | |||||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | // @param float step_away_param - inout hyper parameter in node2vec algorithm | ||||
| // @param NodeIdType default_node - default node id | // @param NodeIdType default_node - default node id | ||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | ||||
| float step_home_param, float step_away_param, NodeIdType default_node, | float step_home_param, float step_away_param, NodeIdType default_node, | ||||
| std::shared_ptr<Tensor> *out) override; | std::shared_ptr<Tensor> *out) override; | ||||
| @@ -117,7 +117,7 @@ class GraphDataClient : public GraphData { | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | ||||
| // does not exist. | // does not exist. | ||||
| // @param TensorRow *out - Returned features | // @param TensorRow *out - Returned features | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out) override; | TensorRow *out) override; | ||||
| @@ -126,7 +126,7 @@ class GraphDataClient : public GraphData { | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | ||||
| // does not exist. | // does not exist. | ||||
| // @param Tensor *out - Returned features | // @param Tensor *out - Returned features | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out) override; | TensorRow *out) override; | ||||
| @@ -51,19 +51,19 @@ class GraphDataImpl : public GraphData { | |||||
| // Get all nodes from the graph. | // Get all nodes from the graph. | ||||
| // @param NodeType node_type - type of node | // @param NodeType node_type - type of node | ||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | // @param std::shared_ptr<Tensor> *out - Returned nodes id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | ||||
| // Get all edges from the graph. | // Get all edges from the graph. | ||||
| // @param NodeType edge_type - type of edge | // @param NodeType edge_type - type of edge | ||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | // @param std::shared_ptr<Tensor> *out - Returned edge ids | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | ||||
| // Get the node id from the edge. | // Get the node id from the edge. | ||||
| // @param std::vector<EdgeIdType> edge_list - List of edges | // @param std::vector<EdgeIdType> edge_list - List of edges | ||||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | // @param std::shared_ptr<Tensor> *out - Returned node ids | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | ||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| @@ -72,7 +72,7 @@ class GraphDataImpl : public GraphData { | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | ||||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | ||||
| // is not enough, fill in tensor as -1. | // is not enough, fill in tensor as -1. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out) override; | std::shared_ptr<Tensor> *out) override; | ||||
| @@ -81,7 +81,7 @@ class GraphDataImpl : public GraphData { | |||||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | ||||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | ||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | ||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | ||||
| @@ -90,7 +90,7 @@ class GraphDataImpl : public GraphData { | |||||
| // @param NodeIdType samples_num - Number of neighbors sampled | // @param NodeIdType samples_num - Number of neighbors sampled | ||||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | // @param NodeType neg_neighbor_type - The type of negative neighbor. | ||||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | ||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | ||||
| @@ -101,7 +101,7 @@ class GraphDataImpl : public GraphData { | |||||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | // @param float step_away_param - inout hyper parameter in node2vec algorithm | ||||
| // @param NodeIdType default_node - default node id | // @param NodeIdType default_node - default node id | ||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | ||||
| float step_home_param, float step_away_param, NodeIdType default_node, | float step_home_param, float step_away_param, NodeIdType default_node, | ||||
| std::shared_ptr<Tensor> *out) override; | std::shared_ptr<Tensor> *out) override; | ||||
| @@ -111,7 +111,7 @@ class GraphDataImpl : public GraphData { | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | ||||
| // does not exist. | // does not exist. | ||||
| // @param TensorRow *out - Returned features | // @param TensorRow *out - Returned features | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out) override; | TensorRow *out) override; | ||||
| @@ -123,7 +123,7 @@ class GraphDataImpl : public GraphData { | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | ||||
| // does not exist. | // does not exist. | ||||
| // @param Tensor *out - Returned features | // @param Tensor *out - Returned features | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out) override; | TensorRow *out) override; | ||||
| @@ -132,7 +132,7 @@ class GraphDataImpl : public GraphData { | |||||
| // Get meta information of graph | // Get meta information of graph | ||||
| // @param MetaInfo *meta_info - Returned meta information | // @param MetaInfo *meta_info - Returned meta information | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetMetaInfo(MetaInfo *meta_info); | Status GetMetaInfo(MetaInfo *meta_info); | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| @@ -202,14 +202,14 @@ class GraphDataImpl : public GraphData { | |||||
| }; | }; | ||||
| // Load graph data from mindrecord file | // Load graph data from mindrecord file | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status LoadNodeAndEdge(); | Status LoadNodeAndEdge(); | ||||
| // Create Tensor By Vector | // Create Tensor By Vector | ||||
| // @param std::vector<std::vector<T>> &data - | // @param std::vector<std::vector<T>> &data - | ||||
| // @param DataType type - | // @param DataType type - | ||||
| // @param std::shared_ptr<Tensor> *out - | // @param std::shared_ptr<Tensor> *out - | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| template <typename T> | template <typename T> | ||||
| Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out); | Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out); | ||||
| @@ -217,32 +217,32 @@ class GraphDataImpl : public GraphData { | |||||
| // @param std::vector<std::vector<T>> *data - To be completed vector | // @param std::vector<std::vector<T>> *data - To be completed vector | ||||
| // @param size_t max_size - The size of the completed vector | // @param size_t max_size - The size of the completed vector | ||||
| // @param T default_value - Filled default | // @param T default_value - Filled default | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| template <typename T> | template <typename T> | ||||
| Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value); | Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value); | ||||
| // Get the default feature of a node | // Get the default feature of a node | ||||
| // @param FeatureType feature_type - | // @param FeatureType feature_type - | ||||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | // @param std::shared_ptr<Feature> *out_feature - Returned feature | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | ||||
| // Get the default feature of a edge | // Get the default feature of a edge | ||||
| // @param FeatureType feature_type - | // @param FeatureType feature_type - | ||||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | // @param std::shared_ptr<Feature> *out_feature - Returned feature | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | ||||
| // Find node object using node id | // Find node object using node id | ||||
| // @param NodeIdType id - | // @param NodeIdType id - | ||||
| // @param std::shared_ptr<Node> *node - Returned node object | // @param std::shared_ptr<Node> *node - Returned node object | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); | Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); | ||||
| // Find edge object using edge id | // Find edge object using edge id | ||||
| // @param EdgeIdType id - | // @param EdgeIdType id - | ||||
| // @param std::shared_ptr<Node> *edge - Returned edge object | // @param std::shared_ptr<Node> *edge - Returned edge object | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge); | Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge); | ||||
| // Negative sampling | // Negative sampling | ||||
| @@ -250,7 +250,7 @@ class GraphDataImpl : public GraphData { | |||||
| // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded | // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded | ||||
| // @param int32_t samples_num - | // @param int32_t samples_num - | ||||
| // @param std::vector<NodeIdType> *out_samples - Sampling results returned | // @param std::vector<NodeIdType> *out_samples - Sampling results returned | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids, | Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids, | ||||
| size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, | size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, | ||||
| std::vector<NodeIdType> *out_samples); | std::vector<NodeIdType> *out_samples); | ||||
| @@ -43,12 +43,12 @@ class LocalEdge : public Edge { | |||||
| // Get the feature of a edge | // Get the feature of a edge | ||||
| // @param FeatureType feature_type - type of feature | // @param FeatureType feature_type - type of feature | ||||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | // @param std::shared_ptr<Feature> *out_feature - Returned feature | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | ||||
| // Update feature of edge | // Update feature of edge | ||||
| // @param std::shared_ptr<Feature> feature - | // @param std::shared_ptr<Feature> feature - | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | ||||
| private: | private: | ||||
| @@ -40,13 +40,13 @@ class LocalNode : public Node { | |||||
| // Get the feature of a node | // Get the feature of a node | ||||
| // @param FeatureType feature_type - type of feature | // @param FeatureType feature_type - type of feature | ||||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | // @param std::shared_ptr<Feature> *out_feature - Returned feature | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | ||||
| // Get the all neighbors of a node | // Get the all neighbors of a node | ||||
| // @param NodeType neighbor_type - type of neighbor | // @param NodeType neighbor_type - type of neighbor | ||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | ||||
| bool exclude_itself = false) override; | bool exclude_itself = false) override; | ||||
| @@ -54,18 +54,18 @@ class LocalNode : public Node { | |||||
| // @param NodeType neighbor_type - type of neighbor | // @param NodeType neighbor_type - type of neighbor | ||||
| // @param int32_t samples_num - Number of neighbors to be acquired | // @param int32_t samples_num - Number of neighbors to be acquired | ||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | ||||
| std::vector<NodeIdType> *out_neighbors) override; | std::vector<NodeIdType> *out_neighbors) override; | ||||
| // Add neighbor of node | // Add neighbor of node | ||||
| // @param std::shared_ptr<Node> node - | // @param std::shared_ptr<Node> node - | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status AddNeighbor(const std::shared_ptr<Node> &node) override; | Status AddNeighbor(const std::shared_ptr<Node> &node) override; | ||||
| // Update feature of node | // Update feature of node | ||||
| // @param std::shared_ptr<Feature> feature - | // @param std::shared_ptr<Feature> feature - | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | ||||
| private: | private: | ||||
| @@ -49,13 +49,13 @@ class Node { | |||||
| // Get the feature of a node | // Get the feature of a node | ||||
| // @param FeatureType feature_type - type of feature | // @param FeatureType feature_type - type of feature | ||||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | // @param std::shared_ptr<Feature> *out_feature - Returned feature | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | ||||
| // Get the all neighbors of a node | // Get the all neighbors of a node | ||||
| // @param NodeType neighbor_type - type of neighbor | // @param NodeType neighbor_type - type of neighbor | ||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | ||||
| bool exclude_itself = false) = 0; | bool exclude_itself = false) = 0; | ||||
| @@ -63,18 +63,18 @@ class Node { | |||||
| // @param NodeType neighbor_type - type of neighbor | // @param NodeType neighbor_type - type of neighbor | ||||
| // @param int32_t samples_num - Number of neighbors to be acquired | // @param int32_t samples_num - Number of neighbors to be acquired | ||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | ||||
| std::vector<NodeIdType> *out_neighbors) = 0; | std::vector<NodeIdType> *out_neighbors) = 0; | ||||
| // Add neighbor of node | // Add neighbor of node | ||||
| // @param std::shared_ptr<Node> node - | // @param std::shared_ptr<Node> node - | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0; | virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0; | ||||
| // Update feature of node | // Update feature of node | ||||
| // @param std::shared_ptr<Feature> feature - | // @param std::shared_ptr<Feature> feature - | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; | virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; | ||||
| protected: | protected: | ||||
| @@ -57,6 +57,10 @@ class RepeatNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Getter | |||||
| /// \return Number of cycles to repeat the execution | |||||
| const int32_t Count() const { return repeat_count_; } | |||||
| /// \brief Base-class override for GetDatasetSize | /// \brief Base-class override for GetDatasetSize | ||||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | /// \param[in] size_getter Shared pointer to DatasetSizeGetter | ||||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | ||||
| @@ -55,6 +55,10 @@ class SkipNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Getter | |||||
| /// \return Number of rows to skip | |||||
| const int32_t Count() const { return skip_count_; } | |||||
| /// \brief Base-class override for GetDatasetSize | /// \brief Base-class override for GetDatasetSize | ||||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | /// \param[in] size_getter Shared pointer to DatasetSizeGetter | ||||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | ||||
| @@ -55,6 +55,10 @@ class TakeNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Getter | |||||
| /// \return Number of rows to output | |||||
| const int32_t Count() const { return take_count_; } | |||||
| /// \brief Base-class override for GetDatasetSize | /// \brief Base-class override for GetDatasetSize | ||||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | /// \param[in] size_getter Shared pointer to DatasetSizeGetter | ||||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | ||||
| @@ -29,7 +29,7 @@ class TensorOpFusionPass : public NodePass { | |||||
| /// \brief Identifies and fuses tensor ops within MapOp | /// \brief Identifies and fuses tensor ops within MapOp | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] *modified indicates whether the node has been visited | /// \param[inout] *modified indicates whether the node has been visited | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -135,7 +135,7 @@ class IRTreePass : public IRPass { | |||||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | /// "modified" flag needs to be set to true if tree is modified during the pass execution. | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate if the tree was modified. | /// \param[inout] Indicate if the tree was modified. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | ||||
| }; | }; | ||||
| @@ -170,14 +170,14 @@ class IRNodePass : public IRPass { | |||||
| /// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution | /// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[out] modified Indicator if the node was changed at all | /// \param[out] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | ||||
| /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation | /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation | ||||
| /// "modified" flag needs to be set to true if node is modified during the pass execution | /// "modified" flag needs to be set to true if node is modified during the pass execution | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[out] modified Indicator if the node was changed at all. | /// \param[out] modified Indicator if the node was changed at all. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | ||||
| // Visit()/VisitAfter() method to be overridden. | // Visit()/VisitAfter() method to be overridden. | ||||
| @@ -266,7 +266,7 @@ class TreePass : public Pass { | |||||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | /// "modified" flag needs to be set to true if tree is modified during the pass execution. | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | ||||
| }; | }; | ||||
| @@ -301,14 +301,14 @@ class NodePass : public Pass { | |||||
| /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution | /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[out] modified Indicator if the node was changed at all | /// \param[out] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | ||||
| /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation | /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation | ||||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution | /// "modified" flag needs to be set to true if tree is modified during the pass execution | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[out] modified Indicator if the node was changed at all. | /// \param[out] modified Indicator if the node was changed at all. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | ||||
| // Visit methods to be overridden. | // Visit methods to be overridden. | ||||
| @@ -41,80 +41,80 @@ class RepeatPass : public NodePass { | |||||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | ||||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | ||||
| /// \brief Identifies the subtree below this node as being in a cache merge path | /// \brief Identifies the subtree below this node as being in a cache merge path | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | ||||
| /// \brief Identifies the subtree below this node as being cached | /// \brief Identifies the subtree below this node as being cached | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| /// \brief Hooks up any identified eoe nodes under this repeat. | /// \brief Hooks up any identified eoe nodes under this repeat. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | ||||
| /// \brief Hooks up any identified eoe nodes under this repeat. | /// \brief Hooks up any identified eoe nodes under this repeat. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | ||||
| /// \brief CacheOp removes previous leaf ops and replaces them with itself | /// \brief CacheOp removes previous leaf ops and replaces them with itself | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| /// \brief Turns of the tracking for operations under merge op | /// \brief Turns of the tracking for operations under merge op | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | ||||
| /// \brief Saves the lookup up in case it needs to be referenced by a repeat | /// \brief Saves the lookup up in case it needs to be referenced by a repeat | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override; | ||||
| /// \brief Set the epoch count for DeviceQueue | /// \brief Set the epoch count for DeviceQueue | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | ||||
| /// \brief Special case for GeneratorOp | /// \brief Special case for GeneratorOp | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | ||||
| /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | ||||
| /// for use with a controlling repeat above it. | /// for use with a controlling repeat above it. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | ||||
| private: | private: | ||||
| /// \brief Adds an operator to the eoe operator stack save area | /// \brief Adds an operator to the eoe operator stack save area | ||||
| /// \param op - The dataset op to work add to eoe stack | /// \param op - The dataset op to work add to eoe stack | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | ||||
| /// \brief Pops an operator from the eoe operator stack save area | /// \brief Pops an operator from the eoe operator stack save area | ||||
| @@ -127,7 +127,7 @@ class RepeatPass : public NodePass { | |||||
| /// \brief Adds an operator to the cached operator stack save area | /// \brief Adds an operator to the cached operator stack save area | ||||
| /// \param op - The dataset op to work add to cached stack | /// \param op - The dataset op to work add to cached stack | ||||
| /// \return Status - The error code return | |||||
| /// \return Status The status code returned | |||||
| void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op); | void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op); | ||||
| /// \brief Pops an operator from the cached operator stack save area | /// \brief Pops an operator from the cached operator stack save area | ||||
| @@ -38,123 +38,123 @@ class CacheErrorPass : public NodePass { | |||||
| /// \brief Identifies the subtree below this node as being cached | /// \brief Identifies the subtree below this node as being cached | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if ZipOp exists under a cache | /// \brief Returns an error if ZipOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache | /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if ConcatOp exists under a cache | /// \brief Returns an error if ConcatOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if TakeOp exists under a cache | /// \brief Returns an error if TakeOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if SkipOp exists under a cache | /// \brief Returns an error if SkipOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | ||||
| /// \brief Returns an error if SkipOp exists under a cache | /// \brief Returns an error if SkipOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| /// \brief Returns an error if FilterOp exists under a cache | /// \brief Returns an error if FilterOp exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | ||||
| /// \brief Identifies the leaf dataset as being mappable | /// \brief Identifies the leaf dataset as being mappable | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | ||||
| /// \brief Identifies the subtree above this node as not being cached | /// \brief Identifies the subtree above this node as not being cached | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| /// \brief Identifies and block repeat under cache scenarios | /// \brief Identifies and block repeat under cache scenarios | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | ||||
| private: | private: | ||||
| @@ -48,14 +48,14 @@ class CacheTransformPass : public TreePass { | |||||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | /// \brief Identifies the subtree below this node as a cached descendant tree. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that | /// \brief Resets the tracking of the cache within the tree and assigns the operators that | ||||
| /// will be involved in a cache transformation | /// will be involved in a cache transformation | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -63,95 +63,95 @@ class CacheTransformPass : public TreePass { | |||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Perform leaf node cache transform identifications | /// \brief Perform leaf node cache transform identifications | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| @@ -161,12 +161,12 @@ class CacheTransformPass : public TreePass { | |||||
| private: | private: | ||||
| /// \brief Common code for mappable leaf setup. | /// \brief Common code for mappable leaf setup. | ||||
| /// \param[in] node The leaf node performing setup work. | /// \param[in] node The leaf node performing setup work. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | ||||
| /// \brief Common code for non-mappable leaf setup. | /// \brief Common code for non-mappable leaf setup. | ||||
| /// \param[in] node The leaf node performing setup work. | /// \param[in] node The leaf node performing setup work. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | ||||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | ||||
| @@ -191,7 +191,7 @@ class CacheTransformPass : public TreePass { | |||||
| /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations | /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | Status RunOnTree(ExecutionTree *tree, bool *modified) override; | ||||
| private: | private: | ||||
| @@ -212,7 +212,7 @@ class CacheTransformPass : public TreePass { | |||||
| /// \param[in] leaf_op The leaf node in the transform | /// \param[in] leaf_op The leaf node in the transform | ||||
| /// \param[in] cache_op The cache op in the transform (will get removed) | /// \param[in] cache_op The cache op in the transform (will get removed) | ||||
| /// \param[in] cache_client The cache client | /// \param[in] cache_client The cache client | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | ||||
| std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | ||||
| }; | }; | ||||
| @@ -38,61 +38,61 @@ class CacheValidationPass : public IRNodePass { | |||||
| /// \brief Returns an error if BatchNode exists under a cache | /// \brief Returns an error if BatchNode exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override; | Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override; | ||||
| /// \brief Returns an error if ConcatNode exists under a cache | /// \brief Returns an error if ConcatNode exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override; | Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override; | ||||
| /// \brief Returns an error if FilterNode exists under a cache | /// \brief Returns an error if FilterNode exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override; | Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override; | ||||
| /// \brief Returns an error if SkipNode exists under a cache | /// \brief Returns an error if SkipNode exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override; | Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override; | ||||
| /// \brief Returns an error if TakeNode exists under a cache | /// \brief Returns an error if TakeNode exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override; | Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override; | ||||
| /// \brief Returns an error if ZipNode exists under a cache | /// \brief Returns an error if ZipNode exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override; | Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override; | ||||
| /// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache | /// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; | Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; | ||||
| /// \brief Returns an error if there is a cache over another cache | /// \brief Returns an error if there is a cache over another cache | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | ||||
| /// \brief Identifies and block repeat under cache scenarios | /// \brief Identifies and block repeat under cache scenarios | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override; | Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override; | ||||
| /// \brief Identifies the subtree above this node as not being cached | /// \brief Identifies the subtree above this node as not being cached | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | ||||
| private: | private: | ||||
| @@ -45,27 +45,27 @@ class EpochCtrlPass : public IRTreePass { | |||||
| /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<RootNode> node, bool *modified) override; | Status Visit(std::shared_ptr<RootNode> node, bool *modified) override; | ||||
| /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override; | Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override; | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection. | /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override; | Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Register the TransferNode for further action. | /// \brief Register the TransferNode for further action. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override; | Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override; | ||||
| /// \brief Getter | /// \brief Getter | ||||
| @@ -89,7 +89,7 @@ class EpochCtrlPass : public IRTreePass { | |||||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -46,20 +46,20 @@ class EpochInjectionPass : public TreePass { | |||||
| /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override; | ||||
| /// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection. | /// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Register the DeviceQueueOp for further action. | /// \brief Register the DeviceQueueOp for further action. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | ||||
| /// \brief Getter | /// \brief Getter | ||||
| @@ -79,7 +79,7 @@ class EpochInjectionPass : public TreePass { | |||||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | Status RunOnTree(ExecutionTree *tree, bool *modified) override; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -17,7 +17,10 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" | #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -47,7 +50,16 @@ Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> no | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Perform ShuffleOp removal check. | |||||
| // Perform RepeatNode removal check. | |||||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { | |||||
| *modified = false; | |||||
| if (node->Count() == 1) { | |||||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Perform ShuffleNode removal check. | |||||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | ||||
| *modified = false; | *modified = false; | ||||
| #if 0 | #if 0 | ||||
| @@ -60,6 +72,24 @@ Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, b | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Perform SkipNode removal check. | |||||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||||
| *modified = false; | |||||
| if (node->Count() == 0) { | |||||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Perform TakeNode removal check. | |||||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | |||||
| *modified = false; | |||||
| if (node->Count() == -1) { | |||||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // constructor | // constructor | ||||
| NodeRemovalPass::NodeRemovalPass() {} | NodeRemovalPass::NodeRemovalPass() {} | ||||
| @@ -45,21 +45,39 @@ class NodeRemovalPass : public IRTreePass { | |||||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | /// \brief Identifies the subtree below this node as a cached descendant tree. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | ||||
| /// \brief Resets the tracking of the cache within the tree | /// \brief Resets the tracking of the cache within the tree | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | ||||
| /// \brief Perform RepeatNode removal check | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<RepeatNode> node, bool *modified) override; | |||||
| /// \brief Perform ShuffleNode removal check | /// \brief Perform ShuffleNode removal check | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override; | Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override; | ||||
| /// \brief Perform SkipNode removal check | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override; | |||||
| /// \brief Perform TakeNode removal check | |||||
| /// \param[in] node The node being visited | |||||
| /// \param[inout] modified Indicator if the node was changed at all | |||||
| /// \return Status The status code returned | |||||
| Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override; | |||||
| /// \brief Getter | /// \brief Getter | ||||
| /// \return All the nodes to be removed | /// \return All the nodes to be removed | ||||
| std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; } | std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; } | ||||
| @@ -79,7 +97,7 @@ class NodeRemovalPass : public IRTreePass { | |||||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -46,20 +46,20 @@ class RemovalPass : public TreePass { | |||||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | /// \brief Identifies the subtree below this node as a cached descendant tree. | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| /// \brief Resets the tracking of the cache within the tree | /// \brief Resets the tracking of the cache within the tree | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Perform ShuffleOp removal check | /// \brief Perform ShuffleOp removal check | ||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | ||||
| /// \brief Getter | /// \brief Getter | ||||
| @@ -81,7 +81,7 @@ class RemovalPass : public TreePass { | |||||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | ||||
| /// \param[inout] tree The tree to operate on. | /// \param[inout] tree The tree to operate on. | ||||
| /// \param[inout] Indicate of the tree was modified. | /// \param[inout] Indicate of the tree was modified. | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | Status RunOnTree(ExecutionTree *tree, bool *modified) override; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -53,7 +53,7 @@ class ConnectorSize : public Sampling { | |||||
| std::string Name() const override { return kConnectorSizeSamplingName; } | std::string Name() const override { return kConnectorSizeSamplingName; } | ||||
| // Save sampling data to file | // Save sampling data to file | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SaveToFile() override; | Status SaveToFile() override; | ||||
| Status Init(const std::string &dir_path, const std::string &device_id) override; | Status Init(const std::string &dir_path, const std::string &device_id) override; | ||||
| @@ -65,7 +65,7 @@ class ConnectorThroughput : public Sampling { | |||||
| std::string Name() const override { return name_; }; | std::string Name() const override { return name_; }; | ||||
| // Save sampling data to file | // Save sampling data to file | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SaveToFile() override; | Status SaveToFile() override; | ||||
| Status Init(const std::string &dir_path, const std::string &device_id); | Status Init(const std::string &dir_path, const std::string &device_id); | ||||
| @@ -32,13 +32,13 @@ class DatasetIteratorTracing : public Tracing { | |||||
| ~DatasetIteratorTracing() override = default; | ~DatasetIteratorTracing() override = default; | ||||
| // Record tracing data | // Record tracing data | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); | Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); | ||||
| std::string Name() const override { return kDatasetIteratorTracingName; }; | std::string Name() const override { return kDatasetIteratorTracingName; }; | ||||
| // Save tracing data to file | // Save tracing data to file | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SaveToFile() override; | Status SaveToFile() override; | ||||
| Status Init(const std::string &dir_path, const std::string &device_id) override; | Status Init(const std::string &dir_path, const std::string &device_id) override; | ||||
| @@ -32,13 +32,13 @@ class DeviceQueueTracing : public Tracing { | |||||
| ~DeviceQueueTracing() override = default; | ~DeviceQueueTracing() override = default; | ||||
| // Record tracing data | // Record tracing data | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); | Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); | ||||
| std::string Name() const override { return kDeviceQueueTracingName; }; | std::string Name() const override { return kDeviceQueueTracingName; }; | ||||
| // Save tracing data to file | // Save tracing data to file | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SaveToFile() override; | Status SaveToFile() override; | ||||
| Status Init(const std::string &dir_path, const std::string &device_id) override; | Status Init(const std::string &dir_path, const std::string &device_id) override; | ||||
| @@ -87,19 +87,19 @@ class ProfilingManager { | |||||
| Status Initialize(); | Status Initialize(); | ||||
| // Save profile data to file | // Save profile data to file | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status SaveProfilingData(); | Status SaveProfilingData(); | ||||
| // Sampling node getter | // Sampling node getter | ||||
| // @param name - The name of the requested node | // @param name - The name of the requested node | ||||
| // @param node - Pointer to the shared pointer for the Sampling node | // @param node - Pointer to the shared pointer for the Sampling node | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetSamplingNode(const std::string &name, std::shared_ptr<Sampling> *node); | Status GetSamplingNode(const std::string &name, std::shared_ptr<Sampling> *node); | ||||
| // Tracing node getter | // Tracing node getter | ||||
| // @param name - The name of the requested node | // @param name - The name of the requested node | ||||
| // @param node - Pointer to the shared pointer for the Tracing node | // @param node - Pointer to the shared pointer for the Tracing node | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status GetTracingNode(const std::string &name, std::shared_ptr<Tracing> *node); | Status GetTracingNode(const std::string &name, std::shared_ptr<Tracing> *node); | ||||
| // If profiling is enabled. | // If profiling is enabled. | ||||
| @@ -120,12 +120,12 @@ class ProfilingManager { | |||||
| // Register profile node to tree | // Register profile node to tree | ||||
| // @param node - Profiling node | // @param node - Profiling node | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status RegisterTracingNode(std::shared_ptr<Tracing> node); | Status RegisterTracingNode(std::shared_ptr<Tracing> node); | ||||
| // Register profile node to tree | // Register profile node to tree | ||||
| // @param node - Profiling node | // @param node - Profiling node | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status RegisterSamplingNode(std::shared_ptr<Sampling> node); | Status RegisterSamplingNode(std::shared_ptr<Sampling> node); | ||||
| ExecutionTree *tree_ = nullptr; // ExecutionTree pointer | ExecutionTree *tree_ = nullptr; // ExecutionTree pointer | ||||
| @@ -442,7 +442,7 @@ class SchemaObj { | |||||
| Status parse_column(nlohmann::json columns); | Status parse_column(nlohmann::json columns); | ||||
| /// \brief Get schema file from JSON file | /// \brief Get schema file from JSON file | ||||
| /// \param[in] json_obj Object of JSON parsed. | |||||
| /// \param[in] json_obj parsed JSON object | |||||
| /// \return Status code | /// \return Status code | ||||
| Status from_json(nlohmann::json json_obj); | Status from_json(nlohmann::json json_obj); | ||||
| @@ -77,7 +77,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o | |||||
| // @param std::shared_ptr<Tensor> *dst - return tensor padded | // @param std::shared_ptr<Tensor> *dst - return tensor padded | ||||
| // @param std::vector<dsize_t> pad_shape - shape to pad to | // @param std::vector<dsize_t> pad_shape - shape to pad to | ||||
| // @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format, | // @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format, | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, | Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, | ||||
| const std::shared_ptr<Tensor> &pad_val); | const std::shared_ptr<Tensor> &pad_val); | ||||
| @@ -86,7 +86,7 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | |||||
| // @param std::shared_ptr<Tensor> *dst - return tensor padded | // @param std::shared_ptr<Tensor> *dst - return tensor padded | ||||
| // @param std::vector<dsize_t> pad_shape - shape to pad to | // @param std::vector<dsize_t> pad_shape - shape to pad to | ||||
| // @param float pad_val - value to pad with | // @param float pad_val - value to pad with | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | ||||
| const std::vector<dsize_t> &pad_shape, float pad_val); | const std::vector<dsize_t> &pad_shape, float pad_val); | ||||
| @@ -98,7 +98,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> | |||||
| // @param std::vector<dsize_t> cur_ind - recursion helper | // @param std::vector<dsize_t> cur_ind - recursion helper | ||||
| // @param T pad_val - value to pad tensor with | // @param T pad_val - value to pad tensor with | ||||
| // @param size_t cur_dim - recursion helper | // @param size_t cur_dim - recursion helper | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, | Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, | ||||
| std::vector<dsize_t> cur_ind, size_t cur_dim = 0); | std::vector<dsize_t> cur_ind, size_t cur_dim = 0); | ||||
| @@ -107,7 +107,7 @@ Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<T | |||||
| // @param std::shared_ptr<Tensor> *dst - return tensor padded | // @param std::shared_ptr<Tensor> *dst - return tensor padded | ||||
| // @param std::vector<dsize_t> pad_shape - shape to pad to | // @param std::vector<dsize_t> pad_shape - shape to pad to | ||||
| // @param std::string pad_val - value to pad with | // @param std::string pad_val - value to pad with | ||||
| // @return - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | ||||
| const std::vector<dsize_t> &pad_shape, const std::string &pad_val); | const std::vector<dsize_t> &pad_shape, const std::string &pad_val); | ||||
| @@ -119,7 +119,7 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> | |||||
| // @param std::vector<dsize_t> cur_ind - recursion helperas text | // @param std::vector<dsize_t> cur_ind - recursion helperas text | ||||
| // @param std::string pad_val - value to pad tensor with | // @param std::string pad_val - value to pad tensor with | ||||
| // @param size_t cur_dim - recursion helper | // @param size_t cur_dim - recursion helper | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst, | Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst, | ||||
| const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim, | const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim, | ||||
| const std::string &pad_value); | const std::string &pad_value); | ||||
| @@ -36,7 +36,7 @@ class ToFloat16Op : public TensorOp { | |||||
| // Overrides the base class compute function | // Overrides the base class compute function | ||||
| // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor | // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor | ||||
| // and transforms its data to float16, the output memory is manipulated to contain the result | // and transforms its data to float16, the output memory is manipulated to contain the result | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | ||||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | ||||
| @@ -58,7 +58,7 @@ class CutOutOp : public TensorOp { | |||||
| // Overrides the base class compute function | // Overrides the base class compute function | ||||
| // Calls the erase function in ImageUtils, this function takes an input tensor | // Calls the erase function in ImageUtils, this function takes an input tensor | ||||
| // and overwrites some of its data using openCV, the output memory is manipulated to contain the result | // and overwrites some of its data using openCV, the output memory is manipulated to contain the result | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | ||||
| std::string Name() const override { return kCutOutOp; } | std::string Name() const override { return kCutOutOp; } | ||||
| @@ -50,7 +50,7 @@ class RandomColorAdjustOp : public TensorOp { | |||||
| // Overrides the base class compute function. | // Overrides the base class compute function. | ||||
| // Calls multiple transform functions in ImageUtils, this function takes an input tensor. | // Calls multiple transform functions in ImageUtils, this function takes an input tensor. | ||||
| // and transforms its data using openCV, the output memory is manipulated to contain the result. | // and transforms its data using openCV, the output memory is manipulated to contain the result. | ||||
| // @return Status - The error code return. | |||||
| // @return Status The status code returned. | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | ||||
| std::string Name() const override { return kRandomColorAdjustOp; } | std::string Name() const override { return kRandomColorAdjustOp; } | ||||
| @@ -61,7 +61,7 @@ class RandomRotationOp : public TensorOp { | |||||
| // Overrides the base class compute function | // Overrides the base class compute function | ||||
| // Calls the rotate function in ImageUtils, this function takes an input tensor | // Calls the rotate function in ImageUtils, this function takes an input tensor | ||||
| // and transforms its data using openCV, the output memory is manipulated to contain the result | // and transforms its data using openCV, the output memory is manipulated to contain the result | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | ||||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | ||||
| @@ -43,7 +43,7 @@ class UniformAugOp : public TensorOp { | |||||
| void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; } | void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; } | ||||
| // Overrides the base class compute function | // Overrides the base class compute function | ||||
| // @return Status - The error code return | |||||
| // @return Status The status code returned | |||||
| Status Compute(const TensorRow &input, TensorRow *output) override; | Status Compute(const TensorRow &input, TensorRow *output) override; | ||||
| std::string Name() const override { return kUniformAugOp; } | std::string Name() const override { return kUniformAugOp; } | ||||
| @@ -53,7 +53,7 @@ class DataHelper { | |||||
| /// \param key Key of field to write to | /// \param key Key of field to write to | ||||
| /// \param value Value array to write to file | /// \param value Value array to write to file | ||||
| /// \param out_file Optional input for output file path, will write to input file if not specified | /// \param out_file Optional input for output file path, will write to input file if not specified | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value, | Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value, | ||||
| const std::string &out_file = ""); | const std::string &out_file = ""); | ||||
| @@ -62,7 +62,7 @@ class DataHelper { | |||||
| /// \param key Key of field to write to | /// \param key Key of field to write to | ||||
| /// \param value Value array to write to file | /// \param value Value array to write to file | ||||
| /// \param out_file Optional parameter for output file path, will write to input file if not specified | /// \param out_file Optional parameter for output file path, will write to input file if not specified | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| template <typename T> | template <typename T> | ||||
| Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value, | Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value, | ||||
| const std::string &out_file = "") { | const std::string &out_file = "") { | ||||
| @@ -99,7 +99,7 @@ class DataHelper { | |||||
| /// \param key Key of field to write to | /// \param key Key of field to write to | ||||
| /// \param value Value to write to file | /// \param value Value to write to file | ||||
| /// \param out_file Optional parameter for output file path, will write to input file if not specified | /// \param out_file Optional parameter for output file path, will write to input file if not specified | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| template <typename T> | template <typename T> | ||||
| Status UpdateValue(const std::string &in_file, const std::string &key, const T &value, | Status UpdateValue(const std::string &in_file, const std::string &key, const T &value, | ||||
| const std::string &out_file = "") { | const std::string &out_file = "") { | ||||
| @@ -134,7 +134,7 @@ class DataHelper { | |||||
| /// \brief Template function to write tensor to file | /// \brief Template function to write tensor to file | ||||
| /// \param[in] in_file File to write to | /// \param[in] in_file File to write to | ||||
| /// \param[in] data Array of type T values | /// \param[in] data Array of type T values | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| template <typename T> | template <typename T> | ||||
| Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) { | Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) { | ||||
| try { | try { | ||||
| @@ -157,7 +157,7 @@ class DataHelper { | |||||
| /// \param[in] in_file File name to write to | /// \param[in] in_file File name to write to | ||||
| /// \param[in] data Pointer to data | /// \param[in] data Pointer to data | ||||
| /// \param[in] length Length of values to write from pointer | /// \param[in] length Length of values to write from pointer | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| template <typename T> | template <typename T> | ||||
| Status WriteBinFile(const std::string &in_file, T *data, size_t length) { | Status WriteBinFile(const std::string &in_file, T *data, size_t length) { | ||||
| try { | try { | ||||
| @@ -188,7 +188,7 @@ class DataHelper { | |||||
| /// note This function will return okay even if key not found | /// note This function will return okay even if key not found | ||||
| /// \param[in] in_file Json file to remove key from | /// \param[in] in_file Json file to remove key from | ||||
| /// \param[in] key The key to remove | /// \param[in] key The key to remove | ||||
| /// \return Status The error code return | |||||
| /// \return Status The status code returned | |||||
| Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = ""); | Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = ""); | ||||
| /// \brief A print method typically used for debugging | /// \brief A print method typically used for debugging | ||||
| @@ -669,8 +669,6 @@ class Dataset: | |||||
| >>> repeat_and_shuffle = data.repeat(50) | >>> repeat_and_shuffle = data.repeat(50) | ||||
| >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) | >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) | ||||
| """ | """ | ||||
| if count == 1: | |||||
| return self | |||||
| return RepeatDataset(self, count) | return RepeatDataset(self, count) | ||||
| @check_skip | @check_skip | ||||
| @@ -717,8 +715,6 @@ class Dataset: | |||||
| >>> # Create a dataset where the dataset includes 50 elements. | >>> # Create a dataset where the dataset includes 50 elements. | ||||
| >>> data = data.take(50) | >>> data = data.take(50) | ||||
| """ | """ | ||||
| if count == -1: | |||||
| return self | |||||
| return TakeDataset(self, count) | return TakeDataset(self, count) | ||||
| def _get_absolute_split_sizes(self, sizes): | def _get_absolute_split_sizes(self, sizes): | ||||
| @@ -1311,6 +1311,51 @@ TEST_F(MindDataTestPipeline, TestSkipDataset) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestSkipTakeRepeat) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipTakeRepeat."; | |||||
| // Create an ImageFolder Dataset | |||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 6)); | |||||
| // Create a Skip operation on ds | |||||
| int32_t count = 0; | |||||
| ds = ds->Skip(count); | |||||
| // Create a Project operation on ds | |||||
| std::vector<std::string> column_project = {"image"}; | |||||
| ds = ds->Project(column_project); | |||||
| // Add a Take(-1) | |||||
| ds = ds->Take(-1); | |||||
| // Add a Repeat(1) | |||||
| ds = ds->Repeat(1); | |||||
| // Create an iterator over the result of the above dataset | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| // iterate over the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| MS_LOG(INFO) << "Number of rows: " << i; | |||||
| // Expect 6 rows | |||||
| EXPECT_EQ(i, 6); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) { | TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize."; | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -163,7 +163,7 @@ def test_take_08(): | |||||
| def test_take_09(): | def test_take_09(): | ||||
| """ | """ | ||||
| Test take: repeat count is -1, and read the whole dataset, take after repeat | |||||
| Test take: take count is -1, and read the whole dataset, take after repeat | |||||
| """ | """ | ||||
| logger.info("test_take_09") | logger.info("test_take_09") | ||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | data1 = ds.GeneratorDataset(generator, ["data"]) | ||||
| @@ -180,7 +180,7 @@ def test_take_09(): | |||||
| def test_take_10(): | def test_take_10(): | ||||
| """ | """ | ||||
| Test take: repeat count is -1, and read the whole dataset, take before repeat | |||||
| Test take: take count is -1, and read the whole dataset, take before repeat | |||||
| """ | """ | ||||
| logger.info("test_take_10") | logger.info("test_take_10") | ||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | data1 = ds.GeneratorDataset(generator, ["data"]) | ||||
| @@ -341,6 +341,18 @@ def test_take_18(): | |||||
| assert sum([1 for _ in data1]) == 2 | assert sum([1 for _ in data1]) == 2 | ||||
| def test_take_19(): | |||||
| """ | |||||
| Test take: take is after batch, that mean take(N), N refer to batches num | |||||
| """ | |||||
| logger.info("test_take_19") | |||||
| with pytest.raises(ValueError) as info: | |||||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||||
| data1 = data1.batch(2) | |||||
| data1 = data1.take(0) | |||||
| assert "positive integer" in str(info.value) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_take_01() | test_take_01() | ||||
| test_take_02() | test_take_02() | ||||
| @@ -360,4 +372,5 @@ if __name__ == '__main__': | |||||
| test_take_16() | test_take_16() | ||||
| test_take_17() | test_take_17() | ||||
| test_take_18() | test_take_18() | ||||
| test_take_19() | |||||
| logger.info('== test take operation finished ==') | logger.info('== test take operation finished ==') | ||||