Merge pull request !2316 from JunhanHu/alter_node_cpptags/v0.5.0-beta
| @@ -487,6 +487,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| (void)builder->SetInColNames(in_col_names); | |||
| } else if (key == "output_columns") { | |||
| (void)builder->SetOutColNames(ToStringVector(value)); | |||
| } else if (key == "columns_order") { | |||
| (void)builder->SetColOrder(ToStringVector(value)); | |||
| } else if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "prefetch_size") { | |||
| @@ -835,6 +837,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset | |||
| (void)builder->SetColumnsToLoad(columns_to_load); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "shuffle_global") { | |||
| (void)builder->SetShuffleGlobal(ToBool(value)); | |||
| } else if (key == "schema_file_path" || key == "schema_json_string") { | |||
| schema_exists = true; | |||
| } else if (key == "num_samples") { | |||
| @@ -1225,6 +1229,8 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "shuffle_global") { | |||
| (void)builder->SetShuffleGlobal(ToBool(value)); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetTotalRows(ToInt(value)); | |||
| } else if (key == "num_shards") { | |||
| @@ -1314,6 +1320,8 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "shuffle_global") { | |||
| (void)builder->SetShuffleGlobal(ToBool(value)); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_shards") { | |||
| @@ -20,6 +20,7 @@ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/datasetops/device_queue_op.h" | |||
| @@ -68,8 +69,45 @@ Status DatasetOp::AddChild(std::shared_ptr<DatasetOp> child) { | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetOp::RemoveChild(std::shared_ptr<DatasetOp> child) { | |||
| if (operator_id_ == kInvalidOperatorId) { | |||
| std::string err_msg( | |||
| "Cannot remove child node. Tree node connections can only" | |||
| "be made if the node belongs to a tree."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| // disallow relationships with other trees | |||
| if (tree_ != child->tree_) { | |||
| std::string err_msg( | |||
| "Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree."); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end()); | |||
| child->RemoveParent(this); | |||
| return Status::OK(); | |||
| } | |||
| Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) { | |||
| for (auto &prev_parent : this->parent_) { | |||
| RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this())); | |||
| RETURN_IF_NOT_OK(prev_parent->AddChild(to_add)); | |||
| } | |||
| RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this())); | |||
| if (tree_->root()->id() == this->id()) { | |||
| tree_->AssignRoot(to_add); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Adds a parent operator to this operator | |||
| void DatasetOp::AddParent(const DatasetOp *parent) { parent_.push_back(parent); } | |||
| void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } | |||
| // Removes a parent operator from this operator | |||
| void DatasetOp::RemoveParent(DatasetOp *parent) { | |||
| parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); | |||
| } | |||
| // Getter function to get a shared pointer to our childAdds a operator to become our child. | |||
| std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const { | |||
| @@ -64,10 +64,19 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| // @param child - shared pointer to the child to add. | |||
| Status AddChild(std::shared_ptr<DatasetOp> child); | |||
| // Remove a operator from our children. | |||
| // @param child - shared pointer to the child to remove. | |||
| Status RemoveChild(std::shared_ptr<DatasetOp> child); | |||
| // Getter function to get a shared pointer to our child | |||
| // @param child_index - An operator can have n children. Indicates choose which child to return. | |||
| std::shared_ptr<DatasetOp> child(int32_t child_index) const; | |||
| // Inserts a operator as the parent current op. | |||
| // Inserted op will become the sole parent of the current op. | |||
| // The existing parent of the current op will be transferred to the inserted op. | |||
| Status InsertAsParent(std::shared_ptr<DatasetOp> to_add); | |||
| // Creates the connector within this operator | |||
| // @param num_producers - number of threads that write into this connector | |||
| // @param num_consumers - number of threads that read from this connector | |||
| @@ -261,7 +270,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| // Adds a parent operator to this operator | |||
| // @notes External callers do not have access to this function. | |||
| // @param parent - The parent node to add | |||
| void AddParent(const DatasetOp *parent); | |||
| void AddParent(DatasetOp *parent); | |||
| // Removes a parent operator from this operator | |||
| // @notes External callers do not have access to this function. | |||
| // @param parent - The parent node to remove | |||
| void RemoveParent(DatasetOp *parent); | |||
| // A helper function for providing an assignment of the column name map. | |||
| // This grabs the map from child 0 and assigns it into this op. | |||
| @@ -270,7 +284,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| Status AssignColMapFromChild(); | |||
| std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | |||
| std::vector<const DatasetOp *> parent_; // Parent nodes. No ownership and read-only | |||
| std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | |||
| int32_t oc_queue_size_; // Capacity for each out_connector_ | |||
| int32_t operator_id_; // Generated id for the node | |||
| ExecutionTree *tree_; // Back pointer to our tree. | |||
| @@ -54,19 +54,20 @@ Status MapOp::Builder::sanityCheck() const { | |||
| Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) { | |||
| RETURN_IF_NOT_OK(sanityCheck()); | |||
| *ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_), | |||
| std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, | |||
| build_perf_mode_); | |||
| std::move(build_tensor_funcs_), std::move(build_col_order_), build_num_workers_, | |||
| build_op_connector_size_, build_perf_mode_); | |||
| return Status::OK(); | |||
| } | |||
| // Constructor of MapOp | |||
| MapOp::MapOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, | |||
| std::vector<std::shared_ptr<TensorOp>> tensor_funcs, int32_t num_workers, int32_t op_connector_size, | |||
| bool perf_mode) | |||
| std::vector<std::shared_ptr<TensorOp>> tensor_funcs, const std::vector<std::string> &columns_order, | |||
| int32_t num_workers, int32_t op_connector_size, bool perf_mode) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| tfuncs_(std::move(tensor_funcs)), | |||
| in_columns_(in_col_names), | |||
| out_columns_(out_col_names), | |||
| columns_order_(columns_order), | |||
| perf_mode_(perf_mode) { | |||
| // If caller didn't specify the out_col_names, assume they are same as the in_columns. | |||
| if (out_columns_.empty() || out_columns_[0].empty()) { | |||
| @@ -93,6 +93,13 @@ class MapOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetColOrder(const std::vector<std::string> &col_order_) { | |||
| build_col_order_ = col_order_; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumWorkers(int32_t num_workers) { | |||
| @@ -123,6 +130,7 @@ class MapOp : public ParallelOp { | |||
| std::vector<std::string> build_in_col_names_; | |||
| std::vector<std::string> build_out_col_names_; | |||
| std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_; | |||
| std::vector<std::string> build_col_order_; | |||
| int32_t build_num_workers_; | |||
| int32_t build_op_connector_size_; | |||
| bool build_perf_mode_; // Default true. | |||
| @@ -137,11 +145,12 @@ class MapOp : public ParallelOp { | |||
| // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). | |||
| // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). | |||
| // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. | |||
| // @param columns_order names A full list of column names (should match the whole dataset view post \p tensorFuncs). | |||
| // @param num_workers The number of worker threads. | |||
| // @param op_connector_size The size of each queue in the connector. | |||
| MapOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, | |||
| std::vector<std::shared_ptr<TensorOp>> tensor_funcs, int32_t num_workers, int32_t op_connector_size, | |||
| bool perf_mode); | |||
| std::vector<std::shared_ptr<TensorOp>> tensor_funcs, const std::vector<std::string> &columns_order, | |||
| int32_t num_workers, int32_t op_connector_size, bool perf_mode); | |||
| // Destructor | |||
| ~MapOp() = default; | |||
| @@ -181,6 +190,10 @@ class MapOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "MapOp"; } | |||
| // Columns order getter | |||
| // @return The post map columns order | |||
| std::vector<std::string> const &ColumnsOrder() const { return columns_order_; } | |||
| private: | |||
| // Local queues where worker threads can pop from. | |||
| // Popping directly from the Connector can block if the previous designated threads haven't pop. | |||
| @@ -202,6 +215,9 @@ class MapOp : public ParallelOp { | |||
| // Indices of the columns to process. | |||
| std::vector<size_t> to_process_indices_; | |||
| // Variable to store the column_order of all columns post tensorOps | |||
| std::vector<std::string> columns_order_; | |||
| // Performance mode is when the main thread creates local queues, pulls databuffers from the previous | |||
| // op's Connector and distributes them to the local queues. Workers pull from the local queues. | |||
| // If this flag is false, each worker pulls directly from the Connector. This use less resources | |||
| @@ -31,7 +31,11 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ClueOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_num_samples_(0), | |||
| builder_shuffle_files_(false), | |||
| builder_shuffle_global_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -62,8 +66,8 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) { | |||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, | |||
| builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, | |||
| builder_device_id_); | |||
| builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_shuffle_global_, | |||
| builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(clue_op->Init()); | |||
| *op = std::move(clue_op); | |||
| @@ -83,7 +87,7 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim | |||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| bool shuffle_files, bool shuffle_global, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_rows_per_shard_(0), | |||
| @@ -94,6 +98,7 @@ ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples | |||
| load_jagged_connector_(true), | |||
| cols_to_keyword_(cols_to_keyword), | |||
| shuffle_files_(shuffle_files), | |||
| shuffle_global_(shuffle_global), | |||
| finished_reading_dataset_(false), | |||
| num_devices_(num_device), | |||
| device_id_(device_id), | |||
| @@ -104,6 +104,13 @@ class ClueOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShuffleGlobal(bool shuffle_global) { | |||
| builder_shuffle_global_ = shuffle_global; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| @@ -132,13 +139,15 @@ class ClueOp : public ParallelOp { | |||
| int32_t builder_worker_connector_size_; | |||
| std::vector<std::string> builder_clue_files_list_; | |||
| bool builder_shuffle_files_; | |||
| bool builder_shuffle_global_; | |||
| std::map<std::string, std::string> builder_cols_to_keyword_; | |||
| }; | |||
| // Constructor of ClueOp | |||
| // @param shuffle_global - whether or not to shuffle the entire dataset. | |||
| ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~ClueOp() = default; | |||
| @@ -169,6 +178,14 @@ class ClueOp : public ParallelOp { | |||
| // @return Status - the error coed returned. | |||
| static Status CountAllFileRows(const std::vector<std::string> &files, int64_t *count); | |||
| // File names getter | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return clue_files_list_; } | |||
| // Global shuffle flag getter | |||
| // @return Bool - whether this Op requires global shuffle | |||
| bool RequireGlobalShuffle() { return shuffle_global_; } | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -248,6 +265,7 @@ class ClueOp : public ParallelOp { | |||
| int32_t device_id_; | |||
| bool shuffle_files_; | |||
| bool shuffle_global_; | |||
| bool finished_reading_dataset_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| @@ -33,7 +33,11 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TextFileOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { | |||
| : builder_device_id_(0), | |||
| builder_num_devices_(1), | |||
| builder_total_rows_(0), | |||
| builder_shuffle_files_(false), | |||
| builder_shuffle_global_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -64,7 +68,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, | |||
| std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, | |||
| builder_num_devices_, builder_device_id_); | |||
| builder_shuffle_global_, builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(text_file_op->Init()); | |||
| *op = std::move(text_file_op); | |||
| @@ -73,7 +77,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| int32_t op_connector_size, bool shuffle_files, bool shuffle_global, int32_t num_device, | |||
| int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| @@ -81,6 +86,7 @@ TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t tot | |||
| total_rows_(total_rows), | |||
| text_files_list_(std::move(text_files_list)), | |||
| shuffle_files_(shuffle_files), | |||
| shuffle_global_(shuffle_global), | |||
| data_schema_(std::move(schema)), | |||
| all_num_rows_(0), | |||
| num_rows_per_shard_(0), | |||
| @@ -105,6 +105,13 @@ class TextFileOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShuffleGlobal(bool shuffle_global) { | |||
| builder_shuffle_global_ = shuffle_global; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetTotalRows(int64_t total_rows) { | |||
| @@ -122,6 +129,7 @@ class TextFileOp : public ParallelOp { | |||
| int32_t builder_worker_connector_size_; | |||
| std::vector<std::string> builder_text_files_list_; | |||
| bool builder_shuffle_files_; | |||
| bool builder_shuffle_global_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| }; | |||
| @@ -135,10 +143,11 @@ class TextFileOp : public ParallelOp { | |||
| // @param op_connector_size - size of each queue in the connector that the child operator pulls from. | |||
| // @param columns_to_load - the names of the columns to load data from. | |||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||
| // @param shuffle_global - whether or not to shuffle the entire dataset. | |||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id); | |||
| // Default destructor | |||
| ~TextFileOp() = default; | |||
| @@ -173,6 +182,14 @@ class TextFileOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "TextFileOp"; } | |||
| // File names getter | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return text_files_list_; } | |||
| // Global shuffle flag getter | |||
| // @return Bool - whether this Op requires global shuffle | |||
| bool RequireGlobalShuffle() { return shuffle_global_; } | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -253,6 +270,7 @@ class TextFileOp : public ParallelOp { | |||
| int64_t total_rows_; | |||
| std::vector<std::string> text_files_list_; | |||
| bool shuffle_files_; | |||
| bool shuffle_global_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| int64_t all_num_rows_; | |||
| int64_t num_rows_per_shard_; | |||
| @@ -56,6 +56,7 @@ TFReaderOp::Builder::Builder() | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| builder_rows_per_buffer_ = config_manager->rows_per_buffer(); | |||
| builder_shuffle_files_ = false; | |||
| builder_shuffle_global_ = false; | |||
| builder_data_schema_ = std::make_unique<DataSchema>(); | |||
| } | |||
| @@ -126,7 +127,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op) | |||
| std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>( | |||
| builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, | |||
| builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, | |||
| builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_); | |||
| builder_shuffle_files_, builder_shuffle_global_, builder_num_devices_, builder_device_id_, | |||
| builder_equal_rows_per_shard_); | |||
| RETURN_IF_NOT_OK(new_tf_reader_op->Init()); | |||
| *out_tf_reader_op = std::move(new_tf_reader_op); | |||
| @@ -136,8 +138,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op) | |||
| TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, | |||
| int64_t total_num_rows, std::vector<std::string> dataset_files_list, | |||
| std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | |||
| std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | |||
| int32_t device_id, bool equal_rows_per_shard) | |||
| std::vector<std::string> columns_to_load, bool shuffle_files, bool shuffle_global, | |||
| int32_t num_device, int32_t device_id, bool equal_rows_per_shard) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| @@ -147,6 +149,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 | |||
| columns_to_load_(std::move(columns_to_load)), | |||
| finished_reading_dataset_(false), | |||
| shuffle_files_(shuffle_files), | |||
| shuffle_global_(shuffle_global), | |||
| data_schema_(std::move(data_schema)), | |||
| filename_index_(std::make_unique<StringIndex>()), | |||
| load_io_block_queue_(true), | |||
| @@ -172,7 +175,8 @@ void TFReaderOp::Print(std::ostream &out, bool show_all) const { | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ | |||
| << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") | |||
| << "\nDataset files list:\n"; | |||
| << "\nShuffle global: " << ((shuffle_global_) ? "yes" : "no") | |||
| << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n"; | |||
| for (int i = 0; i < dataset_files_list_.size(); ++i) { | |||
| out << " " << dataset_files_list_[i]; | |||
| } | |||
| @@ -217,7 +221,6 @@ Status TFReaderOp::Init() { | |||
| // temporary: make size large enough to hold all files + EOE to avoid hangs | |||
| int32_t safe_queue_size = static_cast<int32_t>(std::ceil(dataset_files_list_.size() / num_workers_)) + 1; | |||
| io_block_queues_.Init(num_workers_, safe_queue_size); | |||
| dataset_files_list_.clear(); // no longer need the original list of files | |||
| return Status::OK(); | |||
| } | |||
| @@ -146,6 +146,13 @@ class TFReaderOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShuffleGlobal(bool shuffle_global) { | |||
| builder_shuffle_global_ = shuffle_global; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetShardEqualRows(bool shard_equal_rows) { | |||
| @@ -165,6 +172,7 @@ class TFReaderOp : public ParallelOp { | |||
| std::vector<std::string> builder_dataset_files_list_; | |||
| std::vector<std::string> builder_columns_to_load_; | |||
| bool builder_shuffle_files_; | |||
| bool builder_shuffle_global_; | |||
| bool builder_equal_rows_per_shard_; | |||
| }; | |||
| @@ -179,11 +187,12 @@ class TFReaderOp : public ParallelOp { | |||
| // @param op_connector_size - size of each queue in the connector that the child operator pulls from. | |||
| // @param columns_to_load - the names of the columns to load data from. | |||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||
| // @param shuffle_global - whether or not to shuffle the entire dataset. | |||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||
| TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | |||
| std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, | |||
| int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, | |||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); | |||
| bool shuffle_global, int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); | |||
| // Default destructor | |||
| ~TFReaderOp() = default; | |||
| @@ -232,6 +241,14 @@ class TFReaderOp : public ParallelOp { | |||
| // @return Name of the current Op | |||
| std::string Name() const override { return "TFReaderOp"; } | |||
| // File names getter | |||
| // @return Vector of the input file names | |||
| std::vector<std::string> FileNames() { return dataset_files_list_; } | |||
| // Global shuffle flag getter | |||
| // @return Bool - whether this Op requires global shuffle | |||
| bool RequireGlobalShuffle() { return shuffle_global_; } | |||
| private: | |||
| // The entry point for when workers are launched. | |||
| // @param worker_id - the id of the worker that is executing this function. | |||
| @@ -372,6 +389,7 @@ class TFReaderOp : public ParallelOp { | |||
| std::vector<std::string> columns_to_load_; | |||
| bool finished_reading_dataset_; | |||
| bool shuffle_files_; | |||
| bool shuffle_global_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::unique_ptr<StringIndex> filename_index_; | |||
| bool load_io_block_queue_; | |||
| @@ -19,6 +19,8 @@ | |||
| #include "dataset/engine/datasetops/dataset_op.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| #include "dataset/util/task_manager.h" | |||
| #include "dataset/engine/opt/pre/map_column_reorder.h" | |||
| #include "dataset/engine/opt/pre/global_shuffle.h" | |||
| #include "dataset/engine/perf/profiling.h" | |||
| #include "dataset/engine/perf/monitor.h" | |||
| @@ -79,8 +81,6 @@ Status ExecutionTree::AssignRoot(const std::shared_ptr<DatasetOp> &op) { | |||
| // Then add it as the root. | |||
| root_ = op; | |||
| // The tree has an assigned root now and it's ready to be prepared. | |||
| tree_state_ = kDeTStatePrepare; | |||
| return Status::OK(); | |||
| } | |||
| @@ -207,9 +207,24 @@ Status ExecutionTree::Prepare() { | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); } | |||
| Status ExecutionTree::PrepareTreePreAction() { | |||
| bool modified = false; | |||
| std::vector<Pass *> pre_actions; | |||
| // Construct pre actions | |||
| pre_actions.push_back(new MapColumnReorder()); | |||
| pre_actions.push_back(new GlobalShufflePass()); | |||
| // Apply pre action passes | |||
| for (auto &pass : pre_actions) { | |||
| RETURN_IF_NOT_OK(pass->Run(this, &modified)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); } | |||
| Status ExecutionTree::PrepareTreePostAction() { | |||
| // The tree is ready to be prepared. | |||
| tree_state_ = kDeTStatePrepare; | |||
| return Status::OK(); | |||
| } | |||
| Status ExecutionTree::Optimize() { | |||
| // auto pp = new PrinterPass(); | |||
| @@ -2,5 +2,7 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-opt OBJECT | |||
| pass.cc | |||
| pre/map_column_reorder.cc | |||
| pre/global_shuffle.cc | |||
| util/printer_pass.cc | |||
| ) | |||
| @@ -37,10 +37,18 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Driver method for TreePass | |||
| Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); } | |||
| Status TreePass::Run(ExecutionTree *tree, bool *modified) { | |||
| if (!tree || !modified) { | |||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); | |||
| } | |||
| return this->RunOnTree(tree, modified); | |||
| } | |||
| // Driver method for NodePass | |||
| Status NodePass::Run(ExecutionTree *tree, bool *modified) { | |||
| if (!tree || !modified) { | |||
| return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); | |||
| } | |||
| std::shared_ptr<DatasetOp> root = tree->root(); | |||
| if (traversalOrder_ == Order::DFS) { | |||
| // DFS | |||
| @@ -57,10 +57,10 @@ class ImageFolderOp; | |||
| // The actual implementation of the passes will be derived from here. | |||
| class Pass : public std::enable_shared_from_this<Pass> { | |||
| public: | |||
| // Run the transformation pass again the execution tree. | |||
| // Run the transformation pass against the execution tree. | |||
| // @param tree - Pointer to the execution tree to be transformed. | |||
| // @param modified - Pointer to the modified flag, | |||
| virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); } | |||
| virtual Status Run(ExecutionTree *tree, bool *modified) = 0; | |||
| }; | |||
| // TreePass is a basic Pass class which performs transformation on ExecutionTree directly. | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "dataset/engine/opt/pre/global_shuffle.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/datasetops/shuffle_op.h" | |||
| #include "dataset/engine/datasetops/source/tf_reader_op.h" | |||
| #include "dataset/engine/datasetops/source/text_file_op.h" | |||
| #include "dataset/engine/datasetops/source/clue_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status GlobalShufflePass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| std::vector<std::shared_ptr<TFReaderOp>> tf_readers; | |||
| std::vector<std::shared_ptr<TextFileOp>> text_files; | |||
| std::vector<std::shared_ptr<ClueOp>> clues; | |||
| // Pass 1, search for all sources which requires global shuffle | |||
| for (auto &op : *tree) { | |||
| if (auto ptr = std::dynamic_pointer_cast<TFReaderOp>(op.shared_from_this())) { | |||
| if (ptr->RequireGlobalShuffle()) { | |||
| tf_readers.push_back(ptr); | |||
| continue; | |||
| } | |||
| } | |||
| if (auto ptr = std::dynamic_pointer_cast<TextFileOp>(op.shared_from_this())) { | |||
| if (ptr->RequireGlobalShuffle()) { | |||
| text_files.push_back(ptr); | |||
| continue; | |||
| } | |||
| } | |||
| if (auto ptr = std::dynamic_pointer_cast<ClueOp>(op.shared_from_this())) { | |||
| if (ptr->RequireGlobalShuffle()) { | |||
| clues.push_back(ptr); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| // Pass 2, insert shuffle nodes | |||
| // The following blocks can be implemented with template if we unify the CountTotalRows across all source nodes . | |||
| for (auto node : tf_readers) { | |||
| std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>(); | |||
| int64_t total_rows = 0; | |||
| TFReaderOp::CountTotalRows(&total_rows, node->FileNames(), 8, true); | |||
| int32_t avg_file_size = total_rows / (node->FileNames().size()); | |||
| builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); | |||
| std::shared_ptr<ShuffleOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(op)); | |||
| RETURN_IF_NOT_OK(node->InsertAsParent(op)); | |||
| } | |||
| for (auto node : text_files) { | |||
| std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>(); | |||
| int64_t total_rows = 0; | |||
| TextFileOp::CountAllFileRows(node->FileNames(), &total_rows); | |||
| int32_t avg_file_size = total_rows / (node->FileNames().size()); | |||
| builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); | |||
| std::shared_ptr<ShuffleOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(op)); | |||
| RETURN_IF_NOT_OK(node->InsertAsParent(op)); | |||
| } | |||
| for (auto node : clues) { | |||
| std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>(); | |||
| int64_t total_rows = 0; | |||
| ClueOp::CountAllFileRows(node->FileNames(), &total_rows); | |||
| int32_t avg_file_size = total_rows / (node->FileNames().size()); | |||
| builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); | |||
| std::shared_ptr<ShuffleOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(op)); | |||
| RETURN_IF_NOT_OK(node->InsertAsParent(op)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H | |||
| #include <memory> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Global Shuffle Pass will insert ShuffleOp when the leaf nodes requires global shuffle. | |||
| // Example: | |||
| // Input Tree: TFReader(GLOBAL_SHUFFLE) -> Batch | |||
| // Output Tree: TFReader -> Shuffle -> Batch | |||
| class GlobalShufflePass : public TreePass { | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/engine/opt/pre/map_column_reorder.h" | |||
| #include "dataset/engine/execution_tree.h" | |||
| #include "dataset/engine/datasetops/map_op.h" | |||
| #include "dataset/engine/datasetops/project_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status MapColumnReorder::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| std::vector<std::shared_ptr<MapOp>> to_process; | |||
| // Pass 1, search for all MapOp with column orders | |||
| for (auto &op : *tree) { | |||
| if (auto mapOp = std::dynamic_pointer_cast<MapOp>(op.shared_from_this())) { | |||
| if (mapOp->ColumnsOrder().size() != 0) { | |||
| to_process.push_back(mapOp); | |||
| } | |||
| } | |||
| } | |||
| // Pass 2, insert nodes for all MapOp | |||
| for (auto node : to_process) { | |||
| std::shared_ptr<ProjectOp::Builder> builder = std::make_shared<ProjectOp::Builder>(node->ColumnsOrder()); | |||
| std::shared_ptr<ProjectOp> op; | |||
| RETURN_IF_NOT_OK(builder->Build(&op)); | |||
| RETURN_IF_NOT_OK(tree->AssociateNode(op)); | |||
| RETURN_IF_NOT_OK(node->InsertAsParent(op)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H | |||
| #define DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H | |||
| #include <memory> | |||
| #include "dataset/engine/opt/pass.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Map Column Recorder Pass will insert ProjectOp when MapOp requires a full output columns reorder. | |||
| // Example: | |||
| // Input Tree: TFReader -> MapOp(with col_order) -> Batch | |||
| // Output Tree: TFReader -> MapOp -> ProjectOp(col_order) -> Batch | |||
| class MapColumnReorder : public TreePass { | |||
| Status RunOnTree(ExecutionTree *tree, bool *modified) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H | |||
| @@ -1892,6 +1892,7 @@ class MapDataset(DatasetOp): | |||
| args["input_columns"] = self.input_columns | |||
| args["operations"] = self.operations | |||
| args["output_columns"] = self.output_columns | |||
| args["columns_order"] = self.columns_order | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -3281,6 +3282,7 @@ class TFRecordDataset(SourceDataset): | |||
| args["num_samples"] = self.num_samples | |||
| if self.shuffle_files is not None: | |||
| args["shuffle_files"] = self.shuffle_files | |||
| args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| @@ -4589,6 +4591,7 @@ class CLUEDataset(SourceDataset): | |||
| args["num_samples"] = self.num_samples | |||
| if self.shuffle_files is not None: | |||
| args["shuffle_files"] = self.shuffle_files | |||
| args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| @@ -4679,6 +4682,7 @@ class TextFileDataset(SourceDataset): | |||
| args["num_samples"] = self.num_samples | |||
| if self.shuffle_files is not None: | |||
| args["shuffle_files"] = self.shuffle_files | |||
| args["shuffle_global"] = (self.shuffle_level == Shuffle.GLOBAL) | |||
| args["shuffle"] = self.shuffle_level | |||
| args["num_shards"] = self.num_shards | |||
| args["shard_id"] = self.shard_id | |||
| @@ -49,33 +49,13 @@ def alter_tree(node): | |||
| def _alter_node(node): | |||
| """Performing some alteration to a dataset node. A common alteration is to insert a node.""" | |||
| if isinstance(node, (de.TFRecordDataset, de.TextFileDataset, de.CLUEDataset)) \ | |||
| and node.shuffle_level == de.Shuffle.GLOBAL: | |||
| # Remove the connection between the parent's node to the current node because we are inserting a node. | |||
| if node.output: | |||
| node.output.pop() | |||
| # Perform a fast scan for average rows per file | |||
| if isinstance(node, de.TFRecordDataset): | |||
| avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) | |||
| else: | |||
| avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files) | |||
| # Shuffle between 4 files with a minimum size of 10000 rows | |||
| new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000)) | |||
| return new_shuffle | |||
| """DEPRECATED""" | |||
| # Please check ccsrc/dataset/engine/opt for tree transformation. | |||
| if isinstance(node, de.MapDataset): | |||
| if node.python_multiprocessing: | |||
| # Bootstrap can only be performed on a copy of the original dataset node. | |||
| # Bootstrap on original dataset node will make all iterators share the same process pool | |||
| node.iterator_bootstrap() | |||
| if node.columns_order is not None: | |||
| # Remove the connection between the parent's node to the current node because we are inserting a node. | |||
| if node.output: | |||
| node.output.pop() | |||
| return node.project(node.columns_order) | |||
| return node | |||
| @@ -51,6 +51,7 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) { | |||
| ASSERT_NE(my_tfreader_op, nullptr); | |||
| parent_op->AddChild(std::move(my_tfreader_op)); | |||
| MS_LOG(INFO) << parent_op; | |||
| my_tree->AssignRoot(parent_op); | |||
| my_tree->Prepare(); | |||
| RepeatOp RepeatOpOp(); | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| def test_map_reorder_pass_0(): | |||
| def generator_mc(maxid=1): | |||
| for _ in range(maxid): | |||
| yield (np.array([0]), np.array([1])) | |||
| # Generator -> Map | |||
| data0 = ds.GeneratorDataset(generator_mc, ["col0", "col1"]) | |||
| data0 = data0.map(input_columns="col0", output_columns="out", columns_order=["col1", "out"], | |||
| operations=(lambda x: x)) | |||
| for item in data0.create_tuple_iterator(): # each data is a dictionary | |||
| assert item == [np.array(1), np.array(0)] | |||
| def test_map_reorder_pass_1(): | |||
| def generator_mc(maxid=1): | |||
| for _ in range(maxid): | |||
| yield (np.array([0]), np.array([1]), np.array([2])) | |||
| # Three map and zip | |||
| data0 = ds.GeneratorDataset(generator_mc, ["a0", "a1", "a2"]) | |||
| data0 = data0.map(input_columns="a0", columns_order=["a2", "a1", "a0"], operations=(lambda x: x)) | |||
| data1 = ds.GeneratorDataset(generator_mc, ["b0", "b1", "b2"]) | |||
| data1 = data1.map(input_columns="b0", columns_order=["b1", "b2", "b0"], operations=(lambda x: x)) | |||
| data2 = ds.zip((data0, data1)) | |||
| data2 = data2.map(input_columns="a0", columns_order=["b2", "a2", "b1", "a1", "b0", "a0"], operations=(lambda x: x)) | |||
| for item in data2.create_tuple_iterator(): | |||
| assert item == [np.array(2), np.array(2), np.array(1), np.array(1), np.array(0), np.array(0)] | |||
| def test_global_shuffle_pass(): | |||
| FILES = ["../data/dataset/testTFTestAllTypes/test.data"] | |||
| SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" | |||
| ds.config.set_seed(1) | |||
| data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) | |||
| data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||
| data2 = data2.shuffle(10000) | |||
| for d1, d2 in zip(data1, data2): | |||
| for t1, t2 in zip(d1, d2): | |||
| assert np.array_equal(t1, t2) | |||
| ds.config.set_seed(1) | |||
| DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" | |||
| data1 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) | |||
| data2 = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) | |||
| data2 = data2.shuffle(10000) | |||
| for d1, d2 in zip(data1, data2): | |||
| for t1, t2 in zip(d1, d2): | |||
| assert np.array_equal(t1, t2) | |||
| ds.config.set_seed(1) | |||
| TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| data1 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.GLOBAL) | |||
| data2 = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=ds.Shuffle.FILES) | |||
| data2 = data2.shuffle(10000) | |||
| for d1, d2 in zip(data1, data2): | |||
| for t1, t2 in zip(d1, d2): | |||
| assert np.array_equal(t1, t2) | |||
| if __name__ == "__main__": | |||
| test_map_reorder_pass_0() | |||
| test_map_reorder_pass_1() | |||
| test_global_shuffle_pass() | |||