diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc index c5aac523d2..4bada31e7e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc @@ -128,7 +128,7 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { Status ConcatOp::PrepareNodePostAction() { RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); - tree_->AddToRepeatStack(shared_from_this()); + tree_->AddToEOEOpStack(shared_from_this()); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc index 727c543958..91ed7fbc5f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc @@ -18,23 +18,26 @@ #include #include #include +#include #include #include #include #include "dataset/engine/execution_tree.h" #include "dataset/engine/datasetops/device_queue_op.h" +#include "dataset/engine/datasetops/source/sampler/sampler.h" #include "dataset/engine/data_buffer.h" #include "dataset/engine/db_connector.h" #include "dataset/engine/opt/pass.h" - +#include "utils/system/crc32c.h" #include "utils/log_adapter.h" namespace mindspore { namespace dataset { // Constructor -DatasetOp::DatasetOp(int32_t op_connector_size) +DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler) : oc_queue_size_(op_connector_size), + sampler_(sampler), operator_id_(kInvalidOperatorId), tree_(nullptr), state_(OpState::kDeOpIdle), @@ -150,6 +153,9 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { } out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); + if (sampler_) { + sampler_->Print(out, show_all); + } } } @@ -222,11 +228,10 @@ Status DatasetOp::PrepareNodePreAction() { Status DatasetOp::PrepareNodePostAction() { // If this op does not have any children and it is in a repeat path of the tree... if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { - // push ourselves onto the tree repeat stack. Later, the repeat operator + // push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator // above us will consume them. - tree_->AddToRepeatStack(shared_from_this()); + tree_->AddToEOEOpStack(shared_from_this()); } - // Creating Connector object for each op. // The consumer of the root node is assumed to be one thread. // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. @@ -289,5 +294,56 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) { // This method will only be called if its derived class does not implement one. return p->RunOnNode(shared_from_this(), modified); } + +// A helper function with some common code that leaf nodes can use during +// prepare phase for checking if they need to assign a sampler to the cache. +// @return - Status +Status DatasetOp::SaveSamplerForCache(bool random_access_op) { + // If we are a descendant under a cache op and we have a sampler, then save this sampler + // to a stack so that the cache can pick it up during it's processing above us. + if (sampler_) { + if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { + // use move semantic to set our sampler_ to null after the move. This is okay because a sampler is + // useless to a random data op. It was only being used as a temporary holding until the cache can + // be created + tree_->AddToSamplerStack(sampler_); + MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling."; + } else if (!random_access_op) { + // A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf. + // This is an error because that type of leaf does not use sampling unless there's a cache to hook it into. + return Status( + StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree"); + } + } + + if (!random_access_op) { + // Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache + // we can remove it now from the base. + sampler_.reset(); + } + + return Status::OK(); +} +uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { + std::stringstream ss; + op->tree_->Print(ss, op); + std::string ss_str = ss.str(); + + // Filter out the Operator control flags field when generating the check sum + ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); + + // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline + ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); + + // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same + // cache_client later. So we filter out these two fields to allow cache sharing. + ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); + + uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); + return cache_crc; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h index 2516fdbee1..254cd411c5 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h @@ -34,6 +34,8 @@ class DataBuffer; class NodePass; +class Sampler; + // The base class DatasetOp is the main tree node. It is an abstract class, so // the actual implementation of the operators will be derived from here. class DatasetOp : public std::enable_shared_from_this { @@ -55,7 +57,8 @@ class DatasetOp : public std::enable_shared_from_this { // Constructor // @param op_connector_size - The size for the output connector of this operator. - explicit DatasetOp(int32_t op_connector_size); + // @param sampler - The sampler for the op + explicit DatasetOp(int32_t op_connector_size, std::shared_ptr sampler); // Destructor virtual ~DatasetOp() { tree_ = nullptr; } @@ -204,6 +207,10 @@ class DatasetOp : public std::enable_shared_from_this { // @return Sets the control flags void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } + // Setter function + // @return Sets the control flags + void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } + // Register the internal worker connectors. No op unless it is a parallel op // @return Status virtual Status RegisterWorkerConnectors() { return Status::OK(); } @@ -271,6 +278,13 @@ class DatasetOp : public std::enable_shared_from_this { // @return Pointer to the ExecutionTree the current op belongs to, no ownership ExecutionTree *Tree() { return tree_; } + // Getter for the sampler + // @return Shared pointer to the sampler (may return nullptr) + std::shared_ptr sampler() { return sampler_; } + + // Computes a CRC value for the operator + static uint32_t GenerateCRC(const std::shared_ptr &op); + protected: // Adds a parent operator to this operator // @notes External callers do not have access to this function. @@ -289,8 +303,15 @@ class DatasetOp : public std::enable_shared_from_this { // @return - Status virtual Status ComputeColMap(); + // A helper function with some common code that leaf nodes can use during + // prepare phase for checking if they need to assign a sampler to the cache. + // @param random_access_op - indicate if this is a mappable random access leaf or not + // @return - Status + Status SaveSamplerForCache(bool random_access_op); + std::vector> child_; // Child nodes std::vector parent_; // Parent nodes. No ownership + std::shared_ptr sampler_; // Some leaf ops might have a sampler int32_t oc_queue_size_; // Capacity for each out_connector_ int32_t operator_id_; // Generated id for the node ExecutionTree *tree_; // Back pointer to our tree. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc index fcb2e357e8..020f40d268 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc @@ -100,7 +100,7 @@ void MapOp::Print(std::ostream &out, bool show_all) const { } out << "\n TensorOps:"; for (size_t i = 0; i < tfuncs_.size(); i++) { - out << " " << tfuncs_[i]; + out << " " << *(tfuncs_[i].get()); } out << "\n\n"; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc index c0a8d95f15..244861a6c8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc @@ -26,8 +26,8 @@ namespace mindspore { namespace dataset { // Constructor -ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size) - : DatasetOp(op_connector_size), +ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler) + : DatasetOp(op_connector_size, sampler), num_workers_(num_workers), num_producers_(num_workers), worker_connector_size_(1), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h index 142ec78360..f59d4bfc53 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h @@ -38,7 +38,8 @@ class ParallelOp : public DatasetOp { // Constructor // @param num_workers // @param op_connector_size - size of the output connector for this operator - ParallelOp(int32_t num_workers, int32_t op_connector_size); + // @param sampler - The sampler for the op + ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler = nullptr); // Destructor ~ParallelOp() = default; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc index 46eded8ea1..1d017a4d3e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc @@ -20,7 +20,8 @@ namespace mindspore { namespace dataset { // Constructor -PipelineOp::PipelineOp(int32_t op_connector_size) : DatasetOp(op_connector_size) {} +PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampler) + : DatasetOp(op_connector_size, sampler) {} // A print method typically used for debugging void PipelineOp::Print(std::ostream &out, bool show_all) const { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h index a14279032d..cb3c76813b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h @@ -32,7 +32,8 @@ class PipelineOp : public DatasetOp { // Constructor // @param op_connector_size - size of the output connector // @return Builder setter method returns reference to the builder. - explicit PipelineOp(int32_t op_connector_size); + // @param sampler - The sampler for the op + explicit PipelineOp(int32_t op_connector_size, std::shared_ptr sampler = nullptr); // Destructor ~PipelineOp() = default; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc index 66e2177636..86903e540a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc @@ -82,14 +82,14 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const { Status RepeatOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own specific logic RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); - std::shared_ptr leaf_op = tree_->PopFromRepeatStack(); + std::shared_ptr leaf_op = tree_->PopFromEOEOpStack(); while (leaf_op != nullptr) { // Track the leaf operators that are under this repeat op. eoe_ops_.push_back(leaf_op); - leaf_op = tree_->PopFromRepeatStack(); + leaf_op = tree_->PopFromEOEOpStack(); } // Push ourselves to the stack in case one of our ascendants is repeat too. - tree_->AddToRepeatStack(shared_from_this()); + tree_->AddToEOEOpStack(shared_from_this()); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc index 7889362555..c7a4269a39 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc @@ -70,13 +70,12 @@ Status CelebAOp::Builder::SanityCheck() { CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, const std::string &dataset_type, const std::set &exts, std::unique_ptr schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size), + : ParallelOp(num_workers, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), folder_path_(dir), decode_(decode), extensions_(exts), data_schema_(std::move(schema)), - sampler_(std::move(sampler)), num_rows_in_attr_file_(0), dataset_type_(dataset_type) { attr_info_queue_ = std::make_unique>>(queue_size); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h index f8a49dabb2..a6fa495a14 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h @@ -221,7 +221,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { bool decode_; std::set extensions_; // extensions allowed std::unique_ptr data_schema_; - std::shared_ptr sampler_; std::unique_ptr>> attr_info_queue_; int64_t num_rows_in_attr_file_; // rows number specified in attr file QueueList> io_block_queues_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc index e7c418b146..8dd615a8c1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc @@ -79,12 +79,11 @@ Status CifarOp::Builder::SanityCheck() { CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_works, queue_size), + : ParallelOp(num_works, queue_size, std::move(sampler)), cifar_type_(type), rows_per_buffer_(rows_per_buf), folder_path_(file_dir), data_schema_(std::move(data_schema)), - sampler_(std::move(sampler)), row_cnt_(0), buf_cnt_(0) { constexpr uint64_t kUtilQueueSize = 512; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h index 21ed80ceab..917b23db94 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h @@ -216,7 +216,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { int32_t rows_per_buffer_; std::string folder_path_; std::unique_ptr data_schema_; - std::shared_ptr sampler_; int64_t row_cnt_; int64_t buf_cnt_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc index c28ed2d3ab..cb17158bff 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc @@ -65,7 +65,7 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str bool recursive, bool do_decode, const std::set &exts, const std::map &map, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_wkrs, queue_size), + : ParallelOp(num_wkrs, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), folder_path_(file_dir), recursive_(recursive), @@ -73,7 +73,6 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str extensions_(exts), class_index_(map), data_schema_(std::move(data_schema)), - sampler_(std::move(sampler)), row_cnt_(0), buf_cnt_(0), sampler_ind_(0), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h index 06f39deee0..6629fd6092 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h @@ -259,7 +259,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { std::set extensions_; // extensions allowed std::map class_index_; std::unique_ptr data_schema_; - std::shared_ptr sampler_; int64_t row_cnt_; int64_t buf_cnt_; int64_t sampler_ind_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc index e26bb7de65..e65da8707b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc @@ -64,7 +64,7 @@ Status ManifestOp::Builder::SanityCheck() { ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, const std::map &class_index, std::unique_ptr data_schema, std::shared_ptr sampler, std::string usage) - : ParallelOp(num_works, queue_size), + : ParallelOp(num_works, queue_size, std::move(sampler)), rows_per_buffer_(rows_per_buffer), io_block_pushed_(0), row_cnt_(0), @@ -72,7 +72,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f data_schema_(std::move(data_schema)), file_(file), class_index_(class_index), - sampler_(std::move(sampler)), decode_(decode), usage_(usage), buf_cnt_(0) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h index 1bdf683084..c180ea581d 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h @@ -230,7 +230,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { std::unique_ptr data_schema_; std::string file_; // file that store the information of images std::map class_index_; - std::shared_ptr sampler_; bool decode_; std::string usage_; int64_t buf_cnt_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc index 67e7757da5..e98f8ae8c1 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc @@ -66,12 +66,11 @@ Status MnistOp::Builder::SanityCheck() { MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size), + : ParallelOp(num_workers, queue_size, std::move(sampler)), buf_cnt_(0), row_cnt_(0), folder_path_(folder_path), rows_per_buffer_(rows_per_buffer), - sampler_(std::move(sampler)), data_schema_(std::move(data_schema)) { io_block_queues_.Init(num_workers, queue_size); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h index c22ee24acd..9bd6276a11 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h @@ -235,7 +235,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { WaitPost wp_; std::string folder_path_; // directory of image folder int32_t rows_per_buffer_; - std::shared_ptr sampler_; std::unique_ptr data_schema_; std::vector image_label_pairs_; std::vector image_names_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc index afd7ebcc55..3a865d8d69 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc @@ -21,6 +21,7 @@ #include "dataset/core/config_manager.h" #include "dataset/util/random.h" #include "dataset/util/wait_post.h" +#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" namespace mindspore { namespace dataset { @@ -30,7 +31,8 @@ RandomDataOp::Builder::Builder() builder_num_workers_(0), builder_op_connector_size_(0), builder_rows_per_buffer_(0), - builder_total_rows_(0) { + builder_total_rows_(0), + builder_sampler_(nullptr) { // Some arguments to the RandomDataOp have a default argument that is taken from the config. // The user may override these defaults by using the builder set methods. std::shared_ptr cfg = GlobalContext::config_manager(); @@ -43,8 +45,9 @@ RandomDataOp::Builder::Builder() Status RandomDataOp::Builder::Build(std::shared_ptr *out_op) { RETURN_IF_NOT_OK(SanityCheck()); - *out_op = std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, - builder_total_rows_, std::move(builder_data_schema_)); + *out_op = + std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, + builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); // If the user did not provide a schema, then we will ask the op to generate a pseudo-random // schema. @@ -66,8 +69,8 @@ Status RandomDataOp::Builder::SanityCheck() const { // Constructor for RandomDataOp RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema) - : ParallelOp(num_workers, op_connector_size), + std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), buffer_id_(0), rows_per_buffer_(rows_per_buffer), total_rows_(total_rows), @@ -124,7 +127,7 @@ Status RandomDataOp::GenerateSchema() { // For each column: // - choose a datatype // - generate a shape that randomly chooses the number of dimensions and the dimension values. - DataType::Type newType = static_cast(GenRandomInt(0, DataType::NUM_OF_TYPES - 2)); + DataType::Type newType = static_cast(GenRandomInt(1, DataType::NUM_OF_TYPES - 2)); int32_t rank = GenRandomInt(1, kMaxRank); std::vector dims; for (int32_t d = 0; d < rank; d++) { @@ -412,5 +415,15 @@ Status RandomDataOp::ComputeColMap() { } return Status::OK(); } + +// During tree prepare phase, operators may have specific post-operations to perform depending on +// their role. +Status RandomDataOp::PrepareNodePostAction() { + // Run common code from super class before adding RandomDataOp specific handling + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + // Specific handling for this op, we need to do cache op work to assign the sampler to the cache. + RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false)); + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h index 020c9a6e09..b2af27dda3 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h @@ -42,7 +42,7 @@ class RandomDataOp : public ParallelOp { // Some constants to provide limits to random generation. static constexpr int32_t kMaxNumColumns = 4; static constexpr int32_t kMaxRank = 4; - static constexpr int32_t kMaxDimValue = 2048; + static constexpr int32_t kMaxDimValue = 32; static constexpr int32_t kMaxTotalRows = 1024; // A nested builder class to aid in the construction of a RandomDataOp @@ -117,6 +117,14 @@ class RandomDataOp : public ParallelOp { return *this; } + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + private: /** * Check if the required parameters are set by the builder. @@ -125,6 +133,7 @@ class RandomDataOp : public ParallelOp { Status SanityCheck() const; std::unique_ptr builder_data_schema_; + std::shared_ptr builder_sampler_; int32_t builder_num_workers_; int32_t builder_op_connector_size_; int64_t builder_rows_per_buffer_; @@ -139,10 +148,11 @@ class RandomDataOp : public ParallelOp { * @param rows_per_buffer - The number of rows in each DataBuffer * @param data_schema - A user-provided schema * @param total_rows - The total number of rows in the dataset + * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes * @return Builder - The modified builder by reference */ RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema); + std::unique_ptr data_schema, std::shared_ptr sampler); /** * Destructor @@ -193,6 +203,12 @@ class RandomDataOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "RandomDataOp"; } + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override; + private: /** * The entry point code for when workers are launched diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc index 226647df14..9f4a9cf55c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -107,12 +107,11 @@ Status DistributedSampler::ResetSampler() { } void DistributedSampler::Print(std::ostream &out, bool show_all) const { - out << "(sampler): DistributedSampler\n"; + out << "\nSampler: DistributedSampler"; if (show_all) { - out << "seed_: " << seed_ << '\n'; - out << "device_id_: " << device_id_ << '\n'; - out << "num_devices_: " << num_devices_ << '\n'; - out << "shuffle_: " << shuffle_ << '\n'; + Sampler::Print(out, show_all); + out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ + << "\nshuffle: " << shuffle_; } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc index 92a880d599..cd2cadb9ff 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -113,5 +113,13 @@ Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { return Status::OK(); } +void PKSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: PKSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h index 7b1423326a..cde8a75b5b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -56,6 +56,11 @@ class PKSampler : public Sampler { // NOT YET FINISHED // @return - The error code return Status ResetSampler() override; + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + private: bool shuffle_; uint32_t seed_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc index af4aa20bb2..d204c55ce9 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -103,5 +103,14 @@ Status PythonSampler::ResetSampler() { return Status::OK(); } + +void PythonSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: PythonSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h index 49ff12878d..7d653b2087 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -50,6 +50,11 @@ class PythonSampler : public Sampler { // @return - The error code return Status GetNextSample(std::unique_ptr *out_buffer) override; + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + private: bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc index b3dfaad7f7..db0a96ea3a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -113,13 +113,12 @@ Status RandomSampler::ResetSampler() { } void RandomSampler::Print(std::ostream &out, bool show_all) const { - out << "(sampler): RandomSampler\n"; - + out << "\nSampler: RandomSampler"; if (show_all) { - out << "num_samples_: " << num_samples_ << '\n'; - out << "next_id_: " << next_id_ << '\n'; + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any } } - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index b3c595870f..1584166dc3 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -80,11 +80,12 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t } void Sampler::Print(std::ostream &out, bool show_all) const { - out << "(sampler): base\n"; - + // Sampler printing is usually only called in the show_all mode. + // Derived classes will display the name, then call back to this base + // for common info. + // No-op in the summary mode. if (show_all) { - out << "num_rows_: " << num_rows_ << '\n'; - out << "num_samples_: " << num_samples_ << '\n'; + out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_; } } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index f0ff6a2c02..28598da55f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -89,7 +89,14 @@ Status SequentialSampler::ResetSampler() { return Status::OK(); } -void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; } - +void SequentialSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: SequentialSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info + out << "\nStart index: " << start_index_; + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h index 2cb7a9ff8d..06f084fb7a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -49,6 +49,9 @@ class SequentialSampler : public Sampler { // @return - The error code return Status GetNextSample(std::unique_ptr *out_buffer) override; + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary void Print(std::ostream &out, bool show_all) const override; private: diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc index 54491889fc..08a623ed1b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -119,5 +119,14 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffe return Status::OK(); } + +void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: SubsetRandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h index 980ffe578a..ffc7cb17bc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -51,6 +51,11 @@ class SubsetRandomSampler : public Sampler { // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. Status GetNextSample(std::unique_ptr *out_buffer) override; + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + private: // A list of indices (already randomized in constructor). std::vector indices_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc index 759af99352..6bf3d2d85e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -156,5 +156,14 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buf return Status::OK(); } + +void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: WeightedRandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h index 257501250d..1fbe29ed80 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -53,6 +53,11 @@ class WeightedRandomSampler : public Sampler { // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. Status GetNextSample(std::unique_ptr *out_buffer) override; + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + private: // A list of weights for each sample. std::vector weights_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc index fbba73de21..818b5ab3f4 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -33,7 +33,11 @@ namespace mindspore { namespace dataset { TextFileOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { + : builder_device_id_(0), + builder_num_devices_(1), + builder_total_rows_(0), + builder_shuffle_files_(false), + builder_sampler_(nullptr) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -64,7 +68,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr text_file_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, - builder_num_devices_, builder_device_id_); + builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); RETURN_IF_NOT_OK(text_file_op->Init()); *op = std::move(text_file_op); @@ -73,8 +77,9 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr schema, std::vector text_files_list, - int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) - : ParallelOp(num_workers, op_connector_size), + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, + std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), device_id_(device_id), num_devices_(num_device), rows_per_buffer_(rows_per_buffer), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h index 5379263979..5b787d4dad 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "dataset/util/status.h" @@ -112,6 +113,14 @@ class TextFileOp : public ParallelOp { return *this; } + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + private: int32_t builder_device_id_; int32_t builder_num_devices_; @@ -123,6 +132,7 @@ class TextFileOp : public ParallelOp { std::vector builder_text_files_list_; bool builder_shuffle_files_; std::unique_ptr builder_schema_; + std::shared_ptr builder_sampler_; }; // Constructor of TextFileOp @@ -136,9 +146,10 @@ class TextFileOp : public ParallelOp { // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. // @param equal_rows_per_shard - whether or not to get equal rows for each process. + // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id); + bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); // Default destructor ~TextFileOp() = default; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index b05fa54978..a2b04bcc01 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -48,7 +48,11 @@ namespace mindspore { namespace dataset { TFReaderOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) { + : builder_device_id_(0), + builder_num_devices_(1), + builder_total_rows_(0), + builder_equal_rows_per_shard_(false), + builder_sampler_(nullptr) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_worker_connector_size_ = config_manager->worker_connector_size(); @@ -87,11 +91,6 @@ Status TFReaderOp::Builder::ValidateInputs() const { err_msg += "Number of parallel workers is smaller or equal to 0\n"; } - if (!builder_equal_rows_per_shard_ && - builder_dataset_files_list_.size() < static_cast(builder_num_devices_)) { - err_msg += "Not enough tfrecord files provided\n"; - } - if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { err_msg += "Wrong sharding configs\n"; } @@ -125,7 +124,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) std::shared_ptr new_tf_reader_op = std::make_shared( builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, - builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_); + builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, + std::move(builder_sampler_)); RETURN_IF_NOT_OK(new_tf_reader_op->Init()); *out_tf_reader_op = std::move(new_tf_reader_op); @@ -136,8 +136,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, int32_t num_device, - int32_t device_id, bool equal_rows_per_shard) - : ParallelOp(num_workers, op_connector_size), + int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), device_id_(device_id), num_devices_(num_device), rows_per_buffer_(rows_per_buffer), @@ -1018,5 +1018,40 @@ Status TFReaderOp::ComputeColMap() { } return Status::OK(); } + +// During tree prepare phase, operators may have specific post-operations to perform depending on +// their role. +Status TFReaderOp::PrepareNodePostAction() { + // Run common code from super class before adding TFReaderOp specific handling + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + + // Specific handling for this op, we need to do cache op work so assign the sampler to the cache + // TF is a special case because it can support file-based sharding/shuffling, or, if there + // is a cache, then it can also do row-based sampler using the sampler on the cache. + // Thus, pass true for random access op flag when saving the sampler. This is a special case, + // since usually a non-mappable dataset would pass false here. + RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true)); + + // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into + // a simpler producer of all data (no shuffling or sharding or anything) + if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { + device_id_ = 0; + num_devices_ = 1; + total_rows_ = 0; + shuffle_files_ = false; + equal_rows_per_shard_ = false; + sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment) + } else { + // This sanity check had been delayed until now in the prepare loop. + // If we are not in a cache path, then we can validate the the file-based sharding config. + // If we are in a cache path, there is no file-based sharding so the check is not correct in that + // situation. + if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast(num_devices_)) { + RETURN_STATUS_UNEXPECTED("Not enough tfrecord files provided\n"); + } + } + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h index 9d2e38ec6b..9226c4c6c5 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h @@ -153,8 +153,17 @@ class TFReaderOp : public ParallelOp { return *this; } + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + private: std::unique_ptr builder_data_schema_; + std::shared_ptr builder_sampler_; int32_t builder_device_id_; int32_t builder_num_devices_; int32_t builder_num_workers_; @@ -180,10 +189,11 @@ class TFReaderOp : public ParallelOp { // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. // @param equal_rows_per_shard - whether or not to get equal rows for each process. + // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, - int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); + int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); // Default destructor ~TFReaderOp() = default; @@ -236,6 +246,12 @@ class TFReaderOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return dataset_files_list_; } + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override; + private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc index 5d9f0ee92c..958aa65b06 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc @@ -88,7 +88,7 @@ Status VOCOp::Builder::SanityCheck() { VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size), + : ParallelOp(num_workers, queue_size, std::move(sampler)), decode_(decode), row_cnt_(0), buf_cnt_(0), @@ -97,7 +97,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std: folder_path_(folder_path), class_index_(class_index), rows_per_buffer_(rows_per_buffer), - sampler_(std::move(sampler)), data_schema_(std::move(data_schema)) { io_block_queues_.Init(num_workers_, queue_size); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h index a0f5eba4d6..89875341ca 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h @@ -274,7 +274,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { TaskType task_type_; std::string task_mode_; int32_t rows_per_buffer_; - std::shared_ptr sampler_; std::unique_ptr data_schema_; WaitPost wp_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc index 05c224ee2e..259ae8e62b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -129,7 +129,7 @@ Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptrAddToRepeatStack(shared_from_this()); + tree_->AddToEOEOpStack(shared_from_this()); return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc index 8dd622912b..2f88ee1795 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/dataset/engine/execution_tree.cc @@ -88,13 +88,13 @@ Status ExecutionTree::AssignRoot(const std::shared_ptr &op) { } // A print method typically used for debugging -void ExecutionTree::Print(std::ostream &out) const { +void ExecutionTree::Print(std::ostream &out, const std::shared_ptr &op) const { out << "Execution tree summary:\n" << "-----------------------\n"; - this->PrintNode(out, root_, "", true, false); + this->PrintNode(out, op == nullptr ? root_ : op, "", true, false); out << "\nExecution tree operator details:\n" << "--------------------------------\n"; - this->PrintNode(out, root_, "", true, true); + this->PrintNode(out, op == nullptr ? root_ : op, "", true, true); } // A helper functions for doing the recursive printing @@ -269,27 +269,40 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) RETURN_IF_NOT_OK(this->PrepareNode(i)); } - // Then clear the flags from this op now that we have prepared it. - BitClear(&prepare_flags_, op_prep_flags); - // No more children, now we execute any prepare actions before going back up the // the tree on recursive function RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); + // Then clear the flags from this op now that we have prepared it. + BitClear(&prepare_flags_, op_prep_flags); + return Status::OK(); } -// Adds an operator to the repeat stack during prepare phase. -void ExecutionTree::AddToRepeatStack(std::shared_ptr dataset_op) { repeat_stack_.push(dataset_op); } +// Adds an operator to the eoe operator stack during prepare phase. +void ExecutionTree::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } -// Pops an operator from the repeat stack during prepare phase. -std::shared_ptr ExecutionTree::PopFromRepeatStack() { +// Pops an operator from the eoe operator stack during prepare phase. +std::shared_ptr ExecutionTree::PopFromEOEOpStack() { std::shared_ptr top_op = nullptr; - if (!repeat_stack_.empty()) { - top_op = repeat_stack_.top(); - repeat_stack_.pop(); + if (!eoe_stack_.empty()) { + top_op = eoe_stack_.top(); + eoe_stack_.pop(); } return top_op; } + +// Adds a sampler to the sampler stack during prepare phase. +void ExecutionTree::AddToSamplerStack(std::shared_ptr sampler) { sampler_stack_.push(sampler); } + +// Pops an operator from the sampler stack during prepare phase. +std::shared_ptr ExecutionTree::PopFromSamplerStack() { + std::shared_ptr top_sampler = nullptr; + if (!sampler_stack_.empty()) { + top_sampler = sampler_stack_.top(); + sampler_stack_.pop(); + } + return top_sampler; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.h b/mindspore/ccsrc/dataset/engine/execution_tree.h index b0391bf77b..5ebfa539ad 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.h +++ b/mindspore/ccsrc/dataset/engine/execution_tree.h @@ -37,7 +37,8 @@ class ExecutionTree { // Prepare flags used during tree prepare phase enum PrepareFlags { kDePrepNone = 0, - kDePrepRepeat = 1 // Processing a repeat operation + kDePrepRepeat = 1, // Processing a repeat operation + kDePrepCache = 2 // Processing a cache operation }; // State flags for the lifecycle of the tree @@ -118,9 +119,9 @@ class ExecutionTree { // @return Status - The error code return Status Launch(); - // A print method typically used for debugging - // @param out - The output stream to write output to - void Print(std::ostream &out) const; + /// A print method typically used for debugging + /// \param out - The output stream to write output to + void Print(std::ostream &out, const std::shared_ptr &op = nullptr) const; // Returns an iterator positioned at the start // @return Iterator - The iterator @@ -199,14 +200,23 @@ class ExecutionTree { // @return Status - The error code return Status PrepareNode(const std::shared_ptr &dataset_op); - // Adds an operator to the repeat stack during prepare phase. - // @param op - The dataset op to work add to repeat stack - // @return Status - The error code return - void AddToRepeatStack(std::shared_ptr dataset_op); + /// Adds an operator to the eoe operator stack during prepare phase. + /// \param op - The dataset op to work add to eoe stack + /// \return Status - The error code return + void AddToEOEOpStack(std::shared_ptr dataset_op); + + /// Pops an operator from the eoe operator stack during prepare phase. + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromEOEOpStack(); + + /// Adds a sampler to the sampler stack during prepare phase. + /// \param samplerop - The dataset op to work add to eoe stack + /// \return Status - The error code return + void AddToSamplerStack(std::shared_ptr sampler); - // Pops an operator from the repeat stack during prepare phase. - // @return shared_ptr to the popped operator - std::shared_ptr PopFromRepeatStack(); + /// Pops an operator from the sampler stack during prepare phase. + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromSamplerStack(); // Return the pointer to the TaskGroup // @return raw pointer to the TaskGroup @@ -236,9 +246,10 @@ class ExecutionTree { int32_t id_count_; // Counter for generating operator id's uint32_t prepare_flags_; // Flags used during tree prepare TreeState tree_state_; // Tracking the current tree state - std::stack> repeat_stack_; // A stack used during prepare phase std::unique_ptr perf_monitor_; // Performance Monitor std::unique_ptr profiling_manager_; // Profiling manager + std::stack> eoe_stack_; // A stack used during prepare phase + std::stack> sampler_stack_; // A stack used during prepare phase }; } // namespace dataset } // namespace mindspore