| @@ -128,7 +128,7 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) { | |||||
| Status ConcatOp::PrepareNodePostAction() { | Status ConcatOp::PrepareNodePostAction() { | ||||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | ||||
| tree_->AddToRepeatStack(shared_from_this()); | |||||
| tree_->AddToEOEOpStack(shared_from_this()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -18,23 +18,26 @@ | |||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <regex> | |||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "dataset/engine/execution_tree.h" | #include "dataset/engine/execution_tree.h" | ||||
| #include "dataset/engine/datasetops/device_queue_op.h" | #include "dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "dataset/engine/datasetops/source/sampler/sampler.h" | |||||
| #include "dataset/engine/data_buffer.h" | #include "dataset/engine/data_buffer.h" | ||||
| #include "dataset/engine/db_connector.h" | #include "dataset/engine/db_connector.h" | ||||
| #include "dataset/engine/opt/pass.h" | #include "dataset/engine/opt/pass.h" | ||||
| #include "utils/system/crc32c.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| DatasetOp::DatasetOp(int32_t op_connector_size) | |||||
| DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||||
| : oc_queue_size_(op_connector_size), | : oc_queue_size_(op_connector_size), | ||||
| sampler_(sampler), | |||||
| operator_id_(kInvalidOperatorId), | operator_id_(kInvalidOperatorId), | ||||
| tree_(nullptr), | tree_(nullptr), | ||||
| state_(OpState::kDeOpIdle), | state_(OpState::kDeOpIdle), | ||||
| @@ -150,6 +153,9 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex | out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex | ||||
| << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); | << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); | ||||
| if (sampler_) { | |||||
| sampler_->Print(out, show_all); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -222,11 +228,10 @@ Status DatasetOp::PrepareNodePreAction() { | |||||
| Status DatasetOp::PrepareNodePostAction() { | Status DatasetOp::PrepareNodePostAction() { | ||||
| // If this op does not have any children and it is in a repeat path of the tree... | // If this op does not have any children and it is in a repeat path of the tree... | ||||
| if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { | if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { | ||||
| // push ourselves onto the tree repeat stack. Later, the repeat operator | |||||
| // push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator | |||||
| // above us will consume them. | // above us will consume them. | ||||
| tree_->AddToRepeatStack(shared_from_this()); | |||||
| tree_->AddToEOEOpStack(shared_from_this()); | |||||
| } | } | ||||
| // Creating Connector object for each op. | // Creating Connector object for each op. | ||||
| // The consumer of the root node is assumed to be one thread. | // The consumer of the root node is assumed to be one thread. | ||||
| // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. | // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. | ||||
| @@ -289,5 +294,56 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) { | |||||
| // This method will only be called if its derived class does not implement one. | // This method will only be called if its derived class does not implement one. | ||||
| return p->RunOnNode(shared_from_this(), modified); | return p->RunOnNode(shared_from_this(), modified); | ||||
| } | } | ||||
| // A helper function with some common code that leaf nodes can use during | |||||
| // prepare phase for checking if they need to assign a sampler to the cache. | |||||
| // @return - Status | |||||
| Status DatasetOp::SaveSamplerForCache(bool random_access_op) { | |||||
| // If we are a descendant under a cache op and we have a sampler, then save this sampler | |||||
| // to a stack so that the cache can pick it up during it's processing above us. | |||||
| if (sampler_) { | |||||
| if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { | |||||
| // use move semantic to set our sampler_ to null after the move. This is okay because a sampler is | |||||
| // useless to a random data op. It was only being used as a temporary holding until the cache can | |||||
| // be created | |||||
| tree_->AddToSamplerStack(sampler_); | |||||
| MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling."; | |||||
| } else if (!random_access_op) { | |||||
| // A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf. | |||||
| // This is an error because that type of leaf does not use sampling unless there's a cache to hook it into. | |||||
| return Status( | |||||
| StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||||
| "Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree"); | |||||
| } | |||||
| } | |||||
| if (!random_access_op) { | |||||
| // Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache | |||||
| // we can remove it now from the base. | |||||
| sampler_.reset(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) { | |||||
| std::stringstream ss; | |||||
| op->tree_->Print(ss, op); | |||||
| std::string ss_str = ss.str(); | |||||
| // Filter out the Operator control flags field when generating the check sum | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); | |||||
| // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); | |||||
| // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same | |||||
| // cache_client later. So we filter out these two fields to allow cache sharing. | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); | |||||
| ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); | |||||
| uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); | |||||
| return cache_crc; | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,6 +34,8 @@ class DataBuffer; | |||||
| class NodePass; | class NodePass; | ||||
| class Sampler; | |||||
| // The base class DatasetOp is the main tree node. It is an abstract class, so | // The base class DatasetOp is the main tree node. It is an abstract class, so | ||||
| // the actual implementation of the operators will be derived from here. | // the actual implementation of the operators will be derived from here. | ||||
| class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | ||||
| @@ -55,7 +57,8 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // Constructor | // Constructor | ||||
| // @param op_connector_size - The size for the output connector of this operator. | // @param op_connector_size - The size for the output connector of this operator. | ||||
| explicit DatasetOp(int32_t op_connector_size); | |||||
| // @param sampler - The sampler for the op | |||||
| explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler); | |||||
| // Destructor | // Destructor | ||||
| virtual ~DatasetOp() { tree_ = nullptr; } | virtual ~DatasetOp() { tree_ = nullptr; } | ||||
| @@ -204,6 +207,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // @return Sets the control flags | // @return Sets the control flags | ||||
| void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } | void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } | ||||
| // Setter function | |||||
| // @return Sets the control flags | |||||
| void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } | |||||
| // Register the internal worker connectors. No op unless it is a parallel op | // Register the internal worker connectors. No op unless it is a parallel op | ||||
| // @return Status | // @return Status | ||||
| virtual Status RegisterWorkerConnectors() { return Status::OK(); } | virtual Status RegisterWorkerConnectors() { return Status::OK(); } | ||||
| @@ -271,6 +278,13 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // @return Pointer to the ExecutionTree the current op belongs to, no ownership | // @return Pointer to the ExecutionTree the current op belongs to, no ownership | ||||
| ExecutionTree *Tree() { return tree_; } | ExecutionTree *Tree() { return tree_; } | ||||
| // Getter for the sampler | |||||
| // @return Shared pointer to the sampler (may return nullptr) | |||||
| std::shared_ptr<Sampler> sampler() { return sampler_; } | |||||
| // Computes a CRC value for the operator | |||||
| static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op); | |||||
| protected: | protected: | ||||
| // Adds a parent operator to this operator | // Adds a parent operator to this operator | ||||
| // @notes External callers do not have access to this function. | // @notes External callers do not have access to this function. | ||||
| @@ -289,8 +303,15 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| // @return - Status | // @return - Status | ||||
| virtual Status ComputeColMap(); | virtual Status ComputeColMap(); | ||||
| // A helper function with some common code that leaf nodes can use during | |||||
| // prepare phase for checking if they need to assign a sampler to the cache. | |||||
| // @param random_access_op - indicate if this is a mappable random access leaf or not | |||||
| // @return - Status | |||||
| Status SaveSamplerForCache(bool random_access_op); | |||||
| std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes | ||||
| std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | std::vector<DatasetOp *> parent_; // Parent nodes. No ownership | ||||
| std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler | |||||
| int32_t oc_queue_size_; // Capacity for each out_connector_ | int32_t oc_queue_size_; // Capacity for each out_connector_ | ||||
| int32_t operator_id_; // Generated id for the node | int32_t operator_id_; // Generated id for the node | ||||
| ExecutionTree *tree_; // Back pointer to our tree. | ExecutionTree *tree_; // Back pointer to our tree. | ||||
| @@ -100,7 +100,7 @@ void MapOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| out << "\n TensorOps:"; | out << "\n TensorOps:"; | ||||
| for (size_t i = 0; i < tfuncs_.size(); i++) { | for (size_t i = 0; i < tfuncs_.size(); i++) { | ||||
| out << " " << tfuncs_[i]; | |||||
| out << " " << *(tfuncs_[i].get()); | |||||
| } | } | ||||
| out << "\n\n"; | out << "\n\n"; | ||||
| } | } | ||||
| @@ -26,8 +26,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size) | |||||
| : DatasetOp(op_connector_size), | |||||
| ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||||
| : DatasetOp(op_connector_size, sampler), | |||||
| num_workers_(num_workers), | num_workers_(num_workers), | ||||
| num_producers_(num_workers), | num_producers_(num_workers), | ||||
| worker_connector_size_(1), | worker_connector_size_(1), | ||||
| @@ -38,7 +38,8 @@ class ParallelOp : public DatasetOp { | |||||
| // Constructor | // Constructor | ||||
| // @param num_workers | // @param num_workers | ||||
| // @param op_connector_size - size of the output connector for this operator | // @param op_connector_size - size of the output connector for this operator | ||||
| ParallelOp(int32_t num_workers, int32_t op_connector_size); | |||||
| // @param sampler - The sampler for the op | |||||
| ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr); | |||||
| // Destructor | // Destructor | ||||
| ~ParallelOp() = default; | ~ParallelOp() = default; | ||||
| @@ -20,7 +20,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor | // Constructor | ||||
| PipelineOp::PipelineOp(int32_t op_connector_size) : DatasetOp(op_connector_size) {} | |||||
| PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler) | |||||
| : DatasetOp(op_connector_size, sampler) {} | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void PipelineOp::Print(std::ostream &out, bool show_all) const { | void PipelineOp::Print(std::ostream &out, bool show_all) const { | ||||
| @@ -32,7 +32,8 @@ class PipelineOp : public DatasetOp { | |||||
| // Constructor | // Constructor | ||||
| // @param op_connector_size - size of the output connector | // @param op_connector_size - size of the output connector | ||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| explicit PipelineOp(int32_t op_connector_size); | |||||
| // @param sampler - The sampler for the op | |||||
| explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr); | |||||
| // Destructor | // Destructor | ||||
| ~PipelineOp() = default; | ~PipelineOp() = default; | ||||
| @@ -82,14 +82,14 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { | |||||
| Status RepeatOp::PrepareNodePostAction() { | Status RepeatOp::PrepareNodePostAction() { | ||||
| // Run any common code from super class first before adding our own specific logic | // Run any common code from super class first before adding our own specific logic | ||||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | ||||
| std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack(); | |||||
| std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromEOEOpStack(); | |||||
| while (leaf_op != nullptr) { | while (leaf_op != nullptr) { | ||||
| // Track the leaf operators that are under this repeat op. | // Track the leaf operators that are under this repeat op. | ||||
| eoe_ops_.push_back(leaf_op); | eoe_ops_.push_back(leaf_op); | ||||
| leaf_op = tree_->PopFromRepeatStack(); | |||||
| leaf_op = tree_->PopFromEOEOpStack(); | |||||
| } | } | ||||
| // Push ourselves to the stack in case one of our ascendants is repeat too. | // Push ourselves to the stack in case one of our ascendants is repeat too. | ||||
| tree_->AddToRepeatStack(shared_from_this()); | |||||
| tree_->AddToEOEOpStack(shared_from_this()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -70,13 +70,12 @@ Status CelebAOp::Builder::SanityCheck() { | |||||
| CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | ||||
| bool decode, const std::string &dataset_type, const std::set<std::string> &exts, | bool decode, const std::string &dataset_type, const std::set<std::string> &exts, | ||||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler) | std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_workers, queue_size), | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| folder_path_(dir), | folder_path_(dir), | ||||
| decode_(decode), | decode_(decode), | ||||
| extensions_(exts), | extensions_(exts), | ||||
| data_schema_(std::move(schema)), | data_schema_(std::move(schema)), | ||||
| sampler_(std::move(sampler)), | |||||
| num_rows_in_attr_file_(0), | num_rows_in_attr_file_(0), | ||||
| dataset_type_(dataset_type) { | dataset_type_(dataset_type) { | ||||
| attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size); | attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size); | ||||
| @@ -221,7 +221,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| bool decode_; | bool decode_; | ||||
| std::set<std::string> extensions_; // extensions allowed | std::set<std::string> extensions_; // extensions allowed | ||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_; | std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_; | ||||
| int64_t num_rows_in_attr_file_; // rows number specified in attr file | int64_t num_rows_in_attr_file_; // rows number specified in attr file | ||||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | ||||
| @@ -79,12 +79,11 @@ Status CifarOp::Builder::SanityCheck() { | |||||
| CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, | CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, | ||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_works, queue_size), | |||||
| : ParallelOp(num_works, queue_size, std::move(sampler)), | |||||
| cifar_type_(type), | cifar_type_(type), | ||||
| rows_per_buffer_(rows_per_buf), | rows_per_buffer_(rows_per_buf), | ||||
| folder_path_(file_dir), | folder_path_(file_dir), | ||||
| data_schema_(std::move(data_schema)), | data_schema_(std::move(data_schema)), | ||||
| sampler_(std::move(sampler)), | |||||
| row_cnt_(0), | row_cnt_(0), | ||||
| buf_cnt_(0) { | buf_cnt_(0) { | ||||
| constexpr uint64_t kUtilQueueSize = 512; | constexpr uint64_t kUtilQueueSize = 512; | ||||
| @@ -216,7 +216,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| std::string folder_path_; | std::string folder_path_; | ||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| int64_t row_cnt_; | int64_t row_cnt_; | ||||
| int64_t buf_cnt_; | int64_t buf_cnt_; | ||||
| @@ -65,7 +65,7 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str | |||||
| bool recursive, bool do_decode, const std::set<std::string> &exts, | bool recursive, bool do_decode, const std::set<std::string> &exts, | ||||
| const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema, | const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler) | std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_wkrs, queue_size), | |||||
| : ParallelOp(num_wkrs, queue_size, std::move(sampler)), | |||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| folder_path_(file_dir), | folder_path_(file_dir), | ||||
| recursive_(recursive), | recursive_(recursive), | ||||
| @@ -73,7 +73,6 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str | |||||
| extensions_(exts), | extensions_(exts), | ||||
| class_index_(map), | class_index_(map), | ||||
| data_schema_(std::move(data_schema)), | data_schema_(std::move(data_schema)), | ||||
| sampler_(std::move(sampler)), | |||||
| row_cnt_(0), | row_cnt_(0), | ||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| sampler_ind_(0), | sampler_ind_(0), | ||||
| @@ -259,7 +259,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||||
| std::set<std::string> extensions_; // extensions allowed | std::set<std::string> extensions_; // extensions allowed | ||||
| std::map<std::string, int32_t> class_index_; | std::map<std::string, int32_t> class_index_; | ||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| int64_t row_cnt_; | int64_t row_cnt_; | ||||
| int64_t buf_cnt_; | int64_t buf_cnt_; | ||||
| int64_t sampler_ind_; | int64_t sampler_ind_; | ||||
| @@ -64,7 +64,7 @@ Status ManifestOp::Builder::SanityCheck() { | |||||
| ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | ||||
| const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | ||||
| std::shared_ptr<Sampler> sampler, std::string usage) | std::shared_ptr<Sampler> sampler, std::string usage) | ||||
| : ParallelOp(num_works, queue_size), | |||||
| : ParallelOp(num_works, queue_size, std::move(sampler)), | |||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| io_block_pushed_(0), | io_block_pushed_(0), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| @@ -72,7 +72,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f | |||||
| data_schema_(std::move(data_schema)), | data_schema_(std::move(data_schema)), | ||||
| file_(file), | file_(file), | ||||
| class_index_(class_index), | class_index_(class_index), | ||||
| sampler_(std::move(sampler)), | |||||
| decode_(decode), | decode_(decode), | ||||
| usage_(usage), | usage_(usage), | ||||
| buf_cnt_(0) { | buf_cnt_(0) { | ||||
| @@ -230,7 +230,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::string file_; // file that store the information of images | std::string file_; // file that store the information of images | ||||
| std::map<std::string, int32_t> class_index_; | std::map<std::string, int32_t> class_index_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| bool decode_; | bool decode_; | ||||
| std::string usage_; | std::string usage_; | ||||
| int64_t buf_cnt_; | int64_t buf_cnt_; | ||||
| @@ -66,12 +66,11 @@ Status MnistOp::Builder::SanityCheck() { | |||||
| MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, | MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, | ||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_workers, queue_size), | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| folder_path_(folder_path), | folder_path_(folder_path), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| sampler_(std::move(sampler)), | |||||
| data_schema_(std::move(data_schema)) { | data_schema_(std::move(data_schema)) { | ||||
| io_block_queues_.Init(num_workers, queue_size); | io_block_queues_.Init(num_workers, queue_size); | ||||
| } | } | ||||
| @@ -235,7 +235,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| WaitPost wp_; | WaitPost wp_; | ||||
| std::string folder_path_; // directory of image folder | std::string folder_path_; // directory of image folder | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::vector<MnistLabelPair> image_label_pairs_; | std::vector<MnistLabelPair> image_label_pairs_; | ||||
| std::vector<std::string> image_names_; | std::vector<std::string> image_names_; | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "dataset/core/config_manager.h" | #include "dataset/core/config_manager.h" | ||||
| #include "dataset/util/random.h" | #include "dataset/util/random.h" | ||||
| #include "dataset/util/wait_post.h" | #include "dataset/util/wait_post.h" | ||||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -30,7 +31,8 @@ RandomDataOp::Builder::Builder() | |||||
| builder_num_workers_(0), | builder_num_workers_(0), | ||||
| builder_op_connector_size_(0), | builder_op_connector_size_(0), | ||||
| builder_rows_per_buffer_(0), | builder_rows_per_buffer_(0), | ||||
| builder_total_rows_(0) { | |||||
| builder_total_rows_(0), | |||||
| builder_sampler_(nullptr) { | |||||
| // Some arguments to the RandomDataOp have a default argument that is taken from the config. | // Some arguments to the RandomDataOp have a default argument that is taken from the config. | ||||
| // The user may override these defaults by using the builder set methods. | // The user may override these defaults by using the builder set methods. | ||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| @@ -43,8 +45,9 @@ RandomDataOp::Builder::Builder() | |||||
| Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) { | Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) { | ||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| *out_op = std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, | |||||
| builder_total_rows_, std::move(builder_data_schema_)); | |||||
| *out_op = | |||||
| std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, | |||||
| builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); | |||||
| // If the user did not provide a schema, then we will ask the op to generate a pseudo-random | // If the user did not provide a schema, then we will ask the op to generate a pseudo-random | ||||
| // schema. | // schema. | ||||
| @@ -66,8 +69,8 @@ Status RandomDataOp::Builder::SanityCheck() const { | |||||
| // Constructor for RandomDataOp | // Constructor for RandomDataOp | ||||
| RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | ||||
| std::unique_ptr<DataSchema> data_schema) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||||
| buffer_id_(0), | buffer_id_(0), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| total_rows_(total_rows), | total_rows_(total_rows), | ||||
| @@ -124,7 +127,7 @@ Status RandomDataOp::GenerateSchema() { | |||||
| // For each column: | // For each column: | ||||
| // - choose a datatype | // - choose a datatype | ||||
| // - generate a shape that randomly chooses the number of dimensions and the dimension values. | // - generate a shape that randomly chooses the number of dimensions and the dimension values. | ||||
| DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(0, DataType::NUM_OF_TYPES - 2)); | |||||
| DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - 2)); | |||||
| int32_t rank = GenRandomInt(1, kMaxRank); | int32_t rank = GenRandomInt(1, kMaxRank); | ||||
| std::vector<dsize_t> dims; | std::vector<dsize_t> dims; | ||||
| for (int32_t d = 0; d < rank; d++) { | for (int32_t d = 0; d < rank; d++) { | ||||
| @@ -412,5 +415,15 @@ Status RandomDataOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| Status RandomDataOp::PrepareNodePostAction() { | |||||
| // Run common code from super class before adding RandomDataOp specific handling | |||||
| RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); | |||||
| // Specific handling for this op, we need to do cache op work to assign the sampler to the cache. | |||||
| RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false)); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,7 +42,7 @@ class RandomDataOp : public ParallelOp { | |||||
| // Some constants to provide limits to random generation. | // Some constants to provide limits to random generation. | ||||
| static constexpr int32_t kMaxNumColumns = 4; | static constexpr int32_t kMaxNumColumns = 4; | ||||
| static constexpr int32_t kMaxRank = 4; | static constexpr int32_t kMaxRank = 4; | ||||
| static constexpr int32_t kMaxDimValue = 2048; | |||||
| static constexpr int32_t kMaxDimValue = 32; | |||||
| static constexpr int32_t kMaxTotalRows = 1024; | static constexpr int32_t kMaxTotalRows = 1024; | ||||
| // A nested builder class to aid in the construction of a RandomDataOp | // A nested builder class to aid in the construction of a RandomDataOp | ||||
| @@ -117,6 +117,14 @@ class RandomDataOp : public ParallelOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method | |||||
| // @param std::shared_ptr<Sampler> sampler | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | |||||
| return *this; | |||||
| } | |||||
| private: | private: | ||||
| /** | /** | ||||
| * Check if the required parameters are set by the builder. | * Check if the required parameters are set by the builder. | ||||
| @@ -125,6 +133,7 @@ class RandomDataOp : public ParallelOp { | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| std::unique_ptr<DataSchema> builder_data_schema_; | std::unique_ptr<DataSchema> builder_data_schema_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| int64_t builder_rows_per_buffer_; | int64_t builder_rows_per_buffer_; | ||||
| @@ -139,10 +148,11 @@ class RandomDataOp : public ParallelOp { | |||||
| * @param rows_per_buffer - The number of rows in each DataBuffer | * @param rows_per_buffer - The number of rows in each DataBuffer | ||||
| * @param data_schema - A user-provided schema | * @param data_schema - A user-provided schema | ||||
| * @param total_rows - The total number of rows in the dataset | * @param total_rows - The total number of rows in the dataset | ||||
| * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | |||||
| * @return Builder - The modified builder by reference | * @return Builder - The modified builder by reference | ||||
| */ | */ | ||||
| RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, | ||||
| std::unique_ptr<DataSchema> data_schema); | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| /** | /** | ||||
| * Destructor | * Destructor | ||||
| @@ -193,6 +203,12 @@ class RandomDataOp : public ParallelOp { | |||||
| // @return Name of the current Op | // @return Name of the current Op | ||||
| std::string Name() const override { return "RandomDataOp"; } | std::string Name() const override { return "RandomDataOp"; } | ||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| // @notes Derived versions of this function should always call it's superclass version first | |||||
| // before providing their own implementations. | |||||
| Status PrepareNodePostAction() override; | |||||
| private: | private: | ||||
| /** | /** | ||||
| * The entry point code for when workers are launched | * The entry point code for when workers are launched | ||||
| @@ -107,12 +107,11 @@ Status DistributedSampler::ResetSampler() { | |||||
| } | } | ||||
| void DistributedSampler::Print(std::ostream &out, bool show_all) const { | void DistributedSampler::Print(std::ostream &out, bool show_all) const { | ||||
| out << "(sampler): DistributedSampler\n"; | |||||
| out << "\nSampler: DistributedSampler"; | |||||
| if (show_all) { | if (show_all) { | ||||
| out << "seed_: " << seed_ << '\n'; | |||||
| out << "device_id_: " << device_id_ << '\n'; | |||||
| out << "num_devices_: " << num_devices_ << '\n'; | |||||
| out << "shuffle_: " << shuffle_ << '\n'; | |||||
| Sampler::Print(out, show_all); | |||||
| out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ | |||||
| << "\nshuffle: " << shuffle_; | |||||
| } | } | ||||
| } | } | ||||
| @@ -113,5 +113,13 @@ Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void PKSampler::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: PKSampler"; | |||||
| if (show_all) { | |||||
| // Call the super class for displaying any common detailed info | |||||
| Sampler::Print(out, show_all); | |||||
| // Then add our own info if any | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -56,6 +56,11 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| // Printer for debugging purposes. | |||||
| // @param out - output stream to write to | |||||
| // @param show_all - bool to show detailed vs summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| bool shuffle_; | bool shuffle_; | ||||
| uint32_t seed_; | uint32_t seed_; | ||||
| @@ -103,5 +103,14 @@ Status PythonSampler::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void PythonSampler::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: PythonSampler"; | |||||
| if (show_all) { | |||||
| // Call the super class for displaying any common detailed info | |||||
| Sampler::Print(out, show_all); | |||||
| // Then add our own info if any | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -50,6 +50,11 @@ class PythonSampler : public Sampler { | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // Printer for debugging purposes. | |||||
| // @param out - output stream to write to | |||||
| // @param show_all - bool to show detailed vs summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | ||||
| @@ -113,13 +113,12 @@ Status RandomSampler::ResetSampler() { | |||||
| } | } | ||||
| void RandomSampler::Print(std::ostream &out, bool show_all) const { | void RandomSampler::Print(std::ostream &out, bool show_all) const { | ||||
| out << "(sampler): RandomSampler\n"; | |||||
| out << "\nSampler: RandomSampler"; | |||||
| if (show_all) { | if (show_all) { | ||||
| out << "num_samples_: " << num_samples_ << '\n'; | |||||
| out << "next_id_: " << next_id_ << '\n'; | |||||
| // Call the super class for displaying any common detailed info | |||||
| Sampler::Print(out, show_all); | |||||
| // Then add our own info if any | |||||
| } | } | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -80,11 +80,12 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t | |||||
| } | } | ||||
| void Sampler::Print(std::ostream &out, bool show_all) const { | void Sampler::Print(std::ostream &out, bool show_all) const { | ||||
| out << "(sampler): base\n"; | |||||
| // Sampler printing is usually only called in the show_all mode. | |||||
| // Derived classes will display the name, then call back to this base | |||||
| // for common info. | |||||
| // No-op in the summary mode. | |||||
| if (show_all) { | if (show_all) { | ||||
| out << "num_rows_: " << num_rows_ << '\n'; | |||||
| out << "num_samples_: " << num_samples_ << '\n'; | |||||
| out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_; | |||||
| } | } | ||||
| } | } | ||||
| @@ -89,7 +89,14 @@ Status SequentialSampler::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; } | |||||
| void SequentialSampler::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: SequentialSampler"; | |||||
| if (show_all) { | |||||
| // Call the super class for displaying any common detailed info | |||||
| Sampler::Print(out, show_all); | |||||
| // Then add our own info | |||||
| out << "\nStart index: " << start_index_; | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -49,6 +49,9 @@ class SequentialSampler : public Sampler { | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // Printer for debugging purposes. | |||||
| // @param out - output stream to write to | |||||
| // @param show_all - bool to show detailed vs summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| private: | private: | ||||
| @@ -119,5 +119,14 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: SubsetRandomSampler"; | |||||
| if (show_all) { | |||||
| // Call the super class for displaying any common detailed info | |||||
| Sampler::Print(out, show_all); | |||||
| // Then add our own info if any | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -51,6 +51,11 @@ class SubsetRandomSampler : public Sampler { | |||||
| // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // Printer for debugging purposes. | |||||
| // @param out - output stream to write to | |||||
| // @param show_all - bool to show detailed vs summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| // A list of indices (already randomized in constructor). | // A list of indices (already randomized in constructor). | ||||
| std::vector<int64_t> indices_; | std::vector<int64_t> indices_; | ||||
| @@ -156,5 +156,14 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: WeightedRandomSampler"; | |||||
| if (show_all) { | |||||
| // Call the super class for displaying any common detailed info | |||||
| Sampler::Print(out, show_all); | |||||
| // Then add our own info if any | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -53,6 +53,11 @@ class WeightedRandomSampler : public Sampler { | |||||
| // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| // Printer for debugging purposes. | |||||
| // @param out - output stream to write to | |||||
| // @param show_all - bool to show detailed vs summary | |||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| // A list of weights for each sample. | // A list of weights for each sample. | ||||
| std::vector<double> weights_; | std::vector<double> weights_; | ||||
| @@ -33,7 +33,11 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| TextFileOp::Builder::Builder() | TextFileOp::Builder::Builder() | ||||
| : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { | |||||
| : builder_device_id_(0), | |||||
| builder_num_devices_(1), | |||||
| builder_total_rows_(0), | |||||
| builder_shuffle_files_(false), | |||||
| builder_sampler_(nullptr) { | |||||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | ||||
| builder_num_workers_ = config_manager->num_parallel_workers(); | builder_num_workers_ = config_manager->num_parallel_workers(); | ||||
| builder_op_connector_size_ = config_manager->op_connector_size(); | builder_op_connector_size_ = config_manager->op_connector_size(); | ||||
| @@ -64,7 +68,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | ||||
| builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, | builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, | ||||
| std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, | std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, | ||||
| builder_num_devices_, builder_device_id_); | |||||
| builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); | |||||
| RETURN_IF_NOT_OK(text_file_op->Init()); | RETURN_IF_NOT_OK(text_file_op->Init()); | ||||
| *op = std::move(text_file_op); | *op = std::move(text_file_op); | ||||
| @@ -73,8 +77,9 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | ||||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | ||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, | |||||
| std::shared_ptr<Sampler> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||||
| device_id_(device_id), | device_id_(device_id), | ||||
| num_devices_(num_device), | num_devices_(num_device), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "dataset/util/status.h" | #include "dataset/util/status.h" | ||||
| @@ -112,6 +113,14 @@ class TextFileOp : public ParallelOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method | |||||
| // @param std::shared_ptr<Sampler> sampler | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | |||||
| return *this; | |||||
| } | |||||
| private: | private: | ||||
| int32_t builder_device_id_; | int32_t builder_device_id_; | ||||
| int32_t builder_num_devices_; | int32_t builder_num_devices_; | ||||
| @@ -123,6 +132,7 @@ class TextFileOp : public ParallelOp { | |||||
| std::vector<std::string> builder_text_files_list_; | std::vector<std::string> builder_text_files_list_; | ||||
| bool builder_shuffle_files_; | bool builder_shuffle_files_; | ||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| }; | }; | ||||
| // Constructor of TextFileOp | // Constructor of TextFileOp | ||||
| @@ -136,9 +146,10 @@ class TextFileOp : public ParallelOp { | |||||
| // @param columns_to_load - the names of the columns to load data from. | // @param columns_to_load - the names of the columns to load data from. | ||||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | // @param shuffle_files - whether or not to shuffle the files before reading data. | ||||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | // @param equal_rows_per_shard - whether or not to get equal rows for each process. | ||||
| // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | |||||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | ||||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | ||||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||||
| bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~TextFileOp() = default; | ~TextFileOp() = default; | ||||
| @@ -48,7 +48,11 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| TFReaderOp::Builder::Builder() | TFReaderOp::Builder::Builder() | ||||
| : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) { | |||||
| : builder_device_id_(0), | |||||
| builder_num_devices_(1), | |||||
| builder_total_rows_(0), | |||||
| builder_equal_rows_per_shard_(false), | |||||
| builder_sampler_(nullptr) { | |||||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | ||||
| builder_num_workers_ = config_manager->num_parallel_workers(); | builder_num_workers_ = config_manager->num_parallel_workers(); | ||||
| builder_worker_connector_size_ = config_manager->worker_connector_size(); | builder_worker_connector_size_ = config_manager->worker_connector_size(); | ||||
| @@ -87,11 +91,6 @@ Status TFReaderOp::Builder::ValidateInputs() const { | |||||
| err_msg += "Number of parallel workers is smaller or equal to 0\n"; | err_msg += "Number of parallel workers is smaller or equal to 0\n"; | ||||
| } | } | ||||
| if (!builder_equal_rows_per_shard_ && | |||||
| builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) { | |||||
| err_msg += "Not enough tfrecord files provided\n"; | |||||
| } | |||||
| if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { | if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { | ||||
| err_msg += "Wrong sharding configs\n"; | err_msg += "Wrong sharding configs\n"; | ||||
| } | } | ||||
| @@ -125,7 +124,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op) | |||||
| std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>( | std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>( | ||||
| builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, | builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, | ||||
| builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, | builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, | ||||
| builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_); | |||||
| builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, | |||||
| std::move(builder_sampler_)); | |||||
| RETURN_IF_NOT_OK(new_tf_reader_op->Init()); | RETURN_IF_NOT_OK(new_tf_reader_op->Init()); | ||||
| *out_tf_reader_op = std::move(new_tf_reader_op); | *out_tf_reader_op = std::move(new_tf_reader_op); | ||||
| @@ -136,8 +136,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 | |||||
| int64_t total_num_rows, std::vector<std::string> dataset_files_list, | int64_t total_num_rows, std::vector<std::string> dataset_files_list, | ||||
| std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | ||||
| std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | ||||
| int32_t device_id, bool equal_rows_per_shard) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler) | |||||
| : ParallelOp(num_workers, op_connector_size, std::move(sampler)), | |||||
| device_id_(device_id), | device_id_(device_id), | ||||
| num_devices_(num_device), | num_devices_(num_device), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| @@ -1018,5 +1018,40 @@ Status TFReaderOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| Status TFReaderOp::PrepareNodePostAction() { | |||||
| // Run common code from super class before adding TFReaderOp specific handling | |||||
| RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); | |||||
| // Specific handling for this op, we need to do cache op work so assign the sampler to the cache | |||||
| // TF is a special case because it can support file-based sharding/shuffling, or, if there | |||||
| // is a cache, then it can also do row-based sampler using the sampler on the cache. | |||||
| // Thus, pass true for random access op flag when saving the sampler. This is a special case, | |||||
| // since usually a non-mappable dataset would pass false here. | |||||
| RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true)); | |||||
| // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into | |||||
| // a simpler producer of all data (no shuffling or sharding or anything) | |||||
| if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { | |||||
| device_id_ = 0; | |||||
| num_devices_ = 1; | |||||
| total_rows_ = 0; | |||||
| shuffle_files_ = false; | |||||
| equal_rows_per_shard_ = false; | |||||
| sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment) | |||||
| } else { | |||||
| // This sanity check had been delayed until now in the prepare loop. | |||||
| // If we are not in a cache path, then we can validate the the file-based sharding config. | |||||
| // If we are in a cache path, there is no file-based sharding so the check is not correct in that | |||||
| // situation. | |||||
| if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast<uint32_t>(num_devices_)) { | |||||
| RETURN_STATUS_UNEXPECTED("Not enough tfrecord files provided\n"); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -153,8 +153,17 @@ class TFReaderOp : public ParallelOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method | |||||
| // @param std::shared_ptr<Sampler> sampler | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetSampler(std::shared_ptr<Sampler> sampler) { | |||||
| builder_sampler_ = std::move(sampler); | |||||
| return *this; | |||||
| } | |||||
| private: | private: | ||||
| std::unique_ptr<DataSchema> builder_data_schema_; | std::unique_ptr<DataSchema> builder_data_schema_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | |||||
| int32_t builder_device_id_; | int32_t builder_device_id_; | ||||
| int32_t builder_num_devices_; | int32_t builder_num_devices_; | ||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| @@ -180,10 +189,11 @@ class TFReaderOp : public ParallelOp { | |||||
| // @param columns_to_load - the names of the columns to load data from. | // @param columns_to_load - the names of the columns to load data from. | ||||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | // @param shuffle_files - whether or not to shuffle the files before reading data. | ||||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | // @param equal_rows_per_shard - whether or not to get equal rows for each process. | ||||
| // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes | |||||
| TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | ||||
| std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, | std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema, | ||||
| int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, | int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files, | ||||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); | |||||
| int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler); | |||||
| // Default destructor | // Default destructor | ||||
| ~TFReaderOp() = default; | ~TFReaderOp() = default; | ||||
| @@ -236,6 +246,12 @@ class TFReaderOp : public ParallelOp { | |||||
| // @return Vector of the input file names | // @return Vector of the input file names | ||||
| std::vector<std::string> FileNames() { return dataset_files_list_; } | std::vector<std::string> FileNames() { return dataset_files_list_; } | ||||
| // During tree prepare phase, operators may have specific post-operations to perform depending on | |||||
| // their role. | |||||
| // @notes Derived versions of this function should always call it's superclass version first | |||||
| // before providing their own implementations. | |||||
| Status PrepareNodePostAction() override; | |||||
| private: | private: | ||||
| // The entry point for when workers are launched. | // The entry point for when workers are launched. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| @@ -88,7 +88,7 @@ Status VOCOp::Builder::SanityCheck() { | |||||
| VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | ||||
| const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | ||||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_workers, queue_size), | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | |||||
| decode_(decode), | decode_(decode), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| @@ -97,7 +97,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std: | |||||
| folder_path_(folder_path), | folder_path_(folder_path), | ||||
| class_index_(class_index), | class_index_(class_index), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| sampler_(std::move(sampler)), | |||||
| data_schema_(std::move(data_schema)) { | data_schema_(std::move(data_schema)) { | ||||
| io_block_queues_.Init(num_workers_, queue_size); | io_block_queues_.Init(num_workers_, queue_size); | ||||
| } | } | ||||
| @@ -274,7 +274,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| TaskType task_type_; | TaskType task_type_; | ||||
| std::string task_mode_; | std::string task_mode_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| std::shared_ptr<Sampler> sampler_; | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| WaitPost wp_; | WaitPost wp_; | ||||
| @@ -129,7 +129,7 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D | |||||
| Status TakeOp::PrepareNodePostAction() { | Status TakeOp::PrepareNodePostAction() { | ||||
| RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); | ||||
| tree_->AddToRepeatStack(shared_from_this()); | |||||
| tree_->AddToEOEOpStack(shared_from_this()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -88,13 +88,13 @@ Status ExecutionTree::AssignRoot(const std::shared_ptr<DatasetOp> &op) { | |||||
| } | } | ||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void ExecutionTree::Print(std::ostream &out) const { | |||||
| void ExecutionTree::Print(std::ostream &out, const std::shared_ptr<DatasetOp> &op) const { | |||||
| out << "Execution tree summary:\n" | out << "Execution tree summary:\n" | ||||
| << "-----------------------\n"; | << "-----------------------\n"; | ||||
| this->PrintNode(out, root_, "", true, false); | |||||
| this->PrintNode(out, op == nullptr ? root_ : op, "", true, false); | |||||
| out << "\nExecution tree operator details:\n" | out << "\nExecution tree operator details:\n" | ||||
| << "--------------------------------\n"; | << "--------------------------------\n"; | ||||
| this->PrintNode(out, root_, "", true, true); | |||||
| this->PrintNode(out, op == nullptr ? root_ : op, "", true, true); | |||||
| } | } | ||||
| // A helper functions for doing the recursive printing | // A helper functions for doing the recursive printing | ||||
| @@ -269,27 +269,40 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op) | |||||
| RETURN_IF_NOT_OK(this->PrepareNode(i)); | RETURN_IF_NOT_OK(this->PrepareNode(i)); | ||||
| } | } | ||||
| // Then clear the flags from this op now that we have prepared it. | |||||
| BitClear(&prepare_flags_, op_prep_flags); | |||||
| // No more children, now we execute any prepare actions before going back up the | // No more children, now we execute any prepare actions before going back up the | ||||
| // the tree on recursive function | // the tree on recursive function | ||||
| RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); | RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); | ||||
| // Then clear the flags from this op now that we have prepared it. | |||||
| BitClear(&prepare_flags_, op_prep_flags); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Adds an operator to the repeat stack during prepare phase. | |||||
| void ExecutionTree::AddToRepeatStack(std::shared_ptr<DatasetOp> dataset_op) { repeat_stack_.push(dataset_op); } | |||||
| // Adds an operator to the eoe operator stack during prepare phase. | |||||
| void ExecutionTree::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); } | |||||
| // Pops an operator from the repeat stack during prepare phase. | |||||
| std::shared_ptr<DatasetOp> ExecutionTree::PopFromRepeatStack() { | |||||
| // Pops an operator from the eoe operator stack during prepare phase. | |||||
| std::shared_ptr<DatasetOp> ExecutionTree::PopFromEOEOpStack() { | |||||
| std::shared_ptr<DatasetOp> top_op = nullptr; | std::shared_ptr<DatasetOp> top_op = nullptr; | ||||
| if (!repeat_stack_.empty()) { | |||||
| top_op = repeat_stack_.top(); | |||||
| repeat_stack_.pop(); | |||||
| if (!eoe_stack_.empty()) { | |||||
| top_op = eoe_stack_.top(); | |||||
| eoe_stack_.pop(); | |||||
| } | } | ||||
| return top_op; | return top_op; | ||||
| } | } | ||||
| // Adds a sampler to the sampler stack during prepare phase. | |||||
| void ExecutionTree::AddToSamplerStack(std::shared_ptr<Sampler> sampler) { sampler_stack_.push(sampler); } | |||||
| // Pops an operator from the sampler stack during prepare phase. | |||||
| std::shared_ptr<Sampler> ExecutionTree::PopFromSamplerStack() { | |||||
| std::shared_ptr<Sampler> top_sampler = nullptr; | |||||
| if (!sampler_stack_.empty()) { | |||||
| top_sampler = sampler_stack_.top(); | |||||
| sampler_stack_.pop(); | |||||
| } | |||||
| return top_sampler; | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,7 +37,8 @@ class ExecutionTree { | |||||
| // Prepare flags used during tree prepare phase | // Prepare flags used during tree prepare phase | ||||
| enum PrepareFlags { | enum PrepareFlags { | ||||
| kDePrepNone = 0, | kDePrepNone = 0, | ||||
| kDePrepRepeat = 1 // Processing a repeat operation | |||||
| kDePrepRepeat = 1, // Processing a repeat operation | |||||
| kDePrepCache = 2 // Processing a cache operation | |||||
| }; | }; | ||||
| // State flags for the lifecycle of the tree | // State flags for the lifecycle of the tree | ||||
| @@ -118,9 +119,9 @@ class ExecutionTree { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status Launch(); | Status Launch(); | ||||
| // A print method typically used for debugging | |||||
| // @param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const; | |||||
| /// A print method typically used for debugging | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out, const std::shared_ptr<DatasetOp> &op = nullptr) const; | |||||
| // Returns an iterator positioned at the start | // Returns an iterator positioned at the start | ||||
| // @return Iterator - The iterator | // @return Iterator - The iterator | ||||
| @@ -199,14 +200,23 @@ class ExecutionTree { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op); | Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op); | ||||
| // Adds an operator to the repeat stack during prepare phase. | |||||
| // @param op - The dataset op to work add to repeat stack | |||||
| // @return Status - The error code return | |||||
| void AddToRepeatStack(std::shared_ptr<DatasetOp> dataset_op); | |||||
| /// Adds an operator to the eoe operator stack during prepare phase. | |||||
| /// \param op - The dataset op to work add to eoe stack | |||||
| /// \return Status - The error code return | |||||
| void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op); | |||||
| /// Pops an operator from the eoe operator stack during prepare phase. | |||||
| /// \return shared_ptr to the popped operator | |||||
| std::shared_ptr<DatasetOp> PopFromEOEOpStack(); | |||||
| /// Adds a sampler to the sampler stack during prepare phase. | |||||
| /// \param samplerop - The dataset op to work add to eoe stack | |||||
| /// \return Status - The error code return | |||||
| void AddToSamplerStack(std::shared_ptr<Sampler> sampler); | |||||
| // Pops an operator from the repeat stack during prepare phase. | |||||
| // @return shared_ptr to the popped operator | |||||
| std::shared_ptr<DatasetOp> PopFromRepeatStack(); | |||||
| /// Pops an operator from the sampler stack during prepare phase. | |||||
| /// \return shared_ptr to the popped operator | |||||
| std::shared_ptr<Sampler> PopFromSamplerStack(); | |||||
| // Return the pointer to the TaskGroup | // Return the pointer to the TaskGroup | ||||
| // @return raw pointer to the TaskGroup | // @return raw pointer to the TaskGroup | ||||
| @@ -236,9 +246,10 @@ class ExecutionTree { | |||||
| int32_t id_count_; // Counter for generating operator id's | int32_t id_count_; // Counter for generating operator id's | ||||
| uint32_t prepare_flags_; // Flags used during tree prepare | uint32_t prepare_flags_; // Flags used during tree prepare | ||||
| TreeState tree_state_; // Tracking the current tree state | TreeState tree_state_; // Tracking the current tree state | ||||
| std::stack<std::shared_ptr<DatasetOp>> repeat_stack_; // A stack used during prepare phase | |||||
| std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor | std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor | ||||
| std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager | ||||
| std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A stack used during prepare phase | |||||
| std::stack<std::shared_ptr<Sampler>> sampler_stack_; // A stack used during prepare phase | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||