From: @lixiachen Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -237,8 +237,11 @@ Status CacheOp::Accept(NodePass *p, bool *const modified) { | |||
| return p->RunOnNode(shared_from_base<CacheOp>(), modified); | |||
| } | |||
| // A public wrapper for creating the cache through the client | |||
| Status CacheOp::CreateCache(uint32_t cache_crc) { | |||
| Status CacheOp::PrepareNodePostAction() { | |||
| // Run any common code from super class first before adding our own | |||
| RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); | |||
| // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache | |||
| uint32_t cache_crc = DatasetOp::GenerateCRC(shared_from_this()); | |||
| // This is a non-mappable cache op so the id's need to be generated. | |||
| // Construct the cache | |||
| const bool generate_ids = true; | |||
| @@ -141,11 +141,7 @@ class CacheOp : public CacheBase, public RandomAccessOp { | |||
| bool AllowCacheMiss() override { return false; } | |||
| /// \brief Base-class override for the name of this operator | |||
| std::string Name() const override { return kCacheOp; } | |||
| /// \brief A public wrapper for creating the cache through the client | |||
| /// \param[in] cache_crc The crc that identifies the cache | |||
| /// \see cache_pass.cc | |||
| /// \return Status return code | |||
| Status CreateCache(uint32_t cache_crc); | |||
| Status PrepareNodePostAction() override; | |||
| private: | |||
| WaitPost rows_cache_done_; | |||
| @@ -33,11 +33,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ClueOp::Builder::Builder() | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_num_samples_(0), | |||
| builder_shuffle_files_(false), | |||
| builder_sampler_(nullptr) { | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -74,7 +70,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) { | |||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, | |||
| builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, | |||
| builder_device_id_, std::move(builder_sampler_)); | |||
| builder_device_id_); | |||
| RETURN_IF_NOT_OK(clue_op->Init()); | |||
| *op = std::move(clue_op); | |||
| @@ -94,8 +90,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim | |||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_rows_per_shard_(0), | |||
| all_num_rows_(0), | |||
| @@ -552,16 +548,6 @@ Status ClueOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so | |||
| // that this clue op will produce the full set of data into the cache. | |||
| void ClueOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| shuffle_files_ = false; | |||
| num_samples_ = 0; | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status ClueOp::Accept(NodePass *p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -122,14 +122,6 @@ class ClueOp : public ParallelOp { | |||
| // @return - the a string vector | |||
| std::vector<std::string> split(const std::string &s, char delim); | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| @@ -141,13 +133,12 @@ class ClueOp : public ParallelOp { | |||
| std::vector<std::string> builder_clue_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::map<std::string, std::string> builder_cols_to_keyword_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| }; | |||
| // Constructor of ClueOp | |||
| ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler); | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~ClueOp() = default; | |||
| @@ -182,11 +173,6 @@ class ClueOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return clue_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so | |||
| /// that this clue op will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // Op name getter | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "ClueOp"; } | |||
| @@ -29,11 +29,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CsvOp::Builder::Builder() | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_num_samples_(0), | |||
| builder_shuffle_files_(false), | |||
| builder_sampler_(nullptr) { | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -65,8 +61,7 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) { | |||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||
| builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | |||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_, | |||
| std::move(builder_sampler_)); | |||
| builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(csv_op->Init()); | |||
| *op = std::move(csv_op); | |||
| @@ -77,8 +72,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | |||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | |||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | |||
| int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| csv_files_list_(std::move(csv_files_list)), | |||
| field_delim_(field_delim), | |||
| column_default_list_(column_default), | |||
| @@ -920,16 +915,6 @@ Status CsvOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so | |||
| // that this csv op will produce the full set of data into the cache. | |||
| void CsvOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| shuffle_files_ = false; | |||
| num_samples_ = 0; | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status CsvOp::Accept(NodePass *p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -241,14 +241,6 @@ class CsvOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| @@ -262,7 +254,6 @@ class CsvOp : public ParallelOp { | |||
| char builder_field_delim_; | |||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; | |||
| std::vector<std::string> builder_column_name_list_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| }; | |||
| // Constructor of CsvOp | |||
| @@ -271,8 +262,7 @@ class CsvOp : public ParallelOp { | |||
| CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, | |||
| int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id, | |||
| std::shared_ptr<SamplerRT> sampler); | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~CsvOp() = default; | |||
| @@ -308,11 +298,6 @@ class CsvOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return csv_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so | |||
| /// that this csv op will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| @@ -34,8 +34,7 @@ RandomDataOp::Builder::Builder() | |||
| builder_num_workers_(0), | |||
| builder_op_connector_size_(0), | |||
| builder_rows_per_buffer_(0), | |||
| builder_total_rows_(0), | |||
| builder_sampler_(nullptr) { | |||
| builder_total_rows_(0) { | |||
| // Some arguments to the RandomDataOp have a default argument that is taken from the config. | |||
| // The user may override these defaults by using the builder set methods. | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| @@ -48,9 +47,8 @@ RandomDataOp::Builder::Builder() | |||
| Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| *out_op = | |||
| std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, | |||
| builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); | |||
| *out_op = std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, | |||
| builder_total_rows_, std::move(builder_data_schema_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -65,8 +63,8 @@ Status RandomDataOp::Builder::SanityCheck() const { | |||
| // Constructor for RandomDataOp | |||
| RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| std::unique_ptr<DataSchema> data_schema) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| buffer_id_(0), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| total_rows_(total_rows), | |||
| @@ -80,8 +78,7 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64 | |||
| if (total_rows_ == 0) { | |||
| total_rows_ = GenRandomInt(1, kMaxTotalRows); | |||
| } | |||
| // If the user did not provide a schema, then we will ask the op to generate a pseudo-random | |||
| // schema. | |||
| // If the user did not provide a schema, then we will ask the op to generate a pseudo-random schema. | |||
| // See details of generateSchema function to learn what type of schema it will create. | |||
| if (data_schema_ == nullptr) { | |||
| GenerateSchema(); | |||
| @@ -117,14 +117,6 @@ class RandomDataOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| /** | |||
| * Check if the required parameters are set by the builder. | |||
| @@ -133,7 +125,6 @@ class RandomDataOp : public ParallelOp { | |||
| Status SanityCheck() const; | |||
| std::unique_ptr<DataSchema> builder_data_schema_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int64_t builder_rows_per_buffer_; | |||
| @@ -148,11 +139,10 @@ class RandomDataOp : public ParallelOp { | |||
| * @param rows_per_buffer - The number of rows in each DataBuffer | |||
| * @param data_schema - A user-provided schema | |||
| * @param total_rows - The total number of rows in the dataset | |||
| * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | |||
| * @return Builder - The modified builder by reference | |||
| */ | |||
| RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||
| std::unique_ptr<DataSchema> data_schema); | |||
| /** | |||
| * Destructor | |||
| @@ -34,11 +34,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TextFileOp::Builder::Builder() | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_total_rows_(0), | |||
| builder_shuffle_files_(false), | |||
| builder_sampler_(nullptr) { | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -74,7 +70,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, | |||
| std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, | |||
| builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); | |||
| builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(text_file_op->Init()); | |||
| *op = std::move(text_file_op); | |||
| @@ -83,9 +79,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, | |||
| std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| @@ -504,16 +499,6 @@ Status TextFileOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so | |||
| // that this text file op will produce the full set of data into the cache. | |||
| void TextFileOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| shuffle_files_ = false; | |||
| total_rows_ = 0; | |||
| } | |||
| // Visitor accept method for NodePass | |||
| Status TextFileOp::Accept(NodePass *p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| @@ -112,14 +112,6 @@ class TextFileOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| @@ -131,7 +123,6 @@ class TextFileOp : public ParallelOp { | |||
| std::vector<std::string> builder_text_files_list_; | |||
| bool builder_shuffle_files_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| }; | |||
| // Constructor of TextFileOp | |||
| @@ -145,10 +136,9 @@ class TextFileOp : public ParallelOp { | |||
| // @param columns_to_load - the names of the columns to load data from. | |||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||
| // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | |||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler); | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~TextFileOp() = default; | |||
| @@ -187,11 +177,6 @@ class TextFileOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return text_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so | |||
| /// that this text file op will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| // @param modified - Whether this node visit modified the pipeline. | |||
| @@ -44,11 +44,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TFReaderOp::Builder::Builder() | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_total_rows_(0), | |||
| builder_equal_rows_per_shard_(false), | |||
| builder_sampler_(nullptr) { | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_worker_connector_size_ = config_manager->worker_connector_size(); | |||
| @@ -122,8 +118,7 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op) | |||
| std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>( | |||
| builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, | |||
| builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, | |||
| builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, | |||
| std::move(builder_sampler_)); | |||
| builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_); | |||
| RETURN_IF_NOT_OK(new_tf_reader_op->Init()); | |||
| *out_tf_reader_op = std::move(new_tf_reader_op); | |||
| @@ -134,8 +129,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 | |||
| int64_t total_num_rows, std::vector<std::string> dataset_files_list, | |||
| std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | |||
| std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | |||
| int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler) | |||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||
| int32_t device_id, bool equal_rows_per_shard) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| @@ -1043,17 +1038,6 @@ Status TFReaderOp::ComputeColMap() { | |||
| return Status::OK(); | |||
| } | |||
| // Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so | |||
| // that this tf reader will produce the full set of data into the cache. | |||
| void TFReaderOp::MakeSimpleProducer() { | |||
| device_id_ = 0; | |||
| num_devices_ = 1; | |||
| total_rows_ = 0; | |||
| shuffle_files_ = false; | |||
| equal_rows_per_shard_ = false; | |||
| } | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| Status TFReaderOp::PrepareNodePostAction() { | |||
| @@ -153,17 +153,8 @@ class TFReaderOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) { | |||
| builder_sampler_ = std::move(sampler); | |||
| return *this; | |||
| } | |||
| private: | |||
| std::unique_ptr<DataSchema> builder_data_schema_; | |||
| std::shared_ptr<SamplerRT> builder_sampler_; | |||
| int32_t builder_device_id_; | |||
| int32_t builder_num_devices_; | |||
| int32_t builder_num_workers_; | |||
| @@ -189,11 +180,10 @@ class TFReaderOp : public ParallelOp { | |||
| // @param columns_to_load - the names of the columns to load data from. | |||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||
| // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | |||
| TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | |||
| std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, | |||
| int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, | |||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler); | |||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); | |||
| // Default destructor | |||
| ~TFReaderOp() = default; | |||
| @@ -246,11 +236,6 @@ class TFReaderOp : public ParallelOp { | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return dataset_files_list_; } | |||
| /// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so | |||
| /// that this tf reader will produce the full set of data into the cache. | |||
| void MakeSimpleProducer(); | |||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||
| // their role. | |||
| // @notes Derived versions of this function should always call it's superclass version first | |||
| @@ -387,7 +372,7 @@ class TFReaderOp : public ParallelOp { | |||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||
| const int64_t &pre_count); | |||
| // Caculate number of rows in each shard. | |||
| // Calculate number of rows in each shard. | |||
| // @return Status - the error code returned. | |||
| Status CalculateNumRowsPerShard(); | |||
| @@ -320,7 +320,6 @@ Status ExecutionTree::PostAction() { | |||
| // The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API. | |||
| // This is because Python API binding to TensorOperation is still in progress. | |||
| post_actions.push_back(std::make_unique<CacheErrorPass>()); | |||
| post_actions.push_back(std::make_unique<CacheTransformPass>()); | |||
| post_actions.push_back(std::make_unique<RepeatPass>()); | |||
| #endif | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| #include "minddata/dataset/include/samplers.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore::dataset { | |||
| @@ -29,6 +30,9 @@ class DatasetCache { | |||
| virtual Status ValidateParams() = 0; | |||
| virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0; | |||
| virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } | |||
| virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds, | |||
| std::shared_ptr<SamplerObj> sampler) = 0; | |||
| virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) = 0; | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| @@ -16,6 +16,8 @@ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| namespace mindspore { | |||
| @@ -44,5 +46,28 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds, | |||
| std::shared_ptr<SamplerObj> sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||
| std::shared_ptr<CacheLookupOp> lookup_op = nullptr; | |||
| RETURN_IF_NOT_OK(CacheLookupOp::Builder() | |||
| .SetNumWorkers(num_workers) | |||
| .SetClient(cache_client_) | |||
| .SetSampler(sampler->SamplerBuild()) | |||
| .Build(&lookup_op)); | |||
| *ds = lookup_op; | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||
| std::shared_ptr<CacheMergeOp> merge_op = nullptr; | |||
| RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op)); | |||
| *ds = merge_op; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -56,6 +56,11 @@ class DatasetCacheImpl : public DatasetCache { | |||
| Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; | |||
| Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds, | |||
| std::shared_ptr<SamplerObj> sampler) override; | |||
| Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; | |||
| Status ValidateParams() override { return Status::OK(); } | |||
| ~DatasetCacheImpl() = default; | |||
| @@ -16,6 +16,8 @@ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| namespace mindspore { | |||
| @@ -46,5 +48,29 @@ Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds, | |||
| std::shared_ptr<SamplerObj> sampler) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||
| std::shared_ptr<CacheLookupOp> lookup_op = nullptr; | |||
| RETURN_IF_NOT_OK(CacheLookupOp::Builder() | |||
| .SetNumWorkers(num_workers) | |||
| .SetClient(cache_client_) | |||
| .SetSampler(sampler->SamplerBuild()) | |||
| .Build(&lookup_op)); | |||
| *ds = lookup_op; | |||
| return Status::OK(); | |||
| } | |||
| Status PreBuiltDatasetCache::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||
| std::shared_ptr<CacheMergeOp> merge_op = nullptr; | |||
| RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op)); | |||
| *ds = merge_op; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -40,6 +40,11 @@ class PreBuiltDatasetCache : public DatasetCache { | |||
| Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *const ds) override; | |||
| Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds, | |||
| std::shared_ptr<SamplerObj> sampler) override; | |||
| Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; | |||
| Status ValidateParams() override { return Status::OK(); } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| @@ -8,6 +8,9 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES | |||
| bucket_batch_by_length_node.cc | |||
| build_sentence_piece_vocab_node.cc | |||
| build_vocab_node.cc | |||
| cache_lookup_node.cc | |||
| cache_merge_node.cc | |||
| cache_node.cc | |||
| concat_node.cc | |||
| epoch_ctrl_node.cc | |||
| filter_node.cc | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheLookupNode::CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), sampler_(sampler), lookup_op_(nullptr), lookup_node_copy_(nullptr) { | |||
| this->AddChild(child); | |||
| } | |||
| void CacheLookupNode::Print(std::ostream &out) const { out << Name(); } | |||
| std::shared_ptr<DatasetNode> CacheLookupNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||
| auto node = std::make_shared<CacheLookupNode>(nullptr, sampler, cache_); | |||
| lookup_node_copy_ = node; | |||
| return node; | |||
| } | |||
| Status CacheLookupNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, | |||
| "Internal error. Attempt to create a cache lookup node without cache client."); | |||
| RETURN_IF_NOT_OK(cache_->Build()); | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_)); | |||
| node_ops->push_back(lookup_op_); | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() { | |||
| // CacheLookupNode should already been copied, so we just return it here | |||
| return std::static_pointer_cast<SamplerObj>(lookup_node_copy_); | |||
| } | |||
| std::shared_ptr<SamplerRT> CacheLookupNode::SamplerBuild() { | |||
| // Runtime cache lookup op should already been built, so we just return it here | |||
| auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_); | |||
| return std::shared_ptr<SamplerRT>(lookup_op); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheLookupNode : public DatasetNode, public SamplerObj { | |||
| public: | |||
| /// \brief Constructor | |||
| CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor | |||
| ~CacheLookupNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCacheLookupNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to convert a SamplerObj class into a runtime sampler object | |||
| /// \return Shared pointers to the newly created Sampler | |||
| std::shared_ptr<SamplerRT> SamplerBuild() override; | |||
| /// \brief a base class override function to copy a SamplerObj class | |||
| /// \return Shared pointers to the newly copied SamplerObj | |||
| std::shared_ptr<SamplerObj> SamplerCopy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create | |||
| /// \return Status Status::OK() if build successfully | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override; | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| private: | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| std::shared_ptr<DatasetOp> lookup_op_; | |||
| std::shared_ptr<CacheLookupNode> lookup_node_copy_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheMergeNode::CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)) { | |||
| nary_op_ = true; | |||
| this->AddChild(child); | |||
| } | |||
| void CacheMergeNode::Print(std::ostream &out) const { out << Name(); } | |||
| std::shared_ptr<DatasetNode> CacheMergeNode::Copy() { | |||
| auto node = std::make_shared<CacheMergeNode>(nullptr, cache_); | |||
| return node; | |||
| } | |||
| Status CacheMergeNode::ValidateParams() { return Status::OK(); } | |||
| Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, | |||
| "Internal error. Attempt to create a cache merge node without cache client."); | |||
| RETURN_IF_NOT_OK(cache_->Build()); | |||
| std::shared_ptr<DatasetOp> merge_op = nullptr; | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op)); | |||
| node_ops->push_back(merge_op); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheMergeNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor | |||
| ~CacheMergeNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCacheMergeNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create | |||
| /// \return Status Status::OK() if build successfully | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override; | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_ | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_node.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CacheNode::CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : DatasetNode(std::move(cache)), sampler_(sampler) { | |||
| this->AddChild(child); | |||
| } | |||
| void CacheNode::Print(std::ostream &out) const { out << Name(); } | |||
| std::shared_ptr<DatasetNode> CacheNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||
| auto node = std::make_shared<CacheNode>(nullptr, sampler, cache_); | |||
| return node; | |||
| } | |||
| Status CacheNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, | |||
| "Internal error. Attempt to create a cache node without cache client."); | |||
| RETURN_IF_NOT_OK(cache_->Build()); | |||
| std::shared_ptr<DatasetOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); | |||
| cache_op->SetSampler(sampler_->SamplerBuild()); | |||
| node_ops->push_back(cache_op); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class CacheNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor | |||
| ~CacheNode() = default; | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| std::string Name() const override { return kCacheNode; } | |||
| /// \brief Print the description | |||
| /// \param out - The output stream to write output to | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create | |||
| /// \return Status Status::OK() if build successfully | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override; | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| private: | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_ | |||
| @@ -204,15 +204,6 @@ std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int | |||
| return SequentialSampler(0, num_samples); | |||
| } | |||
| Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) { | |||
| if (cache_ != nullptr) { | |||
| RETURN_IF_NOT_OK(cache_->Build()); | |||
| std::shared_ptr<DatasetOp> cache_op; | |||
| RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); | |||
| node_ops->push_back(cache_op); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Constructor to initialize the cache | |||
| DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; } | |||
| @@ -53,6 +53,9 @@ constexpr char kBatchNode[] = "Batch"; | |||
| constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength"; | |||
| constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab"; | |||
| constexpr char kBuildVocabNode[] = "BuildVocab"; | |||
| constexpr char kCacheLookupNode[] = "CacheLookup"; | |||
| constexpr char kCacheMergeNode[] = "CacheMerge"; | |||
| constexpr char kCacheNode[] = "Cache"; | |||
| constexpr char kConcatNode[] = "Concat"; | |||
| constexpr char kEpochCtrlNode[] = "EpochCtrl"; | |||
| constexpr char kFilterNode[] = "Filter"; | |||
| @@ -248,6 +251,9 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \brief Getter of the number of workers | |||
| int32_t num_workers() { return num_workers_; } | |||
| /// \brief Getter of dataset cache | |||
| std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; } | |||
| /// \brief Setter function for runtime number of workers | |||
| /// \param[in] num_workers The number of threads in this operator | |||
| /// \return Shared pointer to the original object | |||
| @@ -299,7 +305,6 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| // Used only in the constructor of the class and its derived classes. | |||
| void AddChild(std::shared_ptr<DatasetNode> child); | |||
| std::string PrintColumns(const std::vector<std::string> &columns) const; | |||
| Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops); | |||
| void PrintNode(std::ostream &out, int *level) const; | |||
| enum DataSource { kNotADataSource = 0, kNonMappableSource = 1, kMappableSource = 2 }; | |||
| enum DataSource mappable_; | |||
| @@ -360,6 +365,20 @@ class NonMappableSourceNode : public DatasetNode { | |||
| /// \brief Node name getter | |||
| /// \return Name of the current node | |||
| virtual std::string Name() const = 0; | |||
| /// \brief By default non-mappable dataset does not support sampling. However, if a cache operator | |||
| /// is injected at some other place higher in the tree, that cache can inherit this sampler | |||
| /// from the leaf, providing sampling support from the caching layer. | |||
| /// This function sets up the sampler for a leaf node that does not use sampling. | |||
| /// \param[in] sampler The sampler to setup | |||
| /// \return Status of the function | |||
| virtual Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) = 0; | |||
| /// \brief If a cache has been added into the ascendant tree over this non-mappable source node, then the cache will | |||
| /// be executing a sampler for fetching the data. As such, any options in the source node need to be reset to its | |||
| /// defaults so that this source node will produce the full set of data into the cache. | |||
| /// \return Status of the function | |||
| virtual Status MakeSimpleProducer() = 0; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -76,7 +76,6 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| auto project_op = std::make_shared<ProjectOp>(project_columns_); | |||
| node_ops->push_back(project_op); | |||
| } | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(map_op); | |||
| return Status::OK(); | |||
| @@ -72,8 +72,6 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| // Argument that is not exposed to user in the API. | |||
| std::set<std::string> extensions = {}; | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| decode_, extensions, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| @@ -67,8 +67,6 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||
| // label is like this:0 1 0 0 1...... | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| decode_, usage_, extensions_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| @@ -64,8 +64,6 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | |||
| dataset_dir_, connector_que_size_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| @@ -62,8 +62,6 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | |||
| dataset_dir_, connector_que_size_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| @@ -83,84 +83,66 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) { | |||
| return res; | |||
| } | |||
| // Function to build CLUENode | |||
| Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| std::map<std::string, std::string> CLUENode::CreateKeyMapForBuild() { | |||
| std::map<std::string, std::string> key_map; | |||
| if (task_ == "AFQMC") { | |||
| if (usage_ == "train") { | |||
| if (usage_ == "train" || usage_ == "eval") { | |||
| key_map["sentence1"] = "sentence1"; | |||
| key_map["sentence2"] = "sentence2"; | |||
| key_map["label"] = "label"; | |||
| } else if (usage_ == "test") { | |||
| } else { // usage_ == "test" | |||
| key_map["id"] = "id"; | |||
| key_map["sentence1"] = "sentence1"; | |||
| key_map["sentence2"] = "sentence2"; | |||
| } else if (usage_ == "eval") { | |||
| key_map["sentence1"] = "sentence1"; | |||
| key_map["sentence2"] = "sentence2"; | |||
| key_map["label"] = "label"; | |||
| } | |||
| } else if (task_ == "CMNLI") { | |||
| if (usage_ == "train") { | |||
| } | |||
| if (task_ == "CMNLI") { | |||
| if (usage_ == "train" || usage_ == "eval") { | |||
| key_map["sentence1"] = "sentence1"; | |||
| key_map["sentence2"] = "sentence2"; | |||
| key_map["label"] = "label"; | |||
| } else if (usage_ == "test") { | |||
| } else { // usage_ == "test" | |||
| key_map["id"] = "id"; | |||
| key_map["sentence1"] = "sentence1"; | |||
| key_map["sentence2"] = "sentence2"; | |||
| } else if (usage_ == "eval") { | |||
| key_map["sentence1"] = "sentence1"; | |||
| key_map["sentence2"] = "sentence2"; | |||
| key_map["label"] = "label"; | |||
| } | |||
| } else if (task_ == "CSL") { | |||
| if (usage_ == "train") { | |||
| } | |||
| if (task_ == "CSL") { | |||
| if (usage_ == "train" || usage_ == "eval") { | |||
| key_map["id"] = "id"; | |||
| key_map["abst"] = "abst"; | |||
| key_map["keyword"] = "keyword"; | |||
| key_map["label"] = "label"; | |||
| } else if (usage_ == "test") { | |||
| } else { // usage_ == "test" | |||
| key_map["id"] = "id"; | |||
| key_map["abst"] = "abst"; | |||
| key_map["keyword"] = "keyword"; | |||
| } else if (usage_ == "eval") { | |||
| key_map["id"] = "id"; | |||
| key_map["abst"] = "abst"; | |||
| key_map["keyword"] = "keyword"; | |||
| key_map["label"] = "label"; | |||
| } | |||
| } else if (task_ == "IFLYTEK") { | |||
| if (usage_ == "train") { | |||
| } | |||
| if (task_ == "IFLYTEK") { | |||
| if (usage_ == "train" || usage_ == "eval") { | |||
| key_map["label"] = "label"; | |||
| key_map["label_des"] = "label_des"; | |||
| key_map["sentence"] = "sentence"; | |||
| } else if (usage_ == "test") { | |||
| } else { // usage_ == "test" | |||
| key_map["id"] = "id"; | |||
| key_map["sentence"] = "sentence"; | |||
| } else if (usage_ == "eval") { | |||
| key_map["label"] = "label"; | |||
| key_map["label_des"] = "label_des"; | |||
| key_map["sentence"] = "sentence"; | |||
| } | |||
| } else if (task_ == "TNEWS") { | |||
| if (usage_ == "train") { | |||
| } | |||
| if (task_ == "TNEWS") { | |||
| if (usage_ == "train" || usage_ == "eval") { | |||
| key_map["label"] = "label"; | |||
| key_map["label_desc"] = "label_desc"; | |||
| key_map["sentence"] = "sentence"; | |||
| key_map["keywords"] = "keywords"; | |||
| } else if (usage_ == "test") { | |||
| } else { // usage_ == "test" | |||
| key_map["id"] = "id"; | |||
| key_map["sentence"] = "sentence"; | |||
| key_map["keywords"] = "keywords"; | |||
| } else if (usage_ == "eval") { | |||
| key_map["label"] = "label"; | |||
| key_map["label_desc"] = "label_desc"; | |||
| key_map["sentence"] = "sentence"; | |||
| key_map["keywords"] = "keywords"; | |||
| } | |||
| } else if (task_ == "WSC") { | |||
| if (usage_ == "train") { | |||
| } | |||
| if (task_ == "WSC") { | |||
| if (usage_ == "train" || usage_ == "eval") { | |||
| key_map["span1_index"] = "target/span1_index"; | |||
| key_map["span2_index"] = "target/span2_index"; | |||
| key_map["span1_text"] = "target/span1_text"; | |||
| @@ -168,24 +150,21 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| key_map["idx"] = "idx"; | |||
| key_map["label"] = "label"; | |||
| key_map["text"] = "text"; | |||
| } else if (usage_ == "test") { | |||
| } else { // usage_ == "test" | |||
| key_map["span1_index"] = "target/span1_index"; | |||
| key_map["span2_index"] = "target/span2_index"; | |||
| key_map["span1_text"] = "target/span1_text"; | |||
| key_map["span2_text"] = "target/span2_text"; | |||
| key_map["idx"] = "idx"; | |||
| key_map["text"] = "text"; | |||
| } else if (usage_ == "eval") { | |||
| key_map["span1_index"] = "target/span1_index"; | |||
| key_map["span2_index"] = "target/span2_index"; | |||
| key_map["span1_text"] = "target/span1_text"; | |||
| key_map["span2_text"] = "target/span2_text"; | |||
| key_map["idx"] = "idx"; | |||
| key_map["label"] = "label"; | |||
| key_map["text"] = "text"; | |||
| } | |||
| } | |||
| return key_map; | |||
| } | |||
| // Function to build CLUENode | |||
| Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| auto key_map = CreateKeyMapForBuild(); | |||
| ColKeyMap ck_map; | |||
| for (auto &p : key_map) { | |||
| ck_map.insert({p.first, split(p.second, '/')}); | |||
| @@ -193,19 +172,13 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // ClueOp by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| // Sort the dataset files in a lexicographical order | |||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | |||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); | |||
| std::shared_ptr<ClueOp> clue_op = | |||
| std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, | |||
| sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| RETURN_IF_NOT_OK(clue_op->Init()); | |||
| @@ -222,7 +195,6 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| rows_per_buffer_, &shuffle_op)); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(clue_op); | |||
| @@ -270,5 +242,27 @@ Status CLUENode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // CLUE by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we setup the sampler for a leaf node that does not use sampling. | |||
| Status CLUENode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| return Status::OK(); | |||
| } | |||
| // If a cache has been added into the ascendant tree over this clue node, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so | |||
| // that this clue node will produce the full set of data into the cache. | |||
| Status CLUENode::MakeSimpleProducer() { | |||
| shard_id_ = 0; | |||
| num_shards_ = 1; | |||
| shuffle_ = ShuffleMode::kFalse; | |||
| num_samples_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -49,6 +50,10 @@ class CLUENode : public NonMappableSourceNode { | |||
| /// \return A shared pointer to the new copy | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief Generate a key map to be used in Build() according to usage and task | |||
| /// \return The generated key map | |||
| std::map<std::string, std::string> CreateKeyMapForBuild(); | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||
| /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create | |||
| /// \return Status Status::OK() if build successfully | |||
| @@ -85,6 +90,22 @@ class CLUENode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief CLUE by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \param[in] sampler The sampler to setup | |||
| /// \return Status of the function | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief If a cache has been added into the ascendant tree over this clue node, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the clue node need to be reset to its defaults so | |||
| /// that this clue node will produce the full set of data into the cache. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \return Status of the function | |||
| Status MakeSimpleProducer() override; | |||
| private: | |||
| /// \brief Split string based on a character delimiter | |||
| /// \return A string vector | |||
| @@ -122,7 +122,6 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| std::shared_ptr<CocoOp> op = | |||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(op); | |||
| @@ -95,12 +95,6 @@ Status CSVNode::ValidateParams() { | |||
| Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // CSVOp by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| // Sort the dataset files in a lexicographical order | |||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||
| @@ -119,10 +113,9 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| } | |||
| } | |||
| std::shared_ptr<CsvOp> csv_op = | |||
| std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, | |||
| rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, | |||
| num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); | |||
| std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( | |||
| sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, | |||
| num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| RETURN_IF_NOT_OK(csv_op->Init()); | |||
| @@ -140,7 +133,6 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(csv_op); | |||
| @@ -188,5 +180,27 @@ Status CSVNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // CSV by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we setup the sampler for a leaf node that does not use sampling. | |||
| Status CSVNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| return Status::OK(); | |||
| } | |||
| // If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so | |||
| // that this CSV node will produce the full set of data into the cache. | |||
| Status CSVNode::MakeSimpleProducer() { | |||
| shard_id_ = 0; | |||
| num_shards_ = 1; | |||
| shuffle_ = ShuffleMode::kFalse; | |||
| num_samples_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -107,6 +107,22 @@ class CSVNode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief CSV by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \param[in] sampler The sampler to setup | |||
| /// \return Status of the function | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief If a cache has been added into the ascendant tree over this CSV node, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the CSV node need to be reset to its defaults so | |||
| /// that this CSV node will produce the full set of data into the cache. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \return Status of the function | |||
| Status MakeSimpleProducer() override; | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| char field_delim_; | |||
| @@ -95,10 +95,10 @@ class GeneratorNode : public MappableSourceNode { | |||
| /// \brief Sampler getter | |||
| /// \return SamplerObj of the current node | |||
| std::shared_ptr<SamplerObj> Sampler() override { return nullptr; } | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| /// \brief Sampler setter | |||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override {} | |||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||
| private: | |||
| py::function generator_function_; | |||
| @@ -70,8 +70,6 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | |||
| std::move(sampler_->SamplerBuild()))); | |||
| @@ -94,7 +94,6 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| manifest_op = | |||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | |||
| class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(manifest_op); | |||
| @@ -23,8 +23,9 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -203,5 +204,16 @@ Status MindDataNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status MindDataNode::Accept(IRNodePass *const p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<MindDataNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status MindDataNode::AcceptAfter(IRNodePass *p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<MindDataNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -92,6 +92,18 @@ class MindDataNode : public MappableSourceNode { | |||
| /// \brief Sampler setter | |||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *p, bool *const modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *p, bool *const modified) override; | |||
| private: | |||
| std::string dataset_file_; // search_for_pattern_ will be true in this mode | |||
| std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode | |||
| @@ -57,7 +57,6 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||
| connector_que_size_, std::move(schema), | |||
| @@ -22,6 +22,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -105,17 +106,9 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops | |||
| } | |||
| } | |||
| // RandomOp by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| // RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential. | |||
| std::shared_ptr<SamplerObj> sampler_ = SelectSampler(total_rows_, false, 1, 0); | |||
| std::shared_ptr<RandomDataOp> op; | |||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | |||
| std::move(data_schema_), std::move(sampler_->SamplerBuild())); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| std::move(data_schema_)); | |||
| node_ops->push_back(op); | |||
| @@ -142,5 +135,27 @@ Status RandomNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| // RandomDataset by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we setup the sampler for a leaf node that does not use sampling. | |||
| Status RandomNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| // RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential. | |||
| *sampler = SelectSampler(total_rows_, false, 1, 0); | |||
| return Status::OK(); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status RandomNode::Accept(IRNodePass *p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->Visit(shared_from_base<RandomNode>(), modified); | |||
| } | |||
| // Visitor accepting method for IRNodePass | |||
| Status RandomNode::AcceptAfter(IRNodePass *p, bool *const modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->VisitAfter(shared_from_base<RandomNode>(), modified); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -99,6 +99,30 @@ class RandomNode : public NonMappableSourceNode { | |||
| const std::mt19937 &RandGen() const { return rand_gen_; } | |||
| const std::unique_ptr<DataSchema> &GetDataSchema() const { return data_schema_; } | |||
| /// \brief RandomDataset by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// \param[in] sampler The sampler to setup | |||
| /// \return Status of the function | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief Random node will always produce the full set of data into the cache | |||
| /// \return Status of the function | |||
| Status MakeSimpleProducer() override { return Status::OK(); } | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status Accept(IRNodePass *p, bool *const modified) override; | |||
| /// \brief Base-class override for accepting IRNodePass visitor | |||
| /// \param[in] p The node to visit | |||
| /// \param[out] modified Indicator if the node was modified | |||
| /// \return Status of the node visit | |||
| Status AcceptAfter(IRNodePass *p, bool *const modified) override; | |||
| private: | |||
| /// \brief A quick inline for producing a random number between (and including) min/max | |||
| /// \param[in] min minimum number that can be generated. | |||
| @@ -73,12 +73,6 @@ Status TextFileNode::ValidateParams() { | |||
| Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // TextFileOp by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| // Sort the dataset files in a lexicographical order | |||
| std::vector<std::string> sorted_dataset_files = dataset_files_; | |||
| std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end()); | |||
| @@ -87,10 +81,10 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| // Create and initalize TextFileOp | |||
| // Create and initialize TextFileOp | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->SamplerBuild())); | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| RETURN_IF_NOT_OK(text_file_op->Init()); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal && !IsDescendantOfCache()) { | |||
| @@ -106,7 +100,6 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| rows_per_buffer_, &shuffle_op)); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| // Add TextFileOp | |||
| node_ops->push_back(text_file_op); | |||
| @@ -152,5 +145,27 @@ Status TextFileNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // TextFile by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we setup the sampler for a leaf node that does not use sampling. | |||
| Status TextFileNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| return Status::OK(); | |||
| } | |||
| // If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults so | |||
| // that this TextFile node will produce the full set of data into the cache. | |||
| Status TextFileNode::MakeSimpleProducer() { | |||
| shard_id_ = 0; | |||
| num_shards_ = 1; | |||
| shuffle_ = ShuffleMode::kFalse; | |||
| num_samples_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -83,6 +83,22 @@ class TextFileNode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief TextFile by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \param[in] sampler The sampler to setup | |||
| /// \return Status of the function | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief If a cache has been added into the ascendant tree over this TextFile node, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the TextFile node need to be reset to its defaults | |||
| /// so that this TextFile node will produce the full set of data into the cache. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \return Status of the function | |||
| Status MakeSimpleProducer() override; | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| int32_t num_samples_; | |||
| @@ -121,17 +121,10 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // TFReaderOp by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we save the sampler here in a leaf node that does not use sampling. | |||
| std::shared_ptr<SamplerObj> sampler_ = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| // Create and initialize TFReaderOp | |||
| std::shared_ptr<TFReaderOp> tf_reader_op = | |||
| std::make_shared<TFReaderOp>(num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, | |||
| std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, | |||
| shard_id_, shard_equal_rows_, std::move(sampler_->SamplerBuild())); | |||
| std::shared_ptr<TFReaderOp> tf_reader_op = std::make_shared<TFReaderOp>( | |||
| num_workers_, worker_connector_size_, rows_per_buffer_, num_samples_, sorted_dir_files, std::move(data_schema), | |||
| connector_que_size_, columns_list_, shuffle_files, num_shards_, shard_id_, shard_equal_rows_); | |||
| RETURN_IF_NOT_OK(tf_reader_op->Init()); | |||
| @@ -149,7 +142,6 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o | |||
| rows_per_buffer_, &shuffle_op)); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| // Add TFReaderOp | |||
| node_ops->push_back(tf_reader_op); | |||
| @@ -227,5 +219,29 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // TFRecord by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we setup the sampler for a leaf node that does not use sampling. | |||
| Status TFRecordNode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| return Status::OK(); | |||
| } | |||
| // If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults so | |||
| // that this TFRecord node will produce the full set of data into the cache. | |||
| Status TFRecordNode::MakeSimpleProducer() { | |||
| shard_id_ = 0; | |||
| num_shards_ = 1; | |||
| shuffle_ = ShuffleMode::kFalse; | |||
| num_samples_ = 0; | |||
| shard_equal_rows_ = false; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -124,6 +124,22 @@ class TFRecordNode : public NonMappableSourceNode { | |||
| /// \return Status of the function | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief TFRecord by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \param[in] sampler The sampler to setup | |||
| /// \return Status of the function | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief If a cache has been added into the ascendant tree over this TFRecord node, then the cache will be executing | |||
| /// a sampler for fetching the data. As such, any options in the TFRecord node need to be reset to its defaults | |||
| /// so that this TFRecord node will produce the full set of data into the cache. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \return Status of the function | |||
| Status MakeSimpleProducer() override; | |||
| private: | |||
| std::vector<std::string> dataset_files_; | |||
| std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string | |||
| @@ -113,7 +113,6 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| voc_op = | |||
| std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild())); | |||
| RETURN_IF_NOT_OK(AddCacheOp(node_ops)); | |||
| node_ops->push_back(voc_op); | |||
| return Status::OK(); | |||
| @@ -31,8 +31,14 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/root_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||
| @@ -195,10 +201,10 @@ Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modi | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified); | |||
| } | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) { | |||
| @@ -207,12 +213,26 @@ Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) { | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| Status IRNodePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified); | |||
| } | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified); | |||
| } | |||
| #endif | |||
| Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| Status IRNodePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified); | |||
| } | |||
| Status IRNodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified) { | |||
| return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified); | |||
| } | |||
| Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *const modified) { | |||
| return Visit(std::static_pointer_cast<DatasetNode>(node), modified); | |||
| } | |||
| @@ -44,7 +44,6 @@ class TakeNode; | |||
| class TransferNode; | |||
| class ZipNode; | |||
| #ifdef ENABLE_PYTHON | |||
| class GeneratorNode; | |||
| class SyncWaitNode; | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -129,14 +128,14 @@ class IRPass : public std::enable_shared_from_this<IRPass> { | |||
| class IRTreePass : public IRPass { | |||
| public: | |||
| /// \brief Run the transformation pass against the IR tree. | |||
| /// \param[inout] root_ir Pointer to the IR tree to be transformed. | |||
| /// \param[inout] modified Indicate if the tree was modified | |||
| /// \param[in,out] root_ir Pointer to the IR tree to be transformed. | |||
| /// \param[in,out] modified Indicate if the tree was modified | |||
| Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final; | |||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate if the tree was modified. | |||
| /// \param[in,out] tree The tree to operate on. | |||
| /// \param[in,out] Indicate if the tree was modified. | |||
| /// \return Status The status code returned | |||
| virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) { return Status::OK(); } | |||
| }; | |||
| @@ -164,8 +163,8 @@ class IRNodePass : public IRPass { | |||
| ~IRNodePass() = default; | |||
| /// \brief Run the transformation pass against the IR tree | |||
| /// \param[inout] root_ir Pointer to the IR tree to be transformed | |||
| /// \param[inout] modified Indicator if the tree was changed | |||
| /// \param[in,out] root_ir Pointer to the IR tree to be transformed | |||
| /// \param[in,out] modified Indicator if the tree was changed | |||
| Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final; | |||
| /// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down | |||
| @@ -210,8 +209,14 @@ class IRNodePass : public IRPass { | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified); | |||
| #ifndef ENABLE_ANDROID | |||
| virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified); | |||
| #endif | |||
| virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified); | |||
| virtual Status Visit(std::shared_ptr<RandomNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified); | |||
| virtual Status Visit(std::shared_ptr<RenameNode> node, bool *const modified); | |||
| virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified); | |||
| virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified); | |||
| @@ -270,14 +275,14 @@ class Pass : public std::enable_shared_from_this<Pass> { | |||
| class TreePass : public Pass { | |||
| public: | |||
| /// \brief Run the transformation pass against the execution tree. | |||
| /// \param[inout] tree Pointer to the execution tree to be transformed. | |||
| /// \param[inout] modified Indicate if the tree was modified | |||
| /// \param[in,out] tree Pointer to the execution tree to be transformed. | |||
| /// \param[in,out] modified Indicate if the tree was modified | |||
| Status Run(ExecutionTree *tree, bool *const modified) final; | |||
| /// \brief Derived classes may implement the runOnTree function to implement tree transformation. | |||
| /// "modified" flag needs to be set to true if tree is modified during the pass execution. | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \param[in,out] tree The tree to operate on. | |||
| /// \param[in,out] Indicate of the tree was modified. | |||
| /// \return Status The status code returned | |||
| virtual Status RunOnTree(ExecutionTree *tree, bool *const modified) { return Status::OK(); } | |||
| }; | |||
| @@ -305,8 +310,8 @@ class NodePass : public Pass { | |||
| ~NodePass() = default; | |||
| /// \brief Run the transformation pass against the execution tree | |||
| /// \param[inout] tree Pointer to the execution tree to be transformed | |||
| /// \param[inout] modified Indicator if the tree was changed | |||
| /// \param[in,out] tree Pointer to the execution tree to be transformed | |||
| /// \param[in,out] modified Indicator if the tree was changed | |||
| Status Run(ExecutionTree *tree, bool *const modified) final; | |||
| /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down | |||
| @@ -16,207 +16,130 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_lookup_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_merge_op.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/album_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/celeba_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/cifar_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/cache_node.h" | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/clue_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/csv_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||
| #endif | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/engine/datasetops/source/generator_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor | |||
| CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {} | |||
| CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_node_(nullptr), sampler_(nullptr) {} | |||
| // Identifies the subtree below this node as a cached descendant tree. | |||
| Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| // Note that this function will only get called on non-leaf nodes. | |||
| // For leaf nodes, the other Visit with NonMappableSourceNode or MappableSourceNode argument will be called instead. | |||
| Status CacheTransformPass::CachePass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) { | |||
| *modified = false; | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| if (is_caching_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Nested cache operations is not supported!"); | |||
| if (node->IsCached()) { | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| } | |||
| is_caching_ = true; | |||
| return Status::OK(); | |||
| } | |||
| // Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache | |||
| // Resets the tracking of the cache within the tree and assigns the nodes that will be involved in a cache | |||
| // transformation | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) { | |||
| Status CacheTransformPass::CachePass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) { | |||
| *modified = false; | |||
| is_caching_ = false; // We a no longer in a cache subtree. clear the flag. | |||
| if (leaf_op_) { | |||
| MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; | |||
| // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, | |||
| // using base class pointers. | |||
| AddMappableCacheOperators(std::move(leaf_op_), node); | |||
| } else { | |||
| // If there was no leaf_op set, then this is a non-mappable scenario. | |||
| if (sampler_) { | |||
| // Grab the sampler that was saved from the leaf and plug it into the cache op | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; | |||
| if (node->IsCached()) { | |||
| is_caching_ = false; // We a no longer in a cache subtree. clear the flag. | |||
| if (leaf_node_) { | |||
| MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; | |||
| // Assign the leaf node into the transform pass, using move to null our copy of it, | |||
| // and also assign the cached node, using base class pointers. | |||
| // In the cases where cache is directly injected after the leaf node, these two nodes might be the same. | |||
| cache_pairs_.push_back(std::make_pair(std::move(leaf_node_), node)); | |||
| } else { | |||
| // We're a cache op but no sampler was saved from leaf, so create a default sampler | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| sampler_ = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| node->SetSampler(std::move(sampler_)); | |||
| MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; | |||
| // If there was no leaf_node_ set, then this is a non-mappable scenario. | |||
| // We only assign the cached node in this case. | |||
| cached_nodes_.push_back(node); | |||
| } | |||
| // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache | |||
| uint32_t cache_crc = DatasetOp::GenerateCRC(node); | |||
| RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Common code for mappable leaf setup. | |||
| Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for multiple leaf nodes under cache."); | |||
| #ifndef ENABLE_ANDROID | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) { | |||
| if (node->IsCached()) { | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| } | |||
| // If we are a leaf in the caching path, then save this leaf. | |||
| // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true | |||
| // by the other Visit() with DatasetNode argument | |||
| if (is_caching_) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; | |||
| leaf_op_ = std::move(leaf_op); | |||
| MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (leaf_node_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree. | |||
| // Node that sampler for non mappable dataset only works if there is a downstream cache. | |||
| RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_)); | |||
| // If we are a non-mappable source node in a caching tree, then change our config so that it becomes a basic | |||
| // source node that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| RETURN_IF_NOT_OK(node->MakeSimpleProducer()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| // Common code for non mappable leaf setup. | |||
| Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) { | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (is_caching_ && leaf_op_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for multiple leaf nodes under cache."); | |||
| Status CacheTransformPass::CachePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) { | |||
| if (node->IsCached()) { | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| } | |||
| // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf | |||
| // as save it for use by cache op in ascendant tree. | |||
| // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true | |||
| // by the other Visit() with DatasetNode argument | |||
| if (is_caching_) { | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | |||
| MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; | |||
| } else { | |||
| // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can | |||
| // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) | |||
| std::shared_ptr<SamplerRT> sampler_from_leaf; | |||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (leaf_node_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // Set up a sampler here to be used by cache if we are a non-mappable leaf in a caching tree. | |||
| // Node that sampler for non mappable dataset only works if there is a downstream cache. | |||
| RETURN_IF_NOT_OK(node->SetupSamplerForCache(&sampler_)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) { | |||
| if (is_caching_) { | |||
| // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic | |||
| // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) { | |||
| if (node->IsCached()) { | |||
| MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; | |||
| is_caching_ = true; | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *const modified) { | |||
| if (is_caching_) { | |||
| // If we are a ClueOp in a caching tree, then change our config so that it becomes a basic | |||
| // ClueOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *const modified) { | |||
| // Cache might also be injected to the non-leaf node upper in the tree, so is_caching_ might also be set to true | |||
| // by the other Visit() with DatasetNode argument | |||
| if (is_caching_) { | |||
| // If we are a CsvOp in a caching tree, then change our config so that it becomes a basic | |||
| // CsvOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *const modified) { | |||
| if (is_caching_) { | |||
| // If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic | |||
| // TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead. | |||
| node->MakeSimpleProducer(); | |||
| MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; | |||
| // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. | |||
| if (leaf_node_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for multiple leaf nodes under cache."); | |||
| } | |||
| // If we are a leaf in the caching path, then save this leaf | |||
| leaf_node_ = node; | |||
| } | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| #endif | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *const modified) { | |||
| return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| return Status::OK(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) { | |||
| if (is_caching_) { | |||
| Status CacheTransformPass::CachePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) { | |||
| if (node->IsCached() || is_caching_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for MindRecordOp under cache."); | |||
| } | |||
| @@ -226,102 +149,85 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> no | |||
| #ifdef ENABLE_PYTHON | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) { | |||
| if (is_caching_) { | |||
| Status CacheTransformPass::CachePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) { | |||
| if (node->IsCached() || is_caching_) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, | |||
| "There is currently no support for GeneratorOp under cache."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| // Perform leaf node cache transform identification | |||
| Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) { | |||
| return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); | |||
| } | |||
| #endif | |||
| // Assigns the leaf and cache operators that are involved in a cache transformation | |||
| void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<CacheOp> cache_op) { | |||
| cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); | |||
| } | |||
| // constructor | |||
| CacheTransformPass::CacheTransformPass() {} | |||
| // Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations | |||
| Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *const modified) { | |||
| Status CacheTransformPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) { | |||
| MS_LOG(INFO) << "Pre pass: Cache transform pass started."; | |||
| // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will | |||
| // use to execute a transform. | |||
| CachePass cache_pass = CachePass(); | |||
| RETURN_IF_NOT_OK(cache_pass.Run(tree, modified)); | |||
| RETURN_IF_NOT_OK(cache_pass.Run(root_ir, modified)); | |||
| // Execute the transform for non-mappable cache | |||
| for (auto cached_node : cache_pass.cached_nodes()) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Injecting a non-mappable cache node."; | |||
| RETURN_IF_NOT_OK(InjectNonMappableCacheNode(cached_node, cache_pass.sampler())); | |||
| } | |||
| // Then, execute the transform for each pair | |||
| // Execute the transform for mappable cache | |||
| for (auto cache_pair : cache_pass.cache_pairs()) { | |||
| MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; | |||
| RETURN_IF_NOT_OK( | |||
| ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client())); | |||
| MS_LOG(DEBUG) << "Cache transform pass: Injecting a mappable cache node."; | |||
| RETURN_IF_NOT_OK(InjectMappableCacheNode(cache_pair.first, cache_pair.second)); | |||
| } | |||
| MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to execute the cache transformation. | |||
| Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<DatasetOp> cache_op, | |||
| std::shared_ptr<CacheClient> cache_client) { | |||
| // Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was | |||
| // the root node. It is also possible that cache_child == leaf_op | |||
| std::shared_ptr<DatasetOp> cache_child = cache_op->child(0); | |||
| DatasetOp *cache_parent = nullptr; | |||
| cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent | |||
| // Helper function to execute mappable cache transformation. | |||
| // Input: | |||
| // Sampler | |||
| // | | |||
| // LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) | |||
| // | |||
| // Transformed: | |||
| // Sampler --> CacheLookupNode -------------------------> | |||
| // | | | |||
| // | CacheMergeNode | |||
| // | | | |||
| // LeafNode --> OtherNodes --> CachedNode | |||
| Status CacheTransformPass::InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node, | |||
| std::shared_ptr<DatasetNode> cached_node) { | |||
| // Create a cache merge node with defaults | |||
| auto cache_merge_node = std::make_shared<CacheMergeNode>(nullptr, cached_node->GetDatasetCache()); | |||
| // Insert the cache merge node to become the cached_node's parent | |||
| RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_merge_node)); | |||
| // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. | |||
| std::shared_ptr<SamplerRT> leaf_sampler = leaf_op->sampler(); | |||
| // Construct the merge op with defaults | |||
| std::shared_ptr<CacheMergeOp> merge_op; | |||
| CacheMergeOp::Builder merge_builder; | |||
| RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(merge_op)); | |||
| // Construct the cache lookup op with defaults | |||
| std::shared_ptr<CacheLookupOp> cache_lookup_op; | |||
| CacheLookupOp::Builder lookup_builder; | |||
| RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op)); | |||
| // Overwrite the old sampler in this leaf op to become the lookup op | |||
| leaf_op->SetSampler(cache_lookup_op); | |||
| // If the cache had a parent, then go into that parent to remove the cache from it's child list and then | |||
| // replace it with the merge op. | |||
| if (cache_parent != nullptr) { | |||
| RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op)); | |||
| RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op)); | |||
| } else { | |||
| // If we didn't have a parent, then the merge op is the root node | |||
| RETURN_IF_NOT_OK(tree->AssignRoot(merge_op)); | |||
| } | |||
| // Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op. | |||
| // We maintain a local pointer to the old child though. | |||
| RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child)); | |||
| // Connect the merge op | |||
| RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op))); | |||
| RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child))); | |||
| // At this point, the cache op has already had it's children and parents taken away. Calling remove | |||
| // on it at this point will not do any node hookups, and instead set internal fields to invalid. | |||
| RETURN_IF_NOT_OK(cache_op->Remove()); | |||
| std::shared_ptr<SamplerObj> leaf_sampler = leaf_node->Sampler(); | |||
| // Create a cache lookup node with leaf_node's sampler | |||
| auto cache_lookup_node = std::make_shared<CacheLookupNode>(nullptr, leaf_sampler, cached_node->GetDatasetCache()); | |||
| // Insert the cache lookup node as the first child of cache merge node | |||
| RETURN_IF_NOT_OK(cache_merge_node->InsertChildAt(0, cache_lookup_node)); | |||
| // Overwrite the old sampler in this leaf node to become the cache lookup node | |||
| leaf_node->SetSampler(std::static_pointer_cast<SamplerObj>(cache_lookup_node)); | |||
| return Status::OK(); | |||
| } | |||
| // Helper function to execute non-mappable cache transformation. | |||
| // Input: | |||
| // LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) | |||
| // | |||
| // Transformed: | |||
| // Sampler | |||
| // | | |||
| // LeafNode --> OtherNodes --> CachedNode --> CacheNode | |||
| Status CacheTransformPass::InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node, | |||
| std::shared_ptr<SamplerObj> sampler) { | |||
| // Create a cache node using the sampler we saved from the leaf | |||
| auto cache_node = std::make_shared<CacheNode>(nullptr, sampler, cached_node->GetDatasetCache()); | |||
| // Insert the cache node to become the cached_node's parent | |||
| RETURN_IF_NOT_OK(cached_node->InsertAbove(cache_node)); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -20,6 +20,8 @@ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| @@ -32,11 +34,11 @@ class CacheClient; | |||
| /// \class CacheTransformPass cache_transform_pass.h | |||
| /// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching | |||
| /// operations | |||
| class CacheTransformPass : public TreePass { | |||
| class CacheTransformPass : public IRTreePass { | |||
| /// \class CachePass | |||
| /// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache | |||
| /// transformation. It works in conjunction with the CacheTransformPass | |||
| class CachePass : public NodePass { | |||
| class CachePass : public IRNodePass { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] transform_pass Raw pointer back to controlling tree pass | |||
| @@ -47,138 +49,72 @@ class CacheTransformPass : public TreePass { | |||
| /// \brief Identifies the subtree below this node as a cached descendant tree. | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) override; | |||
| /// \brief Resets the tracking of the cache within the tree and assigns the operators that | |||
| /// will be involved in a cache transformation | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override; | |||
| Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \brief Perform non-mappable leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ClueOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CsvOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \brief Perform non-mappable leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<RandomNode> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \brief Perform mappable leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) override; | |||
| #ifdef ENABLE_PYTHON | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) override; | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) override; | |||
| #ifndef ENABLE_ANDROID | |||
| /// \brief Perform leaf node cache transform identifications | |||
| /// \param[in] node The node being visited | |||
| /// \param[inout] modified Indicator if the node was changed at all | |||
| /// \param[in,out] modified Indicator if the node was changed at all | |||
| /// \return Status The status code returned | |||
| Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override; | |||
| Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified) override; | |||
| #endif | |||
| /// \brief Getter | |||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; } | |||
| std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs() { | |||
| return cache_pairs_; | |||
| } | |||
| private: | |||
| /// \brief Common code for mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The status code returned | |||
| Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Common code for non-mappable leaf setup. | |||
| /// \param[in] node The leaf node performing setup work. | |||
| /// \return Status The status code returned | |||
| Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op); | |||
| /// \brief Getter | |||
| std::vector<std::shared_ptr<DatasetNode>> cached_nodes() { return cached_nodes_; } | |||
| /// \brief Assigns the leaf and cache operators that are involved in a cache transformation | |||
| /// \param[in] leaf_op The leaf operator involved in the cache transform | |||
| /// \param[in] cache_op The cache operator involved in the cache transform | |||
| void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op); | |||
| /// \brief Getter | |||
| std::shared_ptr<SamplerObj> sampler() { return sampler_; } | |||
| private: | |||
| bool is_caching_; | |||
| std::shared_ptr<DatasetOp> leaf_op_; | |||
| std::shared_ptr<SamplerRT> sampler_; | |||
| // The two operators that work together to establish the cache transform | |||
| std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_; | |||
| std::shared_ptr<MappableSourceNode> leaf_node_; | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| // The two nodes that work together to establish the cache transform | |||
| std::vector<std::shared_ptr<DatasetNode>> cached_nodes_; | |||
| std::vector<std::pair<std::shared_ptr<MappableSourceNode>, std::shared_ptr<DatasetNode>>> cache_pairs_; | |||
| }; | |||
| public: | |||
| @@ -189,32 +125,46 @@ class CacheTransformPass : public TreePass { | |||
| ~CacheTransformPass() = default; | |||
| /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations | |||
| /// \param[inout] tree The tree to operate on. | |||
| /// \param[inout] Indicate of the tree was modified. | |||
| /// \param[in,out] tree The tree to operate on. | |||
| /// \param[in,out] Indicate of the tree was modified. | |||
| /// \return Status The status code returned | |||
| Status RunOnTree(ExecutionTree *tree, bool *const modified) override; | |||
| Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override; | |||
| private: | |||
| /// \brief Helper function to execute the cache transformation. | |||
| /// \brief Helper function to execute mappable cache transformation. | |||
| /// | |||
| /// Input: | |||
| /// Sampler | |||
| /// | | |||
| /// LeafOp --> OtherOps --> CacheOp | |||
| /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) | |||
| /// | |||
| /// Transformed: | |||
| /// Sampler --> CacheLookupNode -------------------------> | |||
| /// | | | |||
| /// | CacheMergeNode | |||
| /// | | | |||
| /// LeafNode --> OtherNodes --> CachedNode | |||
| /// | |||
| /// \param[in] leaf_node The leaf node in the transform | |||
| /// \param[in] cached_node The node with cache attribute which is involved in the cache transform | |||
| /// \return Status The status code returned | |||
| Status InjectMappableCacheNode(std::shared_ptr<MappableSourceNode> leaf_node, | |||
| std::shared_ptr<DatasetNode> cached_node); | |||
| /// \brief Helper function to execute non-mappable cache transformation. | |||
| /// | |||
| /// Input: | |||
| /// LeafNode --> OtherNodes --> CachedNode (cache_ = DatasetCache) | |||
| /// | |||
| /// Transformed: | |||
| /// Sampler --> CacheLookupOp ----------------> | |||
| /// | | | |||
| /// | MergeOp | |||
| /// | | | |||
| /// LeafOp --> OtherOps --> | |||
| /// Sampler | |||
| /// | | |||
| /// LeafNode --> OtherNodes --> CachedNode --> CacheNode | |||
| /// | |||
| /// \param[in] leaf_op The leaf node in the transform | |||
| /// \param[in] cache_op The cache op in the transform (will get removed) | |||
| /// \param[in] cache_client The cache client | |||
| /// \param[in] cached_node The node with cache attribute which is involved in the cache transform | |||
| /// \param[in] sampler The sampler saved for non-mappable leaf nodes during the CachePass | |||
| /// \return Status The status code returned | |||
| Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op, | |||
| std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client); | |||
| Status InjectNonMappableCacheNode(std::shared_ptr<DatasetNode> cached_node, std::shared_ptr<SamplerObj> sampler); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,7 @@ | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/engine/opt/post/generator_node_pass.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/cache_validation_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/deep_copy_pass.h" | |||
| #include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h" | |||
| @@ -53,6 +54,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) { | |||
| actions.emplace_back(std::make_unique<NodeRemovalPass>()); | |||
| actions.emplace_back(std::make_unique<EpochCtrlPass>()); | |||
| if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>()); | |||
| actions.emplace_back(std::make_unique<CacheTransformPass>()); | |||
| // Vector of flags for each action | |||
| std::vector<bool> modified(actions.size(), false); | |||
| // Apply pre-pass actions | |||
| @@ -35,7 +35,7 @@ namespace dataset { | |||
| // Internal Sampler class forward declaration | |||
| class SamplerRT; | |||
| class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||
| class SamplerObj { | |||
| public: | |||
| /// \brief Constructor | |||
| SamplerObj(); | |||
| @@ -122,7 +122,7 @@ std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement = false, int64_ | |||
| /// Function to create a Sequential Sampler. | |||
| /// \notes Samples the dataset elements sequentially, same as not having a sampler. | |||
| /// \param[in] start_index - Index to start sampling at (dafault to start at first id). | |||
| /// \param[in] start_index - Index to start sampling at (default to start at first id). | |||
| /// \param[in] num_samples - The number of samples to draw (default to all elements). | |||
| /// \return Shared pointer to the current Sampler. | |||
| std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0); | |||
| @@ -465,24 +465,21 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| rc = ccbuilder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op. | |||
| // Rather than manually build this, the way to do it is to choose the position of the cache in the tree by | |||
| // adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and | |||
| // replace it with the required tree structures for cache lookup op and cache merge op. | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); | |||
| std::shared_ptr<CacheLookupOp> myLookupOp; | |||
| rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp); | |||
| std::shared_ptr<CacheMergeOp> myMergeOp; | |||
| rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp); | |||
| std::shared_ptr<ImageFolderOp> so; | |||
| ImageFolderOp::Builder builder; | |||
| builder.SetSampler(std::move(seq_sampler)) | |||
| .SetOpConnectorSize(3) | |||
| builder.SetOpConnectorSize(3) | |||
| .SetNumWorkers(3) | |||
| .SetRowsPerBuffer(2) | |||
| .SetExtensions({".jpg", ".JPEG"}) | |||
| .SetRecursive(true) | |||
| .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); | |||
| rc = builder.Build(&so); | |||
| so->SetSampler(myLookupOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| @@ -495,7 +492,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| rc = myTree->AssociateNode(so); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| rc = myTree->AssociateNode(myLookupOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myMergeOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| @@ -503,9 +502,11 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| rc = myRepeatOp->AddChild(myMergeOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myMergeOp->AddChild(myLookupOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(so); | |||
| rc = myMergeOp->AddChild(so); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->Prepare(1); | |||
| @@ -532,119 +533,3 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { | |||
| rc = myClient->DestroyCache(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| //// Simple test with a repeated cache op over random data producer. | |||
| //// The difference in this one is that you do not add the sampler to the cache op directly. | |||
| //// Instead, the sampler is added as part of the leaf op construction. Then, the prepare | |||
| //// phase will pull this up from the leaf and into the cache. | |||
| //// It removes the sampler from the leaf op, which doesn't make sense there anyway for | |||
| //// the RandomDataOp which doesn't support sampling without a cache. | |||
| //// | |||
| //// RepeatOp | |||
| //// | | |||
| //// CacheOp | |||
| //// | | |||
| //// RandomDataOp | |||
| //// | |||
| TEST_F(MindDataTestCacheOp, DISABLED_TestCacheInheritSampler) { | |||
| // Clear the rc of the master thread if any | |||
| (void)TaskManager::GetMasterThreadRc(); | |||
| Status rc; | |||
| int32_t rank = 0; // not used | |||
| MS_LOG(INFO) << "UT test TestCacheInheritSampler"; | |||
| session_id_type env_session; | |||
| rc = GetSessionFromEnv(&env_session); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index); | |||
| // Start with an empty execution tree | |||
| auto myTree = std::make_shared<ExecutionTree>(); | |||
| // Create a schema using the C api's | |||
| std::unique_ptr<DataSchema> testSchema = std::make_unique<DataSchema>(); | |||
| // 2 columns. First column is an "image" 640,480,3 | |||
| TensorShape c1Shape({640, 480, 3}); | |||
| ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, | |||
| rank, // not used | |||
| &c1Shape); | |||
| // Column 2 will just be a scalar label number | |||
| TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); | |||
| testSchema->AddColumn(c1); | |||
| testSchema->AddColumn(c2); | |||
| // RandomDataOp | |||
| std::shared_ptr<RandomDataOp> myRandomDataOp; | |||
| rc = RandomDataOp::Builder() | |||
| .SetRowsPerBuffer(2) | |||
| .SetNumWorkers(4) | |||
| .SetDataSchema(std::move(testSchema)) | |||
| .SetTotalRows(10) | |||
| .SetSampler(std::move(seq_sampler)) | |||
| .Build(&myRandomDataOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRandomDataOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // CacheOp | |||
| CacheClient::Builder ccbuilder; | |||
| // use arbitrary session of 1, size of 0, spilling// is true | |||
| ccbuilder.SetSessionId(env_session).SetCacheMemSz(4).SetSpill(true); | |||
| std::shared_ptr<CacheClient> myClient; | |||
| rc = ccbuilder.Build(&myClient); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::shared_ptr<CacheOp> myCacheOp; | |||
| rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myCacheOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // RepeatOp | |||
| uint32_t numRepeats = 4; | |||
| std::shared_ptr<RepeatOp> myRepeatOp; | |||
| rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssociateNode(myRepeatOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Assign tree relations and root | |||
| rc = myRepeatOp->AddChild(myCacheOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myCacheOp->AddChild(myRandomDataOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rc = myTree->AssignRoot(myRepeatOp); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| MS_LOG(INFO) << "Launching tree and begin iteration"; | |||
| rc = myTree->Prepare(1); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| std::cout << *myClient << std::endl; | |||
| rc = myTree->Launch(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator dI(myTree); | |||
| TensorRow tensorList; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| int rowCount = 0; | |||
| while (!tensorList.empty()) { | |||
| // Don't display these rows, just count them | |||
| MS_LOG(INFO) << "Row fetched #: " << rowCount; | |||
| rc = dI.FetchNextTensorRow(&tensorList); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| rowCount++; | |||
| } | |||
| ASSERT_EQ(rowCount, 40); | |||
| rc = myClient->DestroyCache(); | |||
| ASSERT_TRUE(rc.IsOk()); | |||
| } | |||
| @@ -315,6 +315,9 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_long_file_list" | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1 | |||
| HandleRcExit $? 0 0 | |||
| for i in $(seq 1 3) | |||
| do | |||
| test_name="test_cache_nomap_multiple_cache${i}" | |||
| @@ -216,7 +216,7 @@ def test_cache_map_failure1(): | |||
| | | |||
| Cache | |||
| | | |||
| ImageFolder | |||
| Coco | |||
| """ | |||
| logger.info("Test cache failure 1") | |||
| @@ -227,8 +227,9 @@ def test_cache_map_failure1(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) | |||
| # This DATA_DIR has 6 images in it | |||
| ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, | |||
| cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| @@ -302,7 +303,7 @@ def test_cache_map_failure3(): | |||
| | | |||
| Batch | |||
| | | |||
| ImageFolder | |||
| Mnist | |||
| """ | |||
| logger.info("Test cache failure 3") | |||
| if "SESSION_ID" in os.environ: | |||
| @@ -312,8 +313,7 @@ def test_cache_map_failure3(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) | |||
| ds1 = ds1.batch(2) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| @@ -342,7 +342,7 @@ def test_cache_map_failure4(): | |||
| | | |||
| Filter | |||
| | | |||
| ImageFolder | |||
| CelebA | |||
| """ | |||
| logger.info("Test cache failure 4") | |||
| @@ -353,8 +353,8 @@ def test_cache_map_failure4(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| # This dataset has 4 records | |||
| ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) | |||
| ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) | |||
| decode_op = c_vision.Decode() | |||
| @@ -382,7 +382,7 @@ def test_cache_map_failure5(): | |||
| | | |||
| Map(decode, randomCrop) | |||
| | | |||
| ImageFolder | |||
| Manifest | |||
| """ | |||
| logger.info("Test cache failure 5") | |||
| @@ -393,8 +393,8 @@ def test_cache_map_failure5(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| data = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| # This dataset has 4 records | |||
| data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) | |||
| random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) | |||
| decode_op = c_vision.Decode() | |||
| @@ -505,7 +505,7 @@ def test_cache_map_failure8(): | |||
| | | |||
| Repeat | |||
| | | |||
| ImageFolder | |||
| Cifar10 | |||
| """ | |||
| logger.info("Test cache failure 8") | |||
| @@ -516,8 +516,7 @@ def test_cache_map_failure8(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.repeat(4) | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||
| @@ -545,7 +544,7 @@ def test_cache_map_failure9(): | |||
| | | |||
| Take | |||
| | | |||
| ImageFolder | |||
| Cifar100 | |||
| """ | |||
| logger.info("Test cache failure 9") | |||
| @@ -556,8 +555,7 @@ def test_cache_map_failure9(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) | |||
| ds1 = ds1.take(2) | |||
| decode_op = c_vision.Decode() | |||
| @@ -587,7 +585,7 @@ def test_cache_map_failure10(): | |||
| | | |||
| Skip | |||
| | | |||
| ImageFolder | |||
| VOC | |||
| """ | |||
| logger.info("Test cache failure 10") | |||
| @@ -598,8 +596,8 @@ def test_cache_map_failure10(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This DATA_DIR only has 2 images in it | |||
| ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) | |||
| # This dataset has 9 records | |||
| ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||
| ds1 = ds1.skip(1) | |||
| decode_op = c_vision.Decode() | |||
| @@ -1913,6 +1913,217 @@ def test_cache_nomap_long_file_list(): | |||
| logger.info("test_cache_nomap_long_file_list Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_failure1(): | |||
| """ | |||
| Test nested cache (failure) | |||
| Repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Cache | |||
| | | |||
| TFRecord | |||
| """ | |||
| logger.info("Test cache nomap failure 1") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| ds1.get_batch_size() | |||
| assert "Nested cache operations" in str(e.value) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert "Nested cache operations" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_nomap_failure1 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_failure2(): | |||
| """ | |||
| Test zip under cache (failure) | |||
| repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Zip | |||
| | | | |||
| Random Random | |||
| """ | |||
| logger.info("Test cache nomap failure 2") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| schema = ds.Schema() | |||
| schema.add_column('image', de_type=mstype.uint8, | |||
| shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) | |||
| schema.add_column('label', de_type=mstype.uint8, shape=[1]) | |||
| ds1 = ds.RandomDataset(schema=schema) | |||
| ds2 = ds.RandomDataset(schema=schema) | |||
| dsz = ds.zip((ds1, ds2)) | |||
| decode_op = c_vision.Decode() | |||
| dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| dsz = dsz.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in dsz.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_nomap_failure2 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_failure3(): | |||
| """ | |||
| Test batch under cache (failure) | |||
| repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(resize) | |||
| | | |||
| Batch | |||
| | | |||
| Clue | |||
| """ | |||
| logger.info("Test cache nomap failure 3") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train') | |||
| ds1 = ds1.batch(2) | |||
| resize_op = c_vision.Resize((224, 224)) | |||
| ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_nomap_failure3 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_failure4(): | |||
| """ | |||
| Test filter under cache (failure) | |||
| repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode) | |||
| | | |||
| Filter | |||
| | | |||
| CSV | |||
| """ | |||
| logger.info("Test cache nomap failure 4") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4']) | |||
| ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) | |||
| decode_op = c_vision.Decode() | |||
| ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) | |||
| ds1 = ds1.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_nomap_failure4 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_failure5(): | |||
| """ | |||
| Test Map with non-deterministic TensorOps under cache (failure) | |||
| repeat | |||
| | | |||
| Cache | |||
| | | |||
| Map(decode, randomCrop) | |||
| | | |||
| TextFile | |||
| """ | |||
| logger.info("Test cache nomap failure 5") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| data = ds.TextFileDataset(TEXT_FILE_DATA_DIR) | |||
| random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) | |||
| decode_op = c_vision.Decode() | |||
| data = data.map(input_columns=["image"], operations=decode_op) | |||
| data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache) | |||
| data = data.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value) | |||
| assert num_iter == 0 | |||
| logger.info('test_cache_nomap_failure5 Ended.\n') | |||
| if __name__ == '__main__': | |||
| test_cache_nomap_basic1() | |||
| test_cache_nomap_basic2() | |||