From: @nsyca Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -490,12 +490,6 @@ RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<s | |||
| #endif | |||
| RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) { | |||
| // Workaround for repeat == 1, do not inject repeat. | |||
| if (count == 1) { | |||
| ir_node_ = input->IRNode(); | |||
| return; | |||
| } | |||
| auto ds = std::make_shared<RepeatNode>(input->IRNode(), count); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| @@ -516,13 +510,6 @@ SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) { | |||
| } | |||
| TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) { | |||
| // If count is greater than the number of element in dataset or equal to -1, | |||
| // all the element in dataset will be taken | |||
| if (count == -1) { | |||
| ir_node_ = input->IRNode(); | |||
| return; | |||
| } | |||
| auto ds = std::make_shared<TakeNode>(input->IRNode(), count); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| @@ -66,7 +66,7 @@ class ColDescriptor { | |||
| /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. | |||
| /// \param[in] num_elements - The number of elements in the data for a Tensor | |||
| /// \param[inout] out_shape - The materialized output Tensor shape | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; | |||
| /// \brief << Stream output operator overload | |||
| @@ -124,13 +124,13 @@ class DataSchema { | |||
| /// \brief Parses a schema json file and populates the columns and meta info. | |||
| /// \param[in] schema_file_path - the schema file that has the column's info to load | |||
| /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load); | |||
| /// \brief Parses a schema JSON string and populates the columns and meta info. | |||
| /// \param[in] schema_json_string - the schema file that has the column's info to load | |||
| /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load); | |||
| /// \brief A print method typically used for debugging | |||
| @@ -148,7 +148,7 @@ class DataSchema { | |||
| /// \brief Adds a column descriptor to the schema | |||
| /// \param[in] cd - The ColDescriptor to add | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status AddColumn(const ColDescriptor &cd); | |||
| /// \brief getter | |||
| @@ -169,7 +169,7 @@ class DataSchema { | |||
| /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. | |||
| /// \param[inout] out_column_name_map - The output map of columns names to column index | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map); | |||
| private: | |||
| @@ -177,7 +177,7 @@ class DataSchema { | |||
| /// does not follow any particular order (json standard does not enforce any ordering protocol). | |||
| /// This one produces a schema that contains all of the columns from the schema file. | |||
| /// \param[in] column_tree - The nlohmann tree from the json file to parse | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status AnyOrderLoad(nlohmann::json column_tree); | |||
| /// \brief Internal helper function. For each input column name, perform a lookup to the json document to | |||
| @@ -185,18 +185,18 @@ class DataSchema { | |||
| /// descriptor and add to the schema in the order in which the input column names are given. | |||
| /// \param[in] column_tree - The nlohmann tree from the json file to parse | |||
| /// \param[in] columns_to_load - list of strings for the columns to add to the schema | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load); | |||
| /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. | |||
| /// \param[in] columnTree - The nlohmann child tree for a given column to load. | |||
| /// \param[in] col_name - The string name of the column for that subtree. | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); | |||
| /// \brief Internal helper function. Performs sanity checks on the json file setup. | |||
| /// \param[in] js - The nlohmann tree for the schema file | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status PreLoadExceptionCheck(const nlohmann::json &js); | |||
| std::vector<ColDescriptor> col_descs_; // Vector of column descriptors | |||
| @@ -53,7 +53,7 @@ class IteratorBase { | |||
| // functionality exists in the derived versions of this function. | |||
| // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | |||
| // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| // @note The position of a Tensor/column might be different from the initial column order | |||
| // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change | |||
| // the column ordering. | |||
| @@ -97,17 +97,17 @@ class DatasetIterator : public IteratorBase { | |||
| // from the tree root node directly. | |||
| // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | |||
| // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status FetchNextTensorRow(TensorRow *out_row) override; | |||
| // Fetches the next tensor row into device row, and returns it's shape. | |||
| // @param out_shapes - A vector of tensor shapes (one shape per column) | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetOutputShapes(std::vector<TensorShape> *out_shapes); | |||
| // Fetches the next tensor row into device row, and returns it's shape. | |||
| // @param outShapes - A vector of tensor shapes (one shape per column) | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetOutputTypes(std::vector<DataType> *out_types); | |||
| // Getter | |||
| @@ -140,12 +140,12 @@ class ChildIterator : public IteratorBase { | |||
| // only from the child/worker id as given from the constructor. | |||
| // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data | |||
| // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status FetchNextTensorRow(TensorRow *out_row) override; | |||
| // This function drains buffer until next eoe has been received. | |||
| // It will be a no-op if the previous row returned is empty. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Drain(); | |||
| // Getter | |||
| @@ -134,7 +134,7 @@ class BarrierOp : public PipelineOp { | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Handles preprocessing of the main loop, used when starting new epoch | |||
| @@ -112,12 +112,12 @@ class BatchOp : public ParallelOp { | |||
| #endif | |||
| // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<BatchOp> *); | |||
| private: | |||
| // Sanity check for builder class args | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| bool builder_drop_; | |||
| @@ -167,11 +167,11 @@ class BatchOp : public ParallelOp { | |||
| ~BatchOp() {} | |||
| // @param int32_t workerId | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status EofReceived(int32_t) override; | |||
| // @param int32_t workerId | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status EoeReceived(int32_t) override; | |||
| // A print method typically used for debugging | |||
| @@ -190,7 +190,7 @@ class BatchOp : public ParallelOp { | |||
| } | |||
| // Main loop of batch | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| @@ -214,14 +214,14 @@ class BatchOp : public ParallelOp { | |||
| // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows | |||
| // @param int32_t size - batch_size | |||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest, | |||
| dsize_t batch_size); | |||
| // @param table | |||
| // @param const PadInfo &pad_info pad info | |||
| // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map); | |||
| @@ -233,18 +233,18 @@ class BatchOp : public ParallelOp { | |||
| private: | |||
| // Worker thread for doing the memcpy of batch | |||
| // @param int32_t param workerId | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Generate buffer with batched tensors | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, | |||
| std::unique_ptr<DataBuffer> *db); | |||
| #ifdef ENABLE_PYTHON | |||
| // Function that calls pyfunc to perform map on batch | |||
| // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair); | |||
| #endif | |||
| @@ -253,7 +253,7 @@ class BatchOp : public ParallelOp { | |||
| // @param std::set<int32_t> *cols, col ids to perform pad on | |||
| // @param std::vector<float> *vals, default padding value for each column | |||
| // @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| static Status UnpackPadInfo(const PadInfo &pad_info, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map, | |||
| std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, | |||
| @@ -264,20 +264,20 @@ class BatchOp : public ParallelOp { | |||
| int32_t num_consumers() const override { return 1; } | |||
| // get the batch size for next batch | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetBatchSize(int32_t *batch_size, CBatchInfo info); | |||
| // Do the initialization of all queues then start all worker threads | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LaunchThreadsAndInitOp(); | |||
| #ifdef ENABLE_PYTHON | |||
| // Invoke batch size function with current BatchInfo to generate batch size. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); | |||
| // Invoke batch map function with current BatchInfo to generate tensors to batch. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); | |||
| #endif | |||
| @@ -107,7 +107,7 @@ class BucketBatchByLengthOp : public PipelineOp { | |||
| // Might need to batch remaining buckets after receiving eoe, so override this method. | |||
| // @param int32_t workerId | |||
| // @return Status - The error code returned | |||
| // @return Status The status code returned | |||
| Status EoeReceived(int32_t) override; | |||
| std::string Name() const override { return kBucketBatchByLengthOp; } | |||
| @@ -123,7 +123,7 @@ class BucketBatchByLengthOp : public PipelineOp { | |||
| } | |||
| // Main loop of batch | |||
| // @return Status - The error code returned | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| private: | |||
| @@ -104,7 +104,7 @@ class BuildSentencePieceVocabOp : public PipelineOp { | |||
| // The builder "build" method creates the final object. | |||
| // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op); | |||
| private: | |||
| @@ -110,7 +110,7 @@ class BuildVocabOp : public ParallelOp { | |||
| // The builder "build" method creates the final object. | |||
| // @param std::shared_ptr<BuildVocabOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<BuildVocabOp> *op); | |||
| private: | |||
| @@ -53,7 +53,7 @@ class CacheBase : public ParallelOp { | |||
| /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state | |||
| /// info from it's previous execution and then initializes itself so that it can be executed | |||
| /// again. | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status Reset() override; | |||
| /// \brief A print method typically used for debugging | |||
| @@ -80,7 +80,7 @@ class CacheLookupOp : public CacheBase, public SamplerRT { | |||
| std::shared_ptr<SamplerRT> build_sampler_; | |||
| // Check if the required parameters are set by the builder. | |||
| // \return Status The error code return | |||
| // \return Status The status code returned | |||
| Status SanityCheck() const; | |||
| }; | |||
| /// \brief Constructor | |||
| @@ -136,7 +136,7 @@ class CacheMergeOp : public ParallelOp { | |||
| std::shared_ptr<SamplerRT> build_sampler_; | |||
| /// Check if the required parameters are set by the builder. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status SanityCheck() const; | |||
| }; | |||
| @@ -189,7 +189,7 @@ class CacheMergeOp : public ParallelOp { | |||
| /// \brief Base-class override for handling cases when an eof is received. | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status EofReceived(int32_t worker_id) override; | |||
| protected: | |||
| @@ -99,7 +99,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| std::shared_ptr<SamplerRT> build_sampler_; | |||
| /// \brief Check if the required parameters are set by the builder. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status SanityCheck() const; | |||
| }; | |||
| @@ -119,7 +119,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| /// \brief Base-class override for special eoe handler. | |||
| /// CacheOp must override this because it shall not perform default handling of eoe. Instead | |||
| /// the CacheOp manages actions related to the end of the epoch. | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| /// \param[in] p The node to visit | |||
| @@ -133,7 +133,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| Status Accept(NodePass *p, bool *modified) override; | |||
| /// \brief Base-class override for handling cases when an eof is received. | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status EofReceived(int32_t worker_id) override; | |||
| Status operator()() override; | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| @@ -159,7 +159,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| Status CacheAllRows(int32_t worker_id); | |||
| Status RegisterResources() override; | |||
| /// \brief Private function for cache setup/init work just after construction | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status InitCache(); | |||
| }; | |||
| } // namespace dataset | |||
| @@ -94,7 +94,7 @@ class ConcatOp : public PipelineOp { | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Op name getter | |||
| @@ -146,14 +146,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// DatasetOps operate by launching a thread (see ExecutionTree). | |||
| /// This pure virtual version makes the requirement that derived classes must provide a functor | |||
| /// that will execute their main runtime loop code. | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status operator()() = 0; | |||
| /// \brief Gets the next buffer from the given child | |||
| /// \notes See GetNextInput for similar function that has built-in message handling | |||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) { | |||
| return GetNextBuffer(p_buffer, worker_id, false); | |||
| } | |||
| @@ -161,7 +161,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \brief Gets the next buffer from the given child | |||
| /// \notes See GetNextInput for similar function that has built-in message handling | |||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } | |||
| /// \brief Gets the next buffer from the given child | |||
| @@ -169,7 +169,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | |||
| /// \param worker_id - The worker id | |||
| /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe); | |||
| /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof | |||
| @@ -177,7 +177,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// those messages are received. | |||
| /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); | |||
| /// \brief Gets the batch size | |||
| @@ -200,19 +200,19 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// The base class implementation simply flows the eoe message to output. Derived classes | |||
| /// may override if they need to perform special eoe handling. | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status EoeReceived(int32_t worker_id); | |||
| /// \brief Performs handling for when an eof message is received. | |||
| /// The base class implementation simply flows the eof message to output. Derived classes | |||
| /// may override if they need to perform special eof handling. | |||
| /// \param worker_id - The worker id | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status EofReceived(int32_t worker_id); | |||
| /// \brief Derived classes may implement the reset function if the operator is stateful and needs | |||
| /// specific reset handling that is not contained in this common code version of the reset | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status Reset(); | |||
| /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on | |||
| @@ -79,7 +79,7 @@ class FilterOp : public ParallelOp { | |||
| private: | |||
| // Sanity check for builder class args. | |||
| // @return Status - The error code return. | |||
| // @return Status The status code returned. | |||
| Status SanityCheck(); | |||
| std::vector<std::string> build_in_col_names_; | |||
| std::shared_ptr<TensorOp> builder_predicate_func_; | |||
| @@ -105,15 +105,15 @@ class FilterOp : public ParallelOp { | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will | |||
| // provide the master loop that drives the logic for performing the work. | |||
| // @return Status The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // @param int32_t workerId. | |||
| // @return Status - The error code return. | |||
| // @return Status The status code returned. | |||
| Status EofReceived(int32_t) override; | |||
| // @param int32_t workerId. | |||
| // @return Status - The error code return. | |||
| // @return Status The status code returned. | |||
| Status EoeReceived(int32_t) override; | |||
| // A print method typically used for debugging. | |||
| @@ -151,34 +151,34 @@ class FilterOp : public ParallelOp { | |||
| // logic of FilterOp, getting the data from previous Op, validating user specified column names, | |||
| // applying predicate to each of the data, filter the data when predicate result is false. | |||
| // @param worker_id The id assigned to this thread/worker upon creation. | |||
| // @return Status The error code return. | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | |||
| // Filter the data by predicate function . | |||
| // @param in_buffer input data buffer. | |||
| // @param to_proess_indices Indices of columns to be processed. | |||
| // @param out data buffer that are filtered by predicate. | |||
| // @return Status The error code return. | |||
| // @return Status The status code returned | |||
| Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out); | |||
| // Collector databuffer. | |||
| // @return Status The error code return. | |||
| // @return Status The status code returned | |||
| Status Collector(); | |||
| // @param input tensor vector. | |||
| // @return Status - The error code return. | |||
| // @return Status The status code returned. | |||
| Status CheckInput(const TensorRow &input) const; | |||
| // Invoke python func. | |||
| // @param input tensor vector. | |||
| // @param the result of predicate. | |||
| // @return Status - The error code return. | |||
| // @return Status The status code returned. | |||
| Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); | |||
| // Private function for validating if each of the user specified input column names | |||
| // exist in the DataBuffer. | |||
| // @param input_columns The vector of input column names used in the current thread. | |||
| // @return Status The error code return. | |||
| // @return Status The status code returned | |||
| Status ValidateInColumns(const std::vector<std::string> *input_columns); | |||
| // Private function for checking the column legality | |||
| @@ -133,7 +133,7 @@ class MapOp : public ParallelOp { | |||
| int32_t build_op_connector_size_; | |||
| // Check if the required parameters are set by the builder. | |||
| // @return Status The error code return | |||
| // @return Status The status code returned | |||
| Status sanityCheck() const; | |||
| }; | |||
| @@ -170,7 +170,7 @@ class MapOp : public ParallelOp { | |||
| // provide the master loop that drives the logic for performing the work | |||
| // This main thread creates local queues, pulls databuffers from the previous | |||
| // op's Connector and distributes them to the local queues. Workers pull from the local queues. | |||
| // @return Status The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Getter | |||
| @@ -239,7 +239,7 @@ class MapOp : public ParallelOp { | |||
| // applying a list of TensorOps to each of the data, process the results and then | |||
| // pushing them back to MapOp's output Connector to be fetched by the next Op. | |||
| // @param worker_id The id assigned to this thread/worker upon creation. | |||
| // @return Status The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ | |||
| // Private function for worker thread to perform TensorOp's compute function and get the result. | |||
| @@ -89,7 +89,7 @@ class ParallelOp : public DatasetOp { | |||
| } | |||
| // Override base class reset to provide reset actions specific to the ParallelOp class. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Getter | |||
| @@ -115,7 +115,7 @@ class ParallelOp : public DatasetOp { | |||
| protected: | |||
| // Interface for derived classes to implement. All derived classes must provide the entry | |||
| // function with the main execution loop for worker threads. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status WorkerEntry(int32_t workerId) = 0; | |||
| /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp | |||
| @@ -75,7 +75,7 @@ class ProjectOp : public PipelineOp { | |||
| // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the | |||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||
| // ensure that it is not called by mistake (it will generate an error). | |||
| // @return Status - The error code returned. | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. | |||
| @@ -93,12 +93,12 @@ class ProjectOp : public PipelineOp { | |||
| // Base-class override for special eoe handler. | |||
| // Inline operators must override this because there is no connector to push eoe onto. | |||
| // @return Status - The error code returned. | |||
| // @return Status The status code returned | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| // Base-class override for special eof handler. | |||
| // Inline operators must override this because there is no connector to push eof onto. | |||
| // @return Status - The error code returned. | |||
| // @return Status The status code returned | |||
| Status EofReceived(int32_t worker_id) override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| @@ -107,7 +107,7 @@ class RenameOp : public PipelineOp { | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| @@ -78,7 +78,7 @@ class RepeatOp : public PipelineOp { | |||
| // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the | |||
| // functor since this op runs inlined inside another operator. The function is overloaded to | |||
| // ensure that it is not called by mistake (it will generate an error). | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // This function returns the buffer that is at the top of our output connector. The caller is | |||
| @@ -90,7 +90,7 @@ class RepeatOp : public PipelineOp { | |||
| // @param p_buffer - output pointer to the buffer that it will fetch. | |||
| // @param worker_id - The worker id | |||
| // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override; | |||
| // Base-class override for handling cases when an eoe is received. | |||
| @@ -130,7 +130,7 @@ class RepeatOp : public PipelineOp { | |||
| int32_t num_repeats() { return num_repeats_; } | |||
| /// \brief reset Op | |||
| /// \@return Status - The error code return | |||
| /// \@return Status The status code returned | |||
| Status Reset() override; | |||
| int64_t GetTreeRepeatCount() override; | |||
| @@ -146,13 +146,13 @@ class ShuffleOp : public PipelineOp { | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Base-class override for special eoe handler. | |||
| // ShuffleOp must override this because it shall not perform default handling of eoe. Instead | |||
| // the ShuffleOp needs to manage actions related to the end of the epoch itself. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status EoeReceived(int32_t worker_id) override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| @@ -167,17 +167,17 @@ class ShuffleOp : public PipelineOp { | |||
| private: | |||
| // Private function to add a new row to the shuffle buffer. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); | |||
| // Private function to populate the shuffle buffer initially by fetching from the child output | |||
| // connector until the shuffle buffer is full (or there is no more data coming). | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InitShuffleBuffer(); | |||
| // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by | |||
| // itself rather than waiting for the reset driven from operators above it in the pipeline. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SelfReset(); | |||
| int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) | |||
| @@ -63,7 +63,7 @@ class SkipOp : public PipelineOp { | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Base-class override for handling cases when an eoe is received. | |||
| @@ -130,12 +130,12 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| } | |||
| /// \brief Check validity of input args | |||
| /// \return - The error code returned | |||
| /// \return Status The status code returned | |||
| Status SanityCheck(); | |||
| /// \brief The builder "build" method creates the final object. | |||
| /// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp | |||
| /// \return - The error code returned | |||
| /// \return Status The status code returned | |||
| Status Build(std::shared_ptr<AlbumOp> *op); | |||
| private: | |||
| @@ -168,18 +168,18 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| ~AlbumOp() = default; | |||
| /// \brief Initialize AlbumOp related var, calls the function to walk all files | |||
| /// \return - The error code returned | |||
| /// \return Status The status code returned | |||
| Status PrescanEntry(); | |||
| /// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| /// \param[in] int32_t workerId - id of each worker | |||
| /// \return Status - The error code returned | |||
| /// \return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| /// \brief Main Loop of AlbumOp | |||
| /// Master thread: Fill IOBlockQueue, then goes to sleep | |||
| /// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | |||
| /// \return Status - The error code returned | |||
| /// \return Status The status code returned | |||
| Status operator()() override; | |||
| /// \brief A print method typically used for debugging | |||
| @@ -204,93 +204,93 @@ class AlbumOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| /// \brief Initialize Sampler, calls sampler->Init() within | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status InitSampler(); | |||
| /// \brief Load image to tensor row | |||
| /// \param[in] image_file Image name of file | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load vector of ints to tensor, append tensor to tensor row | |||
| /// \param[in] json_obj Json object containing multi-dimensional label | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load vector of floatss to tensor, append tensor to tensor row | |||
| /// \param[in] json_obj Json object containing array data | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load string array into a tensor, append tensor to tensor row | |||
| /// \param[in] json_obj Json object containing string tensor | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load string into a tensor, append tensor to tensor row | |||
| /// \param[in] json_obj Json object containing string tensor | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load float value to tensor row | |||
| /// \param[in] json_obj Json object containing float | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load int value to tensor row | |||
| /// \param[in] json_obj Json object containing int | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load emtpy tensor to tensor row | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadEmptyTensor(uint32_t col_num, TensorRow *row); | |||
| /// \brief Load id from file name to tensor row | |||
| /// \param[in] file The file name to get ID from | |||
| /// \param[in] col_num Column num in schema | |||
| /// \param[inout] row Tensor row to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row); | |||
| /// \brief Load a tensor row according to a json file | |||
| /// \param[in] row_id_type row_id - id for this tensor row | |||
| /// \param[in] ImageColumns file Json file location | |||
| /// \param[inout] TensorRow row Json content stored into a tensor row | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadTensorRow(row_id_type row_id, const std::string &file, TensorRow *row); | |||
| /// \param[in] const std::vector<int64_t> &keys Keys in ioblock | |||
| /// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| /// \brief Called first when function is called | |||
| /// \return Status The error code returned | |||
| /// \return Status The status code returned | |||
| Status LaunchThreadsAndInitOp(); | |||
| /// \brief reset Op | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Reset() override; | |||
| // Private function for computing the assignment of the column name map. | |||
| // @return Status The error code returned | |||
| // @return Status The status code returned | |||
| Status ComputeColMap() override; | |||
| int32_t rows_per_buffer_; | |||
| @@ -116,12 +116,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Check validity of input args | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| // The builder "build" method creates the final object. | |||
| // @param std::shared_ptr<CelebAOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<CelebAOp> *op); | |||
| private: | |||
| @@ -151,12 +151,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // Main Loop of CelebAOp | |||
| // Master thread: Fill IOBlockQueue, then goes to sleep | |||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t worker_id - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // A print method typically used for debugging | |||
| @@ -166,7 +166,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // Method in operator(), to fill IOBlockQueue | |||
| // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer); | |||
| /// \brief Base-class override for NodePass visitor acceptor | |||
| @@ -199,14 +199,14 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // @param const std::vector<int64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| // Load a tensor row according to a pair | |||
| // @param row_id_type row_id - id for this tensor row | |||
| // @param std::pair - <image_file,<label>> | |||
| // @param TensorRow row - image & label read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<int32_t>> &image_label, | |||
| TensorRow *row); | |||
| @@ -215,7 +215,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| bool CheckDatasetTypeValid(); | |||
| // reset Op | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Private function for computing the assignment of the column name map. | |||
| @@ -109,12 +109,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| } | |||
| // Check validity of input args | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| // The builder "build" method creates the final object. | |||
| // @param std::shared_ptr<CifarOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<CifarOp> *op); | |||
| private: | |||
| @@ -144,13 +144,13 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param uint32_t workerId - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Main Loop of CifarOp | |||
| // Master thread: Fill IOBlockQueue, then goes to sleep | |||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // A print method typically used for debugging | |||
| @@ -177,18 +177,18 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InitSampler(); | |||
| // Load a tensor row according to a pair | |||
| // @param uint64_t index - index need to load | |||
| // @param TensorRow row - image & label read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadTensorRow(uint64_t index, TensorRow *row); | |||
| // @param const std::vector<uint64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| // Read block data from cifar file | |||
| @@ -200,7 +200,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| Status LaunchThreadsAndInitOp(); | |||
| // reset Op | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Get cifar files in dir | |||
| @@ -221,7 +221,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss | |||
| // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | |||
| // Private function for computing the assignment of the column name map. | |||
| @@ -133,12 +133,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| } | |||
| // Check validity of input args | |||
| // @return = The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| // The builder "Build" method creates the final object. | |||
| // @param std::shared_ptr<CocoOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<CocoOp> *op); | |||
| private: | |||
| @@ -173,13 +173,13 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t workerId - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Main Loop of CocoOp | |||
| // Master thread: Fill IOBlockQueue, then goes to sleep | |||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // A print method typically used for debugging | |||
| @@ -214,19 +214,19 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| std::string Name() const override { return "CocoOp"; } | |||
| /// \brief Gets the class indexing | |||
| /// \return Status - The status code return | |||
| /// \return Status The status code returned | |||
| Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override; | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InitSampler(); | |||
| // Load a tensor row according to image id | |||
| // @param row_id_type row_id - id for this tensor row | |||
| // @param std::string image_id - image id | |||
| // @param TensorRow row - image & target read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); | |||
| // Load a tensor row with vector which a vector to a tensor | |||
| @@ -235,7 +235,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::shared_ptr<Tensor> image - image tensor | |||
| // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | |||
| // @param TensorRow row - image & target read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | |||
| std::shared_ptr<Tensor> coordinate, TensorRow *trow); | |||
| @@ -245,7 +245,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::shared_ptr<Tensor> image - image tensor | |||
| // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | |||
| // @param TensorRow row - image & target read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | |||
| std::shared_ptr<Tensor> coordinate, TensorRow *trow); | |||
| @@ -255,69 +255,69 @@ class CocoOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::shared_ptr<Tensor> image - image tensor | |||
| // @param std::shared_ptr<Tensor> coordinate - coordinate tensor | |||
| // @param TensorRow row - image & target read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image, | |||
| std::shared_ptr<Tensor> coordinate, TensorRow *trow); | |||
| // @param const std::string &path - path to the image file | |||
| // @param const ColDescriptor &col - contains tensor implementation and datatype | |||
| // @param std::shared_ptr<Tensor> tensor - return | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | |||
| // @param const std::vector<uint64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| // Read annotation from Annotation folder | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ParseAnnotationIds(); | |||
| // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor | |||
| // @param std::vector<int64_t> *keys - image id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | |||
| // Called first when function is called | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LaunchThreadsAndInitOp(); | |||
| // Reset dataset state | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // @param nlohmann::json image_tree - image tree of json | |||
| // @param std::vector<std::string> *image_vec - image id list of json | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec); | |||
| // @param nlohmann::json categories_tree - categories tree of json | |||
| // return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status CategoriesColumnLoad(const nlohmann::json &categories_tree); | |||
| // @param nlohmann::json categories_tree - categories tree of json | |||
| // @param const std::string &image_file - current image name in annotation | |||
| // @param const int32_t &id - current unique id of annotation | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | |||
| // @param nlohmann::json categories_tree - categories tree of json | |||
| // @param const std::string &image_file - current image name in annotation | |||
| // @param const int32_t &id - current unique id of annotation | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | |||
| // @param nlohmann::json categories_tree - categories tree of json | |||
| // @param const std::string &image_file - current image name in annotation | |||
| // @param const int32_t &id - current unique id of annotation | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id); | |||
| // @param nlohmann::json categories_tree - categories tree of json | |||
| // @param const std::string &image_file - current image name in annotation | |||
| // @param const int32_t &image_id - current unique id of annotation | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, | |||
| const int32_t &image_id); | |||
| @@ -115,13 +115,13 @@ class GeneratorOp : public PipelineOp { | |||
| // Class functor operator () override. | |||
| // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Overrides base class reset method. When an operator does a reset, it cleans up any state | |||
| // info from it's previous execution and then initializes itself so that it can be executed | |||
| // again. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| @@ -135,12 +135,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| } | |||
| // Check validity of input args | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| // The builder "build" method creates the final object. | |||
| // @param std::shared_ptr<ImageFolderOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<ImageFolderOp> *op); | |||
| private: | |||
| @@ -172,28 +172,28 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| // Initialize ImageFOlderOp related var, calls the function to walk all files | |||
| // @param - std::string dir file directory to ImageNetFolder | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status PrescanMasterEntry(const std::string &dir); | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t workerId - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t workerId - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PrescanWorkerEntry(int32_t worker_id); | |||
| // Main Loop of ImageFolderOp | |||
| // Master thread: Fill IOBlockQueue, then goes to sleep | |||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | |||
| // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | |||
| // A print method typically used for debugging | |||
| @@ -224,19 +224,19 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InitSampler(); | |||
| // Load a tensor row according to a pair | |||
| // @param row_id_type row_id - id for this tensor row | |||
| // @param ImageLabelPair pair - <imagefile,label> | |||
| // @param TensorRow row - image & label read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); | |||
| // @param const std::vector<int64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| // @param std::string & dir - dir to walk all images | |||
| @@ -253,7 +253,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| Status LaunchThreadsAndInitOp(); | |||
| // reset Op | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Private function for computing the assignment of the column name map. | |||
| @@ -58,12 +58,12 @@ class IOBlock { | |||
| // Fetches the first key from the block. | |||
| // @note Only useful if you know the block only has 1 key. | |||
| // @return A copy of the first key from the block | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetKey(int64_t *out_key) const; | |||
| // Fetches the list of keys from this block. | |||
| // @param out_keys - A copy of the vector of keys from the block. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetKeys(std::vector<int64_t> *out_keys) const; | |||
| // Does this block have the eoe flag turned on? | |||
| @@ -110,7 +110,7 @@ class FilenameBlock : public IOBlock { | |||
| // Gets the filename from the block using the provided index container | |||
| // @param out_filename - The filename to add to the block | |||
| // @param index - The index to perform lookup against | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetFilename(std::string *out_filename, const AutoIndexObj<std::string> &index) const; | |||
| // Get the start offset of file | |||
| @@ -110,12 +110,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| } | |||
| // Check validity of input args | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| // The builder "build" method creates the final object. | |||
| // @param std::shared_ptr<ManifestOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<ManifestOp> *op); | |||
| private: | |||
| @@ -145,18 +145,18 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t worker_id - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Main Loop of ManifestOp | |||
| // Master thread: Fill IOBlockQueue, then goes to sleep | |||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | |||
| // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | |||
| // A print method typically used for debugging | |||
| @@ -201,37 +201,37 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InitSampler(); | |||
| // Method in operator(), to fill IOBlockQueue | |||
| // @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer); | |||
| // Load a tensor row according to a pair | |||
| // @param row_id_type row_id - id for this tensor row | |||
| // @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>> | |||
| // @param TensorRow row - image & label read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<std::string>> &data, | |||
| TensorRow *row); | |||
| // @param const std::vector<int64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| // Parse manifest file to get image path and label and so on. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ParseManifestFile(); | |||
| // Called first when function is called | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LaunchThreadsAndInitOp(); | |||
| // reset Op | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Check if image ia valid.Only support JPEG/PNG/GIF/BMP | |||
| @@ -239,7 +239,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| Status CheckImageType(const std::string &file_name, bool *valid); | |||
| // Count label index,num rows and num samples | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status CountDatasetInfo(); | |||
| // Private function for computing the assignment of the column name map. | |||
| @@ -164,13 +164,13 @@ class MindRecordOp : public ParallelOp { | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t workerId - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Class functor operator () override. | |||
| // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Called first when function is called | |||
| @@ -180,7 +180,7 @@ class MindRecordOp : public ParallelOp { | |||
| // Overrides base class reset method. When an operator does a reset, it cleans up any state | |||
| // info from it's previous execution and then initializes itself so that it can be executed | |||
| // again. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Getter method | |||
| @@ -99,12 +99,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Check validity of input args | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| // The builder "Build" method creates the final object. | |||
| // @param std::shared_ptr<MnistOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<MnistOp> *op); | |||
| private: | |||
| @@ -133,18 +133,18 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t worker_id - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Main Loop of MnistOp | |||
| // Master thread: Fill IOBlockQueue, then goes to sleep | |||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | |||
| // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override; | |||
| // A print method typically used for debugging | |||
| @@ -170,39 +170,39 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InitSampler(); | |||
| // Load a tensor row according to a pair | |||
| // @param row_id_type row_id - id for this tensor row | |||
| // @param ImageLabelPair pair - <imagefile,label> | |||
| // @param TensorRow row - image & label read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); | |||
| // @param const std::vector<int64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| // Iterate through all members in sampleIds and fill them into IOBlock. | |||
| // @param std::shared_ptr<Tensor> sample_ids - | |||
| // @param std::vector<int64_t> *keys - keys in ioblock | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | |||
| // Check image file stream. | |||
| // @param const std::string *file_name - image file name | |||
| // @param std::ifstream *image_reader - image file stream | |||
| // @param uint32_t num_images - returns the number of images | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); | |||
| // Check label stream. | |||
| // @param const std::string &file_name - label file name | |||
| // @param std::ifstream *label_reader - label file stream | |||
| // @param uint32_t num_labels - returns the number of labels | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); | |||
| // Read 4 bytes of data from a file stream. | |||
| @@ -219,23 +219,23 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::ifstream *image_reader - image file stream | |||
| // @param std::ifstream *label_reader - label file stream | |||
| // @param int64_t read_num - number of image to read | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); | |||
| // Parse all mnist dataset files | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ParseMnistData(); | |||
| // Read all files in the directory | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WalkAllFiles(); | |||
| // Called first when function is called | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LaunchThreadsAndInitOp(); | |||
| // reset Op | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Private function for computing the assignment of the column name map. | |||
| @@ -63,7 +63,7 @@ class RandomDataOp : public ParallelOp { | |||
| /** | |||
| * The build method that produces the instantiated RandomDataOp as a shared pointer | |||
| * @param out_op - The output RandomDataOperator that was constructed | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status Build(std::shared_ptr<RandomDataOp> *out_op); | |||
| @@ -128,7 +128,7 @@ class RandomDataOp : public ParallelOp { | |||
| private: | |||
| /** | |||
| * Check if the required parameters are set by the builder. | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status SanityCheck() const; | |||
| @@ -182,7 +182,7 @@ class RandomDataOp : public ParallelOp { | |||
| * Class functor operator () override. | |||
| * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will | |||
| * provide the master loop that drives the logic for performing the work. | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status operator()() override; | |||
| @@ -190,7 +190,7 @@ class RandomDataOp : public ParallelOp { | |||
| * Overrides base class reset method. When an operator does a reset, it cleans up any state | |||
| * info from it's previous execution and then initializes itself so that it can be executed | |||
| * again. | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status Reset() override; | |||
| @@ -207,7 +207,7 @@ class RandomDataOp : public ParallelOp { | |||
| /** | |||
| * The entry point code for when workers are launched | |||
| * @param worker_id - The worker id | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| @@ -219,7 +219,7 @@ class RandomDataOp : public ParallelOp { | |||
| /** | |||
| * Performs a synchronization between workers at the end of an epoch | |||
| * @param worker_id - The worker id | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status EpochSync(int32_t worker_id, bool *quitting); | |||
| @@ -227,7 +227,7 @@ class RandomDataOp : public ParallelOp { | |||
| * A helper function to stuff the tensor table into a buffer and send it to output connector | |||
| * @param worker_id - The worker id | |||
| * @param in_table - The tensor table to pack and send | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table); | |||
| @@ -235,7 +235,7 @@ class RandomDataOp : public ParallelOp { | |||
| * A helper function to create random data for the row | |||
| * @param worker_id - The worker id | |||
| * @param new_row - The output row to produce | |||
| * @return Status - The error code return | |||
| * @return Status The status code returned | |||
| */ | |||
| Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); | |||
| @@ -40,7 +40,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||
| // @param std::unique_ptr<DataBuffer pBuffer | |||
| // @param int32_t workerId | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | |||
| @@ -53,7 +53,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status ResetSampler() override; | |||
| // Printer for debugging purposes. | |||
| @@ -41,13 +41,13 @@ class PythonSamplerRT : public SamplerRT { | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status ResetSampler() override; | |||
| // Op calls this to get next Buffer that contains all the sampleIds | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // Printer for debugging purposes. | |||
| @@ -40,14 +40,14 @@ class RandomSamplerRT : public SamplerRT { | |||
| // Op calls this to get next Buffer that contains all the sampleIds | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // meant to be called by base class or python | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status ResetSampler() override; | |||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||
| @@ -35,12 +35,12 @@ class RandomAccessOp { | |||
| public: | |||
| // Sampler get number of rows in the dataset | |||
| // @param int64_t num - return number of rows for this dataset | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNumRowsInDataset(int64_t *num_rows) const; | |||
| // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK | |||
| // @param std::map<int64_t, std::vector<int64_t>> * map | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const { | |||
| RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); | |||
| } | |||
| @@ -71,7 +71,7 @@ class SamplerRT { | |||
| // @note It is Sampler responsibility to make sure that the id is not out of bound. | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0; | |||
| // This function only called by python layer. Not needed by Android. | |||
| @@ -81,7 +81,7 @@ class SamplerRT { | |||
| #endif | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status ResetSampler() = 0; | |||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | |||
| @@ -114,13 +114,13 @@ class SamplerRT { | |||
| // Adds a sampler to become our child. | |||
| // @param std::shared_ptr<DatasetOp> - The sampler to add as a child. | |||
| // @return - The error code returned. | |||
| // @return Status The status code returned | |||
| Status AddChild(std::shared_ptr<SamplerRT> child); | |||
| // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler | |||
| // @param std::shared_ptr<Tensor>* sampleIds | |||
| // @param int64_t numElements - must be a non 0 number | |||
| // @return - The error code returned. | |||
| // @return Status The status code returned | |||
| Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements); | |||
| // A print method typically used for debugging | |||
| @@ -146,7 +146,7 @@ class SamplerRT { | |||
| // associated id. | |||
| // @param int64_t* out_associated_id - Out parameter, contains the associated id. | |||
| // @param int64_t id - The id used as an index to get the associated child id. | |||
| // @return - The error code returned. | |||
| // @return Status The status code returned | |||
| Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); | |||
| protected: | |||
| @@ -40,13 +40,13 @@ class SequentialSamplerRT : public SamplerRT { | |||
| Status InitSampler() override; | |||
| // for next epoch of sampleIds | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status ResetSampler() override; | |||
| // Op calls this to get next Buffer that contains all the sampleIds | |||
| // @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op | |||
| // @param int32_t workerId - not meant to be used | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // Printer for debugging purposes. | |||
| @@ -132,12 +132,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| } | |||
| // Check validity of input args | |||
| // @return = The error code return | |||
| // @return Status The status code returned | |||
| Status SanityCheck(); | |||
| // The builder "Build" method creates the final object. | |||
| // @param std::shared_ptr<VOCOp> *op - DatasetOp | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status Build(std::shared_ptr<VOCOp> *op); | |||
| private: | |||
| @@ -173,13 +173,13 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t workerId - id of each worker | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status WorkerEntry(int32_t worker_id) override; | |||
| // Main Loop of VOCOp | |||
| // Master thread: Fill IOBlockQueue, then goes to sleep | |||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // A print method typically used for debugging | |||
| @@ -222,55 +222,55 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status InitSampler(); | |||
| // Load a tensor row according to image id | |||
| // @param row_id_type row_id - id for this tensor row | |||
| // @param std::string image_id - image id | |||
| // @param TensorRow row - image & target read into this tensor row | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); | |||
| // @param const std::string &path - path to the image file | |||
| // @param const ColDescriptor &col - contains tensor implementation and datatype | |||
| // @param std::shared_ptr<Tensor> tensor - return | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | |||
| // @param const std::string &path - path to the image file | |||
| // @param TensorRow *row - return | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ReadAnnotationToTensor(const std::string &path, TensorRow *row); | |||
| // @param const std::vector<uint64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db); | |||
| // Read image list from ImageSets | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ParseImageIds(); | |||
| // Read annotation from Annotation folder | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ParseAnnotationIds(); | |||
| // @param const std::string &path - path to annotation xml | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ParseAnnotationBbox(const std::string &path); | |||
| // @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor | |||
| // @param std::vector<int64_t> *keys - image id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys); | |||
| // Called first when function is called | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LaunchThreadsAndInitOp(); | |||
| // Reset dataset state | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Reset() override; | |||
| // Private function for computing the assignment of the column name map. | |||
| @@ -75,7 +75,7 @@ class TakeOp : public PipelineOp { | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| // Base-class override for NodePass visitor acceptor. | |||
| @@ -101,7 +101,7 @@ class ZipOp : public PipelineOp { | |||
| // Class functor operator () override. | |||
| // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will | |||
| // provide the master loop that drives the logic for performing the work | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status operator()() override; | |||
| /// \brief Base-class override for NodePass pre-visit acceptor | |||
| @@ -226,7 +226,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) { | |||
| num_epochs_ = num_epochs; | |||
| partially_prepare_ = partial; | |||
| @@ -115,16 +115,16 @@ class ExecutionTree { | |||
| // provides it with a link to the tree. A node cannot form any relationships (parent/child) with | |||
| // other nodes unless they are associated with the same tree. | |||
| // @param op - The operator to associate | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status AssociateNode(const std::shared_ptr<DatasetOp> &op); | |||
| // Sets the root node of the tree | |||
| // @param op - The operator to assign as root | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status AssignRoot(const std::shared_ptr<DatasetOp> &op); | |||
| // Start the execution of the tree | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Launch(); | |||
| /// A print method typically used for debugging | |||
| @@ -155,7 +155,7 @@ class ExecutionTree { | |||
| // wrapper for the TaskGroup handling that is stored inside the execution tree. | |||
| // @param num_workers - The number of workers to launch | |||
| // @param func - The function entry point that workers will execute | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = ""); | |||
| // Getter method | |||
| @@ -181,32 +181,32 @@ class ExecutionTree { | |||
| // Compulsory transformation/action post optimization. | |||
| // For example, repeatOp inlining | |||
| // | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Prepare(int num_epochs = -1, bool partial = false); | |||
| // Compulsory transformation/action pre optimization. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PreAction(); | |||
| // Compulsory transformation/action post optimization. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PostAction(); | |||
| // Optimization transformation/action, optional. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Optimize(); | |||
| // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively | |||
| // walk the tree to perform modifications to the tree or specific nodes within the tree to get | |||
| // it ready for execution. | |||
| // @param Total number of epochs that will be run on this tree | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PrepareDeprecated(); | |||
| // Recursive function used during prepare phase to visit a node and drive any pre- and post- | |||
| // node actions during a tree walk. | |||
| // @param op - The dataset op to work on | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op); | |||
| // Return the pointer to the TaskGroup | |||
| @@ -51,7 +51,7 @@ class Edge { | |||
| // Get the feature of a edge | |||
| // @param FeatureType feature_type - type of feature | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | |||
| // Get nodes on the edge | |||
| @@ -71,7 +71,7 @@ class Edge { | |||
| // Update feature of edge | |||
| // @param std::shared_ptr<Feature> feature - | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; | |||
| protected: | |||
| @@ -47,19 +47,19 @@ class GraphData { | |||
| // Get all nodes from the graph. | |||
| // @param NodeType node_type - type of node | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0; | |||
| // Get all edges from the graph. | |||
| // @param NodeType edge_type - type of edge | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0; | |||
| // Get the node id from the edge. | |||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | |||
| // All neighbors of the acquisition node. | |||
| @@ -68,7 +68,7 @@ class GraphData { | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | |||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | |||
| // is not enough, fill in tensor as -1. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) = 0; | |||
| @@ -77,7 +77,7 @@ class GraphData { | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; | |||
| @@ -87,7 +87,7 @@ class GraphData { | |||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0; | |||
| @@ -98,7 +98,7 @@ class GraphData { | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) = 0; | |||
| @@ -108,7 +108,7 @@ class GraphData { | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param TensorRow *out - Returned features | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) = 0; | |||
| @@ -117,7 +117,7 @@ class GraphData { | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param Tensor *out - Returned features | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) = 0; | |||
| @@ -57,19 +57,19 @@ class GraphDataClient : public GraphData { | |||
| // Get all nodes from the graph. | |||
| // @param NodeType node_type - type of node | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get all edges from the graph. | |||
| // @param NodeType edge_type - type of edge | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get the node id from the edge. | |||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||
| // All neighbors of the acquisition node. | |||
| @@ -78,7 +78,7 @@ class GraphDataClient : public GraphData { | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | |||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | |||
| // is not enough, fill in tensor as -1. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| @@ -87,7 +87,7 @@ class GraphDataClient : public GraphData { | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||
| @@ -96,7 +96,7 @@ class GraphDataClient : public GraphData { | |||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | |||
| @@ -107,7 +107,7 @@ class GraphDataClient : public GraphData { | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| @@ -117,7 +117,7 @@ class GraphDataClient : public GraphData { | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param TensorRow *out - Returned features | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) override; | |||
| @@ -126,7 +126,7 @@ class GraphDataClient : public GraphData { | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param Tensor *out - Returned features | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) override; | |||
| @@ -51,19 +51,19 @@ class GraphDataImpl : public GraphData { | |||
| // Get all nodes from the graph. | |||
| // @param NodeType node_type - type of node | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get all edges from the graph. | |||
| // @param NodeType edge_type - type of edge | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get the node id from the edge. | |||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||
| // All neighbors of the acquisition node. | |||
| @@ -72,7 +72,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | |||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | |||
| // is not enough, fill in tensor as -1. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| @@ -81,7 +81,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||
| @@ -90,7 +90,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | |||
| @@ -101,7 +101,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| @@ -111,7 +111,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param TensorRow *out - Returned features | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) override; | |||
| @@ -123,7 +123,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param Tensor *out - Returned features | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) override; | |||
| @@ -132,7 +132,7 @@ class GraphDataImpl : public GraphData { | |||
| // Get meta information of graph | |||
| // @param MetaInfo *meta_info - Returned meta information | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetMetaInfo(MetaInfo *meta_info); | |||
| #ifdef ENABLE_PYTHON | |||
| @@ -202,14 +202,14 @@ class GraphDataImpl : public GraphData { | |||
| }; | |||
| // Load graph data from mindrecord file | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status LoadNodeAndEdge(); | |||
| // Create Tensor By Vector | |||
| // @param std::vector<std::vector<T>> &data - | |||
| // @param DataType type - | |||
| // @param std::shared_ptr<Tensor> *out - | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| template <typename T> | |||
| Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out); | |||
| @@ -217,32 +217,32 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::vector<std::vector<T>> *data - To be completed vector | |||
| // @param size_t max_size - The size of the completed vector | |||
| // @param T default_value - Filled default | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| template <typename T> | |||
| Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value); | |||
| // Get the default feature of a node | |||
| // @param FeatureType feature_type - | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | |||
| // Get the default feature of a edge | |||
| // @param FeatureType feature_type - | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | |||
| // Find node object using node id | |||
| // @param NodeIdType id - | |||
| // @param std::shared_ptr<Node> *node - Returned node object | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); | |||
| // Find edge object using edge id | |||
| // @param EdgeIdType id - | |||
| // @param std::shared_ptr<Node> *edge - Returned edge object | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge); | |||
| // Negative sampling | |||
| @@ -250,7 +250,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded | |||
| // @param int32_t samples_num - | |||
| // @param std::vector<NodeIdType> *out_samples - Sampling results returned | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids, | |||
| size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, | |||
| std::vector<NodeIdType> *out_samples); | |||
| @@ -43,12 +43,12 @@ class LocalEdge : public Edge { | |||
| // Get the feature of a edge | |||
| // @param FeatureType feature_type - type of feature | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | |||
| // Update feature of edge | |||
| // @param std::shared_ptr<Feature> feature - | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | |||
| private: | |||
| @@ -40,13 +40,13 @@ class LocalNode : public Node { | |||
| // Get the feature of a node | |||
| // @param FeatureType feature_type - type of feature | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | |||
| // Get the all neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | |||
| bool exclude_itself = false) override; | |||
| @@ -54,18 +54,18 @@ class LocalNode : public Node { | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| std::vector<NodeIdType> *out_neighbors) override; | |||
| // Add neighbor of node | |||
| // @param std::shared_ptr<Node> node - | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status AddNeighbor(const std::shared_ptr<Node> &node) override; | |||
| // Update feature of node | |||
| // @param std::shared_ptr<Feature> feature - | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | |||
| private: | |||
| @@ -49,13 +49,13 @@ class Node { | |||
| // Get the feature of a node | |||
| // @param FeatureType feature_type - type of feature | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | |||
| // Get the all neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | |||
| bool exclude_itself = false) = 0; | |||
| @@ -63,18 +63,18 @@ class Node { | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| std::vector<NodeIdType> *out_neighbors) = 0; | |||
| // Add neighbor of node | |||
| // @param std::shared_ptr<Node> node - | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0; | |||
| // Update feature of node | |||
| // @param std::shared_ptr<Feature> feature - | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0; | |||
| protected: | |||
| @@ -57,6 +57,10 @@ class RepeatNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Getter | |||
| /// \return Number of cycles to repeat the execution | |||
| const int32_t Count() const { return repeat_count_; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| @@ -55,6 +55,10 @@ class SkipNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Getter | |||
| /// \return Number of rows to skip | |||
| const int32_t Count() const { return skip_count_; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| @@ -55,6 +55,10 @@ class TakeNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Getter | |||
| /// \return Number of rows to output | |||
| const int32_t Count() const { return take_count_; } | |||
| /// \brief Base-class override for GetDatasetSize | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| @@ -29,7 +29,7 @@ class TensorOpFusionPass : public NodePass { | |||
| /// \brief Identifies and fuses tensor ops within MapOp | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] *modified indicates whether the node has been visited | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -135,7 +135,7 @@ class IRTreePass : public IRPass { | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate if the tree was modified. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); } | |||
| }; | |||
| @@ -170,14 +170,14 @@ class IRNodePass : public IRPass { | |||
| /// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution | |||
| /// \param[in] node The node being visited | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | |||
| /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation | |||
| /// "modified" flag needs to be set to true if node is modified during the pass execution | |||
| /// \param[in] node The node being visited | |||
| /// \param[out] modified Indicator if the node was changed at all. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); } | |||
| // Visit()/VisitAfter() method to be overridden. | |||
| @@ -266,7 +266,7 @@ class TreePass : public Pass { | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||
| }; | |||
| @@ -301,14 +301,14 @@ class NodePass : public Pass { | |||
| /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution | |||
| /// \param[in] node The node being visited | |||
| /// \param[out] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | |||
| /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution | |||
| /// \param[in] node The node being visited | |||
| /// \param[out] modified Indicator if the node was changed at all. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); } | |||
| // Visit methods to be overridden. | |||
| @@ -41,80 +41,80 @@ class RepeatPass : public NodePass { | |||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree below this node as being in a repeated path of the tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree below this node as being in a cache merge path | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree below this node as being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Hooks up any identified eoe nodes under this repeat. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| /// \brief Hooks up any identified eoe nodes under this repeat. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override; | |||
| /// \brief CacheOp removes previous leaf ops and replaces them with itself | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Turns of the tracking for operations under merge op | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override; | |||
| /// \brief Saves the lookup up in case it needs to be referenced by a repeat | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override; | |||
| /// \brief Set the epoch count for DeviceQueue | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | |||
| /// \brief Special case for GeneratorOp | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up | |||
| /// for use with a controlling repeat above it. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override; | |||
| private: | |||
| /// \brief Adds an operator to the eoe operator stack save area | |||
| /// \param op - The dataset op to work add to eoe stack | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||
| /// \brief Pops an operator from the eoe operator stack save area | |||
| @@ -127,7 +127,7 @@ class RepeatPass : public NodePass { | |||
| /// \brief Adds an operator to the cached operator stack save area | |||
| /// \param op - The dataset op to work add to cached stack | |||
| /// \return Status - The error code return | |||
| /// \return Status The status code returned | |||
| void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||
| /// \brief Pops an operator from the cached operator stack save area | |||
| @@ -38,123 +38,123 @@ class CacheErrorPass : public NodePass { | |||
| /// \brief Identifies the subtree below this node as being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Returns an error if ZipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override; | |||
| /// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||
| /// \brief Returns an error if ConcatOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | |||
| /// \brief Returns an error if TakeOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | |||
| /// \brief Returns an error if SkipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | |||
| /// \brief Returns an error if SkipOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override; | |||
| #ifdef ENABLE_PYTHON | |||
| /// \brief Returns an error if FilterOp exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||
| /// \brief Identifies the leaf dataset as being mappable | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| /// \brief Identifies the subtree above this node as not being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Identifies and block repeat under cache scenarios | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| private: | |||
| @@ -48,14 +48,14 @@ class CacheTransformPass : public TreePass { | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that | |||
| /// will be involved in a cache transformation | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -63,95 +63,95 @@ class CacheTransformPass : public TreePass { | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | |||
| #ifdef ENABLE_PYTHON | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override; | |||
| #endif | |||
| @@ -161,12 +161,12 @@ class CacheTransformPass : public TreePass { | |||
| private: | |||
| /// \brief Common code for mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Common code for non-mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | |||
| @@ -191,7 +191,7 @@ class CacheTransformPass : public TreePass { | |||
| /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| private: | |||
| @@ -212,7 +212,7 @@ class CacheTransformPass : public TreePass { | |||
| /// \param[in] leaf_op The leaf node in the transform | |||
| /// \param[in] cache_op The cache op in the transform (will get removed) | |||
| /// \param[in] cache_client The cache client | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | |||
| }; | |||
| @@ -38,61 +38,61 @@ class CacheValidationPass : public IRNodePass { | |||
| /// \brief Returns an error if BatchNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override; | |||
| /// \brief Returns an error if ConcatNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override; | |||
| /// \brief Returns an error if FilterNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override; | |||
| /// \brief Returns an error if SkipNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override; | |||
| /// \brief Returns an error if TakeNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override; | |||
| /// \brief Returns an error if ZipNode exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override; | |||
| /// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<MapNode> node, bool *modified) override; | |||
| /// \brief Returns an error if there is a cache over another cache | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| /// \brief Identifies and block repeat under cache scenarios | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override; | |||
| /// \brief Identifies the subtree above this node as not being cached | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| private: | |||
| @@ -45,27 +45,27 @@ class EpochCtrlPass : public IRTreePass { | |||
| /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<RootNode> node, bool *modified) override; | |||
| /// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Register the TransferNode for further action. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override; | |||
| /// \brief Getter | |||
| @@ -89,7 +89,7 @@ class EpochCtrlPass : public IRTreePass { | |||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -46,20 +46,20 @@ class EpochInjectionPass : public TreePass { | |||
| /// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override; | |||
| /// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Register the DeviceQueueOp for further action. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override; | |||
| /// \brief Getter | |||
| @@ -79,7 +79,7 @@ class EpochInjectionPass : public TreePass { | |||
| /// \brief Runs an injection pass to inject in operators needed at the pre pass stage | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -17,7 +17,10 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "minddata/dataset/engine/opt/pre/node_removal_pass.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -47,7 +50,16 @@ Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> no | |||
| return Status::OK(); | |||
| } | |||
| // Perform ShuffleOp removal check. | |||
| // Perform RepeatNode removal check. | |||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<RepeatNode> node, bool *modified) { | |||
| *modified = false; | |||
| if (node->Count() == 1) { | |||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform ShuffleNode removal check. | |||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) { | |||
| *modified = false; | |||
| #if 0 | |||
| @@ -60,6 +72,24 @@ Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, b | |||
| return Status::OK(); | |||
| } | |||
| // Perform SkipNode removal check. | |||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<SkipNode> node, bool *modified) { | |||
| *modified = false; | |||
| if (node->Count() == 0) { | |||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform TakeNode removal check. | |||
| Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<TakeNode> node, bool *modified) { | |||
| *modified = false; | |||
| if (node->Count() == -1) { | |||
| nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // constructor | |||
| NodeRemovalPass::NodeRemovalPass() {} | |||
| @@ -45,21 +45,39 @@ class NodeRemovalPass : public IRTreePass { | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override; | |||
| /// \brief Perform RepeatNode removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<RepeatNode> node, bool *modified) override; | |||
| /// \brief Perform ShuffleNode removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override; | |||
| /// \brief Perform SkipNode removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override; | |||
| /// \brief Perform TakeNode removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override; | |||
| /// \brief Getter | |||
| /// \return All the nodes to be removed | |||
| std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; } | |||
| @@ -79,7 +97,7 @@ class NodeRemovalPass : public IRTreePass { | |||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -46,20 +46,20 @@ class RemovalPass : public TreePass { | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override; | |||
| #endif | |||
| /// \brief Perform ShuffleOp removal check | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||
| /// \brief Getter | |||
| @@ -81,7 +81,7 @@ class RemovalPass : public TreePass { | |||
| /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -53,7 +53,7 @@ class ConnectorSize : public Sampling { | |||
| std::string Name() const override { return kConnectorSizeSamplingName; } | |||
| // Save sampling data to file | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SaveToFile() override; | |||
| Status Init(const std::string &dir_path, const std::string &device_id) override; | |||
| @@ -65,7 +65,7 @@ class ConnectorThroughput : public Sampling { | |||
| std::string Name() const override { return name_; }; | |||
| // Save sampling data to file | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SaveToFile() override; | |||
| Status Init(const std::string &dir_path, const std::string &device_id); | |||
| @@ -32,13 +32,13 @@ class DatasetIteratorTracing : public Tracing { | |||
| ~DatasetIteratorTracing() override = default; | |||
| // Record tracing data | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); | |||
| std::string Name() const override { return kDatasetIteratorTracingName; }; | |||
| // Save tracing data to file | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SaveToFile() override; | |||
| Status Init(const std::string &dir_path, const std::string &device_id) override; | |||
| @@ -32,13 +32,13 @@ class DeviceQueueTracing : public Tracing { | |||
| ~DeviceQueueTracing() override = default; | |||
| // Record tracing data | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); | |||
| std::string Name() const override { return kDeviceQueueTracingName; }; | |||
| // Save tracing data to file | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SaveToFile() override; | |||
| Status Init(const std::string &dir_path, const std::string &device_id) override; | |||
| @@ -87,19 +87,19 @@ class ProfilingManager { | |||
| Status Initialize(); | |||
| // Save profile data to file | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status SaveProfilingData(); | |||
| // Sampling node getter | |||
| // @param name - The name of the requested node | |||
| // @param node - Pointer to the shared pointer for the Sampling node | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetSamplingNode(const std::string &name, std::shared_ptr<Sampling> *node); | |||
| // Tracing node getter | |||
| // @param name - The name of the requested node | |||
| // @param node - Pointer to the shared pointer for the Tracing node | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status GetTracingNode(const std::string &name, std::shared_ptr<Tracing> *node); | |||
| // If profiling is enabled. | |||
| @@ -120,12 +120,12 @@ class ProfilingManager { | |||
| // Register profile node to tree | |||
| // @param node - Profiling node | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status RegisterTracingNode(std::shared_ptr<Tracing> node); | |||
| // Register profile node to tree | |||
| // @param node - Profiling node | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status RegisterSamplingNode(std::shared_ptr<Sampling> node); | |||
| ExecutionTree *tree_ = nullptr; // ExecutionTree pointer | |||
| @@ -442,7 +442,7 @@ class SchemaObj { | |||
| Status parse_column(nlohmann::json columns); | |||
| /// \brief Get schema file from JSON file | |||
| /// \param[in] json_obj Object of JSON parsed. | |||
| /// \param[in] json_obj parsed JSON object | |||
| /// \return Status code | |||
| Status from_json(nlohmann::json json_obj); | |||
| @@ -77,7 +77,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o | |||
| // @param std::shared_ptr<Tensor> *dst - return tensor padded | |||
| // @param std::vector<dsize_t> pad_shape - shape to pad to | |||
| // @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format, | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape, | |||
| const std::shared_ptr<Tensor> &pad_val); | |||
| @@ -86,7 +86,7 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | |||
| // @param std::shared_ptr<Tensor> *dst - return tensor padded | |||
| // @param std::vector<dsize_t> pad_shape - shape to pad to | |||
| // @param float pad_val - value to pad with | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | |||
| const std::vector<dsize_t> &pad_shape, float pad_val); | |||
| @@ -98,7 +98,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> | |||
| // @param std::vector<dsize_t> cur_ind - recursion helper | |||
| // @param T pad_val - value to pad tensor with | |||
| // @param size_t cur_dim - recursion helper | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst, | |||
| std::vector<dsize_t> cur_ind, size_t cur_dim = 0); | |||
| @@ -107,7 +107,7 @@ Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<T | |||
| // @param std::shared_ptr<Tensor> *dst - return tensor padded | |||
| // @param std::vector<dsize_t> pad_shape - shape to pad to | |||
| // @param std::string pad_val - value to pad with | |||
| // @return - The error code return | |||
| // @return Status The status code returned | |||
| Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, | |||
| const std::vector<dsize_t> &pad_shape, const std::string &pad_val); | |||
| @@ -119,7 +119,7 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> | |||
| // @param std::vector<dsize_t> cur_ind - recursion helperas text | |||
| // @param std::string pad_val - value to pad tensor with | |||
| // @param size_t cur_dim - recursion helper | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst, | |||
| const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim, | |||
| const std::string &pad_value); | |||
| @@ -36,7 +36,7 @@ class ToFloat16Op : public TensorOp { | |||
| // Overrides the base class compute function | |||
| // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor | |||
| // and transforms its data to float16, the output memory is manipulated to contain the result | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| @@ -58,7 +58,7 @@ class CutOutOp : public TensorOp { | |||
| // Overrides the base class compute function | |||
| // Calls the erase function in ImageUtils, this function takes an input tensor | |||
| // and overwrites some of its data using openCV, the output memory is manipulated to contain the result | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| std::string Name() const override { return kCutOutOp; } | |||
| @@ -50,7 +50,7 @@ class RandomColorAdjustOp : public TensorOp { | |||
| // Overrides the base class compute function. | |||
| // Calls multiple transform functions in ImageUtils, this function takes an input tensor. | |||
| // and transforms its data using openCV, the output memory is manipulated to contain the result. | |||
| // @return Status - The error code return. | |||
| // @return Status The status code returned. | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| std::string Name() const override { return kRandomColorAdjustOp; } | |||
| @@ -61,7 +61,7 @@ class RandomRotationOp : public TensorOp { | |||
| // Overrides the base class compute function | |||
| // Calls the rotate function in ImageUtils, this function takes an input tensor | |||
| // and transforms its data using openCV, the output memory is manipulated to contain the result | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||
| @@ -43,7 +43,7 @@ class UniformAugOp : public TensorOp { | |||
| void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; } | |||
| // Overrides the base class compute function | |||
| // @return Status - The error code return | |||
| // @return Status The status code returned | |||
| Status Compute(const TensorRow &input, TensorRow *output) override; | |||
| std::string Name() const override { return kUniformAugOp; } | |||
| @@ -53,7 +53,7 @@ class DataHelper { | |||
| /// \param key Key of field to write to | |||
| /// \param value Value array to write to file | |||
| /// \param out_file Optional input for output file path, will write to input file if not specified | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value, | |||
| const std::string &out_file = ""); | |||
| @@ -62,7 +62,7 @@ class DataHelper { | |||
| /// \param key Key of field to write to | |||
| /// \param value Value array to write to file | |||
| /// \param out_file Optional parameter for output file path, will write to input file if not specified | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| template <typename T> | |||
| Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value, | |||
| const std::string &out_file = "") { | |||
| @@ -99,7 +99,7 @@ class DataHelper { | |||
| /// \param key Key of field to write to | |||
| /// \param value Value to write to file | |||
| /// \param out_file Optional parameter for output file path, will write to input file if not specified | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| template <typename T> | |||
| Status UpdateValue(const std::string &in_file, const std::string &key, const T &value, | |||
| const std::string &out_file = "") { | |||
| @@ -134,7 +134,7 @@ class DataHelper { | |||
| /// \brief Template function to write tensor to file | |||
| /// \param[in] in_file File to write to | |||
| /// \param[in] data Array of type T values | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| template <typename T> | |||
| Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) { | |||
| try { | |||
| @@ -157,7 +157,7 @@ class DataHelper { | |||
| /// \param[in] in_file File name to write to | |||
| /// \param[in] data Pointer to data | |||
| /// \param[in] length Length of values to write from pointer | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| template <typename T> | |||
| Status WriteBinFile(const std::string &in_file, T *data, size_t length) { | |||
| try { | |||
| @@ -188,7 +188,7 @@ class DataHelper { | |||
| /// note This function will return okay even if key not found | |||
| /// \param[in] in_file Json file to remove key from | |||
| /// \param[in] key The key to remove | |||
| /// \return Status The error code return | |||
| /// \return Status The status code returned | |||
| Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = ""); | |||
| /// \brief A print method typically used for debugging | |||
| @@ -669,8 +669,6 @@ class Dataset: | |||
| >>> repeat_and_shuffle = data.repeat(50) | |||
| >>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10) | |||
| """ | |||
| if count == 1: | |||
| return self | |||
| return RepeatDataset(self, count) | |||
| @check_skip | |||
| @@ -717,8 +715,6 @@ class Dataset: | |||
| >>> # Create a dataset where the dataset includes 50 elements. | |||
| >>> data = data.take(50) | |||
| """ | |||
| if count == -1: | |||
| return self | |||
| return TakeDataset(self, count) | |||
| def _get_absolute_split_sizes(self, sizes): | |||
| @@ -1311,6 +1311,51 @@ TEST_F(MindDataTestPipeline, TestSkipDataset) { | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestSkipTakeRepeat) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipTakeRepeat."; | |||
| // Create an ImageFolder Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | |||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 6)); | |||
| // Create a Skip operation on ds | |||
| int32_t count = 0; | |||
| ds = ds->Skip(count); | |||
| // Create a Project operation on ds | |||
| std::vector<std::string> column_project = {"image"}; | |||
| ds = ds->Project(column_project); | |||
| // Add a Take(-1) | |||
| ds = ds->Take(-1); | |||
| // Add a Repeat(1) | |||
| ds = ds->Repeat(1); | |||
| // Create an iterator over the result of the above dataset | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // iterate over the dataset and get each row | |||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||
| iter->GetNextRow(&row); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| MS_LOG(INFO) << "Number of rows: " << i; | |||
| // Expect 6 rows | |||
| EXPECT_EQ(i, 6); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize."; | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| @@ -163,7 +163,7 @@ def test_take_08(): | |||
| def test_take_09(): | |||
| """ | |||
| Test take: repeat count is -1, and read the whole dataset, take after repeat | |||
| Test take: take count is -1, and read the whole dataset, take after repeat | |||
| """ | |||
| logger.info("test_take_09") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| @@ -180,7 +180,7 @@ def test_take_09(): | |||
| def test_take_10(): | |||
| """ | |||
| Test take: repeat count is -1, and read the whole dataset, take before repeat | |||
| Test take: take count is -1, and read the whole dataset, take before repeat | |||
| """ | |||
| logger.info("test_take_10") | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| @@ -341,6 +341,18 @@ def test_take_18(): | |||
| assert sum([1 for _ in data1]) == 2 | |||
| def test_take_19(): | |||
| """ | |||
| Test take: take is after batch, that mean take(N), N refer to batches num | |||
| """ | |||
| logger.info("test_take_19") | |||
| with pytest.raises(ValueError) as info: | |||
| data1 = ds.GeneratorDataset(generator, ["data"]) | |||
| data1 = data1.batch(2) | |||
| data1 = data1.take(0) | |||
| assert "positive integer" in str(info.value) | |||
| if __name__ == '__main__': | |||
| test_take_01() | |||
| test_take_02() | |||
| @@ -360,4 +372,5 @@ if __name__ == '__main__': | |||
| test_take_16() | |||
| test_take_17() | |||
| test_take_18() | |||
| test_take_19() | |||
| logger.info('== test take operation finished ==') | |||