| @@ -15,6 +15,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||||
| csv_op.cc | csv_op.cc | ||||
| album_op.cc | album_op.cc | ||||
| mappable_leaf_op.cc | mappable_leaf_op.cc | ||||
| nonmappable_leaf_op.cc | |||||
| ) | ) | ||||
| set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | ||||
| @@ -89,23 +89,11 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim | |||||
| ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | ||||
| ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, | ||||
| bool shuffle_files, int32_t num_device, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| rows_per_buffer_(rows_per_buffer), | |||||
| num_rows_per_shard_(0), | |||||
| all_num_rows_(0), | |||||
| num_samples_(num_samples), | |||||
| filename_index_(std::make_unique<StringIndex>()), | |||||
| bool shuffle_files, int32_t num_devices, int32_t device_id) | |||||
| : NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size, | |||||
| shuffle_files, num_devices, device_id), | |||||
| clue_files_list_(std::move(clue_files_list)), | clue_files_list_(std::move(clue_files_list)), | ||||
| load_jagged_connector_(true), | |||||
| cols_to_keyword_(cols_to_keyword), | |||||
| shuffle_files_(shuffle_files), | |||||
| finished_reading_dataset_(false), | |||||
| num_devices_(num_device), | |||||
| device_id_(device_id), | |||||
| load_io_block_queue_(true) { | |||||
| worker_connector_size_ = worker_connector_size; | |||||
| } | |||||
| cols_to_keyword_(cols_to_keyword) {} | |||||
| Status ClueOp::Init() { | Status ClueOp::Init() { | ||||
| RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); | RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); | ||||
| @@ -119,16 +107,6 @@ Status ClueOp::Init() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ClueOp::Reset() { | |||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||||
| load_jagged_connector_ = true; | |||||
| load_io_block_queue_ = true; | |||||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) { | Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) { | ||||
| nlohmann::json cursor = js; | nlohmann::json cursor = js; | ||||
| for (int i = 0; i < key_chain.size(); i++) { | for (int i = 0; i < key_chain.size(); i++) { | ||||
| @@ -161,8 +139,7 @@ Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_c | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id) { | |||||
| Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) { | |||||
| std::ifstream handle(file); | std::ifstream handle(file); | ||||
| if (!handle.is_open()) { | if (!handle.is_open()) { | ||||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file); | RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file); | ||||
| @@ -228,93 +205,6 @@ Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, con | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ClueOp::operator()() { | |||||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||||
| // Move register to the front of launching thread, this will fix the problem | |||||
| // when thread exit unnormally register will failed occasionally. | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||||
| // launch one thread, responsible for filling IoBlockQueue | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this), "", id())); | |||||
| RETURN_IF_NOT_OK( | |||||
| tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1), "", id())); | |||||
| // must be called after launching workers. | |||||
| TaskManager::FindMe()->Post(); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| while (!finished_reading_dataset_) { | |||||
| int64_t buffer_id = 0; | |||||
| int32_t workers_done = 0; | |||||
| int64_t rows_read = 0; | |||||
| load_io_block_queue_ = true; | |||||
| while (workers_done < num_workers_) { | |||||
| std::unique_ptr<DataBuffer> buffer; | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||||
| if (buffer->eoe()) { | |||||
| workers_done++; | |||||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||||
| } | |||||
| rows_read += buffer->NumRows(); | |||||
| buffer->set_id(buffer_id++); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||||
| } else { | |||||
| // end of epoch | |||||
| load_jagged_connector_ = false; | |||||
| load_io_block_queue_ = false; | |||||
| } | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | |||||
| NotifyToFillIOBlockQueue(); | |||||
| } else { | |||||
| jagged_buffer_connector_->DoReset(); | |||||
| buffer_id = 0; | |||||
| // Self-reset to start a new iteration | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| } | |||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ClueOp::WorkerEntry(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<FilenameBlock> io_block; | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| while (!io_block->eof()) { | |||||
| if (!io_block->eoe()) { | |||||
| if (load_jagged_connector_) { | |||||
| std::string filename; | |||||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||||
| int64_t start_offset = io_block->GetStartOffset(); | |||||
| int64_t end_offset = io_block->GetEndOffset(); | |||||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||||
| } | |||||
| } else { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void ClueOp::Print(std::ostream &out, bool show_all) const { | void ClueOp::Print(std::ostream &out, bool show_all) const { | ||||
| if (!show_all) { | if (!show_all) { | ||||
| @@ -326,7 +216,7 @@ void ClueOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| ParallelOp::Print(out, show_all); | ParallelOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ | |||||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_ | |||||
| << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | ||||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; | << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; | ||||
| for (int i = 0; i < clue_files_list_.size(); ++i) { | for (int i = 0; i < clue_files_list_.size(); ++i) { | ||||
| @@ -336,52 +226,6 @@ void ClueOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| } | } | ||||
| // Pops an element from a queue in io_block_queues | |||||
| Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes an element to a queue in io_block_queues | |||||
| Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||||
| return Status::OK(); | |||||
| } | |||||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||||
| std::mt19937 rng(seed); | |||||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||||
| } | |||||
| Status ClueOp::WaitToFillIOBlockQueue() { | |||||
| // must be called first if called by worker spanwed by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::vector<int64_t> i_keys; | |||||
| if (shuffle_files_) { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| i_keys.push_back(it.key()); | |||||
| } | |||||
| } | |||||
| uint32_t seed = 0; | |||||
| while (true) { | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||||
| io_block_queue_wait_post_.Clear(); | |||||
| if (finished_reading_dataset_) { | |||||
| break; | |||||
| } | |||||
| if (shuffle_files_) { | |||||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||||
| } | |||||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | ||||
| int32_t queue_index = 0; | int32_t queue_index = 0; | ||||
| int64_t pre_count = 0; | int64_t pre_count = 0; | ||||
| @@ -431,66 +275,18 @@ Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||||
| bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count) { | |||||
| *start_offset = 0; | |||||
| *end_offset = 0; | |||||
| bool push = false; | |||||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||||
| if (device_id_ + 1 < 0) { | |||||
| MS_LOG(ERROR) << "Device id is invalid"; | |||||
| return false; | |||||
| } | |||||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||||
| *start_offset = start_index - pre_count; | |||||
| push = true; | |||||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| if (pre_count >= start_index && pre_count < end_index) { | |||||
| *start_offset = 0; | |||||
| push = true; | |||||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| return push; | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| Status ClueOp::PostEndOfEpoch(int32_t queue_index) { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ClueOp::CalculateNumRowsPerShard() { | Status ClueOp::CalculateNumRowsPerShard() { | ||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | ||||
| int64_t count = CountTotalRows(it.value()); | int64_t count = CountTotalRows(it.value()); | ||||
| filename_numrows_[it.value()] = count; | filename_numrows_[it.value()] = count; | ||||
| all_num_rows_ += count; | |||||
| num_rows_ += count; | |||||
| } | } | ||||
| if (all_num_rows_ == 0) { | |||||
| if (num_rows_ == 0) { | |||||
| RETURN_STATUS_UNEXPECTED( | RETURN_STATUS_UNEXPECTED( | ||||
| "Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API."); | "Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API."); | ||||
| } | } | ||||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_)); | |||||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -513,17 +309,6 @@ int64_t ClueOp::CountTotalRows(const std::string &file) { | |||||
| return count; | return count; | ||||
| } | } | ||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| Status ClueOp::PostEndOfData() { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) { | Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) { | ||||
| std::shared_ptr<ClueOp> op; | std::shared_ptr<ClueOp> op; | ||||
| *count = 0; | *count = 0; | ||||
| @@ -26,6 +26,8 @@ | |||||
| #include "minddata/dataset/util/auto_index.h" | #include "minddata/dataset/util/auto_index.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" | |||||
| #include "minddata/dataset/engine/jagged_connector.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -34,7 +36,7 @@ using ColKeyMap = std::map<std::string, std::vector<std::string>>; | |||||
| class JaggedConnector; | class JaggedConnector; | ||||
| class ClueOp : public ParallelOp { | |||||
| class ClueOp : public NonMappableLeafOp { | |||||
| public: | public: | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| @@ -150,18 +152,7 @@ class ClueOp : public ParallelOp { | |||||
| // Instantiates the internal queues and connectors | // Instantiates the internal queues and connectors | ||||
| // @return Status - the error code returned | // @return Status - the error code returned | ||||
| Status Init(); | |||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - the error code returned. | |||||
| Status operator()() override; | |||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| // @return Status - the error code returned. | |||||
| Status Reset() override; | |||||
| Status Init() override; | |||||
| // Get total rows in files. | // Get total rows in files. | ||||
| // @param files - all clue files. | // @param files - all clue files. | ||||
| @@ -178,72 +169,28 @@ class ClueOp : public ParallelOp { | |||||
| std::string Name() const override { return "ClueOp"; } | std::string Name() const override { return "ClueOp"; } | ||||
| private: | private: | ||||
| // The entry point for when workers are launched. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| Status WorkerEntry(int32_t worker_id) override; | |||||
| // Reads a clue file and loads the data into multiple buffers. | // Reads a clue file and loads the data into multiple buffers. | ||||
| // @param file - the file to read. | // @param file - the file to read. | ||||
| // @param start_offset - the start offset of file. | // @param start_offset - the start offset of file. | ||||
| // @param end_offset - the end offset of file. | // @param end_offset - the end offset of file. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id); | |||||
| // Pops an element from a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to pop from. | |||||
| // @param out_block - the popped element. | |||||
| // @return Status - the error code returned. | |||||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||||
| // Pushes an element to a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to push to. | |||||
| // @param io_block - the element to push onto the queue. | |||||
| // @return Status - the error code returned. | |||||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||||
| // @return Status - the error code returned. | |||||
| Status WaitToFillIOBlockQueue(); | |||||
| Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; | |||||
| // Fill the IOBlockQueue. | // Fill the IOBlockQueue. | ||||
| // @para i_keys - keys of file to fill to the IOBlockQueue | // @para i_keys - keys of file to fill to the IOBlockQueue | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||||
| void NotifyToFillIOBlockQueue(); | |||||
| // Select file and push it to the block queue. | |||||
| // @param file_name - File name. | |||||
| // @param start_file - If file contains the first sample of data. | |||||
| // @param end_file - If file contains the end sample of data. | |||||
| // @param pre_count - Total rows of previous files. | |||||
| // @return Status - the error code returned. | |||||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfEpoch(int32_t queue_index); | |||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override; | |||||
| // Calculate number of rows in each shard. | // Calculate number of rows in each shard. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status CalculateNumRowsPerShard(); | |||||
| Status CalculateNumRowsPerShard() override; | |||||
| // Count number of rows in each file. | // Count number of rows in each file. | ||||
| // @param filename - clue file name. | // @param filename - clue file name. | ||||
| // @return int64_t - the total number of rows in file. | // @return int64_t - the total number of rows in file. | ||||
| int64_t CountTotalRows(const std::string &file); | int64_t CountTotalRows(const std::string &file); | ||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfData(); | |||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t); | Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t); | ||||
| @@ -251,22 +198,7 @@ class ClueOp : public ParallelOp { | |||||
| // @return - Status | // @return - Status | ||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| int32_t device_id_; | |||||
| bool shuffle_files_; | |||||
| bool finished_reading_dataset_; | |||||
| int32_t num_devices_; | |||||
| int64_t rows_per_buffer_; | |||||
| bool load_io_block_queue_; | |||||
| int64_t num_rows_per_shard_; | |||||
| int64_t all_num_rows_; | |||||
| int64_t num_samples_; | |||||
| std::map<std::string, int64_t> filename_numrows_; | |||||
| std::unique_ptr<StringIndex> filename_index_; | |||||
| std::vector<std::string> clue_files_list_; | std::vector<std::string> clue_files_list_; | ||||
| WaitPost io_block_queue_wait_post_; | |||||
| std::shared_ptr<JaggedConnector> jagged_buffer_connector_; | |||||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||||
| bool load_jagged_connector_; | |||||
| ColKeyMap cols_to_keyword_; | ColKeyMap cols_to_keyword_; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -71,25 +71,13 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, | |||||
| const std::vector<std::shared_ptr<BaseRecord>> &column_default, | const std::vector<std::shared_ptr<BaseRecord>> &column_default, | ||||
| const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, | ||||
| int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, | ||||
| int32_t num_device, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| int32_t num_devices, int32_t device_id) | |||||
| : NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size, | |||||
| shuffle_files, num_devices, device_id), | |||||
| csv_files_list_(std::move(csv_files_list)), | csv_files_list_(std::move(csv_files_list)), | ||||
| field_delim_(field_delim), | field_delim_(field_delim), | ||||
| column_default_list_(column_default), | column_default_list_(column_default), | ||||
| column_name_list_(column_name), | |||||
| rows_per_buffer_(rows_per_buffer), | |||||
| num_rows_per_shard_(0), | |||||
| all_num_rows_(0), | |||||
| num_samples_(num_samples), | |||||
| filename_index_(std::make_unique<StringIndex>()), | |||||
| load_jagged_connector_(true), | |||||
| shuffle_files_(shuffle_files), | |||||
| finished_reading_dataset_(false), | |||||
| num_devices_(num_device), | |||||
| device_id_(device_id), | |||||
| load_io_block_queue_(true) { | |||||
| worker_connector_size_ = worker_connector_size; | |||||
| } | |||||
| column_name_list_(column_name) {} | |||||
| Status CsvOp::Init() { | Status CsvOp::Init() { | ||||
| RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_)); | RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_)); | ||||
| @@ -98,14 +86,13 @@ Status CsvOp::Init() { | |||||
| io_block_queues_.Init(num_workers_, safe_queue_size); | io_block_queues_.Init(num_workers_, safe_queue_size); | ||||
| RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); | RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); | ||||
| jagged_buffer_connector_ = std::make_shared<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||||
| jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| CsvOp::CsvParser::CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, | |||||
| char field_delim, std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, | |||||
| std::string file_path) | |||||
| CsvOp::CsvParser::CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim, | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path) | |||||
| : worker_id_(worker_id), | : worker_id_(worker_id), | ||||
| buffer_connector_(connector), | buffer_connector_(connector), | ||||
| csv_rows_per_buffer_(rows_per_buffer), | csv_rows_per_buffer_(rows_per_buffer), | ||||
| @@ -221,6 +208,7 @@ int CsvOp::CsvParser::PutRow(int c) { | |||||
| if (cur_row_ == csv_rows_per_buffer_) { | if (cur_row_ == csv_rows_per_buffer_) { | ||||
| cur_buffer_->set_tensor_table(std::move(tensor_table_)); | cur_buffer_->set_tensor_table(std::move(tensor_table_)); | ||||
| buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); | buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); | ||||
| cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); | ||||
| @@ -499,19 +487,9 @@ Status CsvOp::CsvParser::InitCsvParser() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CsvOp::Reset() { | |||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||||
| load_jagged_connector_ = true; | |||||
| load_io_block_queue_ = true; | |||||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id) { | |||||
| CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file); | |||||
| Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) { | |||||
| CsvParser csv_parser(worker_id, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_, | |||||
| file); | |||||
| csv_parser.SetStartOffset(start_offset); | csv_parser.SetStartOffset(start_offset); | ||||
| csv_parser.SetEndOffset(end_offset); | csv_parser.SetEndOffset(end_offset); | ||||
| std::ifstream ifs; | std::ifstream ifs; | ||||
| @@ -546,93 +524,6 @@ Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, cons | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CsvOp::operator()() { | |||||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||||
| // Move register to the front of launching thread, this will fix the problem | |||||
| // when thread exit unnormally register will failed occasionally. | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||||
| // launch one thread, responsible for filling IoBlockQueue | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this), "", id())); | |||||
| RETURN_IF_NOT_OK( | |||||
| tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1), "", id())); | |||||
| // must be called after launching workers. | |||||
| TaskManager::FindMe()->Post(); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| while (!finished_reading_dataset_) { | |||||
| int64_t buffer_id = 0; | |||||
| int32_t workers_done = 0; | |||||
| int64_t rows_read = 0; | |||||
| load_io_block_queue_ = true; | |||||
| while (workers_done < num_workers_) { | |||||
| std::unique_ptr<DataBuffer> buffer; | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||||
| if (buffer->eoe()) { | |||||
| workers_done++; | |||||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||||
| } | |||||
| rows_read += buffer->NumRows(); | |||||
| buffer->set_id(buffer_id++); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||||
| } else { | |||||
| // end of epoch | |||||
| load_jagged_connector_ = false; | |||||
| load_io_block_queue_ = false; | |||||
| } | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | |||||
| NotifyToFillIOBlockQueue(); | |||||
| } else { | |||||
| jagged_buffer_connector_->DoReset(); | |||||
| buffer_id = 0; | |||||
| // Self-reset to start a new iteration | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| } | |||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::WorkerEntry(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<FilenameBlock> io_block; | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| while (!io_block->eof()) { | |||||
| if (!io_block->eoe()) { | |||||
| if (load_jagged_connector_) { | |||||
| std::string filename; | |||||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||||
| int64_t start_offset = io_block->GetStartOffset(); | |||||
| int64_t end_offset = io_block->GetEndOffset(); | |||||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||||
| } | |||||
| } else { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void CsvOp::Print(std::ostream &out, bool show_all) const { | void CsvOp::Print(std::ostream &out, bool show_all) const { | ||||
| if (!show_all) { | if (!show_all) { | ||||
| @@ -644,7 +535,7 @@ void CsvOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| ParallelOp::Print(out, show_all); | ParallelOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ | |||||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_ | |||||
| << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | ||||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n"; | << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n"; | ||||
| for (int i = 0; i < csv_files_list_.size(); ++i) { | for (int i = 0; i < csv_files_list_.size(); ++i) { | ||||
| @@ -654,52 +545,6 @@ void CsvOp::Print(std::ostream &out, bool show_all) const { | |||||
| } | } | ||||
| } | } | ||||
| // Pops an element from a queue in io_block_queues | |||||
| Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes an element to a queue in io_block_queues | |||||
| Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||||
| return Status::OK(); | |||||
| } | |||||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||||
| std::mt19937 rng(seed); | |||||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||||
| } | |||||
| Status CsvOp::WaitToFillIOBlockQueue() { | |||||
| // must be called first if called by worker spanwed by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::vector<int64_t> i_keys; | |||||
| if (shuffle_files_) { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| i_keys.push_back(it.key()); | |||||
| } | |||||
| } | |||||
| uint32_t seed = 0; | |||||
| while (true) { | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||||
| io_block_queue_wait_post_.Clear(); | |||||
| if (finished_reading_dataset_) { | |||||
| break; | |||||
| } | |||||
| if (shuffle_files_) { | |||||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||||
| } | |||||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | ||||
| int32_t queue_index = 0; | int32_t queue_index = 0; | ||||
| int64_t pre_count = 0; | int64_t pre_count = 0; | ||||
| @@ -749,72 +594,24 @@ Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||||
| bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count) { | |||||
| *start_offset = 0; | |||||
| *end_offset = 0; | |||||
| bool push = false; | |||||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||||
| if (device_id_ + 1 < 0) { | |||||
| MS_LOG(ERROR) << "Device id is invalid"; | |||||
| return false; | |||||
| } | |||||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||||
| *start_offset = start_index - pre_count; | |||||
| push = true; | |||||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| if (pre_count >= start_index && pre_count < end_index) { | |||||
| *start_offset = 0; | |||||
| push = true; | |||||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| return push; | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| Status CsvOp::PostEndOfEpoch(int32_t queue_index) { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::CalculateNumRowsPerShard() { | Status CsvOp::CalculateNumRowsPerShard() { | ||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | ||||
| int64_t count = CountTotalRows(it.value()); | int64_t count = CountTotalRows(it.value()); | ||||
| filename_numrows_[it.value()] = count; | filename_numrows_[it.value()] = count; | ||||
| all_num_rows_ += count; | |||||
| num_rows_ += count; | |||||
| } | } | ||||
| if (all_num_rows_ == 0) { | |||||
| if (num_rows_ == 0) { | |||||
| RETURN_STATUS_UNEXPECTED( | RETURN_STATUS_UNEXPECTED( | ||||
| "Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format."); | "Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format."); | ||||
| } | } | ||||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_)); | |||||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t CsvOp::CountTotalRows(const std::string &file) { | int64_t CsvOp::CountTotalRows(const std::string &file) { | ||||
| CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file); | |||||
| CsvParser csv_parser(0, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_, file); | |||||
| std::ifstream ifs; | std::ifstream ifs; | ||||
| ifs.open(file, std::ifstream::in); | ifs.open(file, std::ifstream::in); | ||||
| if (!ifs.is_open()) { | if (!ifs.is_open()) { | ||||
| @@ -835,17 +632,6 @@ int64_t CsvOp::CountTotalRows(const std::string &file) { | |||||
| return csv_parser.GetTotalRows(); | return csv_parser.GetTotalRows(); | ||||
| } | } | ||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| Status CsvOp::PostEndOfData() { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) { | Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) { | ||||
| std::shared_ptr<CsvOp> op; | std::shared_ptr<CsvOp> op; | ||||
| *count = 0; | *count = 0; | ||||
| @@ -26,6 +26,8 @@ | |||||
| #include "minddata/dataset/util/auto_index.h" | #include "minddata/dataset/util/auto_index.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | #include "minddata/dataset/engine/datasetops/source/io_block.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" | |||||
| #include "minddata/dataset/engine/jagged_connector.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -34,7 +36,7 @@ const size_t CSV_BUFFER_SIZE = 4096; | |||||
| using StringIndex = AutoIndexObj<std::string>; | using StringIndex = AutoIndexObj<std::string>; | ||||
| class JaggedConnector; | class JaggedConnector; | ||||
| class CsvOp : public ParallelOp { | |||||
| class CsvOp : public NonMappableLeafOp { | |||||
| public: | public: | ||||
| enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; | enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; | ||||
| @@ -63,7 +65,7 @@ class CsvOp : public ParallelOp { | |||||
| public: | public: | ||||
| CsvParser() = delete; | CsvParser() = delete; | ||||
| CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim, | |||||
| CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim, | |||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path); | std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path); | ||||
| ~CsvParser() = default; | ~CsvParser() = default; | ||||
| @@ -125,7 +127,7 @@ class CsvOp : public ParallelOp { | |||||
| int CatchException(int c); | int CatchException(int c); | ||||
| int32_t worker_id_; | int32_t worker_id_; | ||||
| std::shared_ptr<JaggedConnector> buffer_connector_; | |||||
| JaggedConnector *buffer_connector_; | |||||
| int64_t csv_rows_per_buffer_; | int64_t csv_rows_per_buffer_; | ||||
| const char csv_field_delim_; | const char csv_field_delim_; | ||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_; | std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_; | ||||
| @@ -274,18 +276,7 @@ class CsvOp : public ParallelOp { | |||||
| // Instantiates the internal queues and connectors | // Instantiates the internal queues and connectors | ||||
| // @return Status - the error code returned | // @return Status - the error code returned | ||||
| Status Init(); | |||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - the error code returned. | |||||
| Status operator()() override; | |||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| // @return Status - the error code returned. | |||||
| Status Reset() override; | |||||
| Status Init() override; | |||||
| // Get total rows in files. | // Get total rows in files. | ||||
| // @param files - all csv files. | // @param files - all csv files. | ||||
| @@ -303,11 +294,6 @@ class CsvOp : public ParallelOp { | |||||
| std::string Name() const override { return "CsvOp"; } | std::string Name() const override { return "CsvOp"; } | ||||
| private: | private: | ||||
| // The entry point for when workers are launched. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| Status WorkerEntry(int32_t worker_id) override; | |||||
| // Parses a single row and puts the data into a tensor table. | // Parses a single row and puts the data into a tensor table. | ||||
| // @param line - the content of the row. | // @param line - the content of the row. | ||||
| // @param tensor_table - the tensor table to put the parsed data in. | // @param tensor_table - the tensor table to put the parsed data in. | ||||
| @@ -321,61 +307,22 @@ class CsvOp : public ParallelOp { | |||||
| // @param end_offset - the end offset of file. | // @param end_offset - the end offset of file. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id); | |||||
| // Pops an element from a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to pop from. | |||||
| // @param out_block - the popped element. | |||||
| // @return Status - the error code returned. | |||||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||||
| // Pushes an element to a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to push to. | |||||
| // @param io_block - the element to push onto the queue. | |||||
| // @return Status - the error code returned. | |||||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||||
| // @return Status - the error code returned. | |||||
| Status WaitToFillIOBlockQueue(); | |||||
| Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; | |||||
| // Fill the IOBlockQueue. | // Fill the IOBlockQueue. | ||||
| // @para i_keys - keys of file to fill to the IOBlockQueue | // @para i_keys - keys of file to fill to the IOBlockQueue | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||||
| void NotifyToFillIOBlockQueue(); | |||||
| // Select file and push it to the block queue. | |||||
| // @param file_name - File name. | |||||
| // @param start_offset - If file contains the first sample of data. | |||||
| // @param end_offset - If file contains the end sample of data. | |||||
| // @param pre_count - Total rows of previous files. | |||||
| // @return Status - the error code returned. | |||||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfEpoch(int32_t queue_index); | |||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override; | |||||
| // Calculate number of rows in each shard. | // Calculate number of rows in each shard. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status CalculateNumRowsPerShard(); | |||||
| Status CalculateNumRowsPerShard() override; | |||||
| // Count number of rows in each file. | // Count number of rows in each file. | ||||
| // @param filename - csv file name. | // @param filename - csv file name. | ||||
| // @return int64_t - the total number of rows in file. | // @return int64_t - the total number of rows in file. | ||||
| int64_t CountTotalRows(const std::string &file); | int64_t CountTotalRows(const std::string &file); | ||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfData(); | |||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| // @return - Status | // @return - Status | ||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| @@ -394,22 +341,7 @@ class CsvOp : public ParallelOp { | |||||
| // @return bool - whether column name identical in all CSV files | // @return bool - whether column name identical in all CSV files | ||||
| bool ColumnNameValidate(); | bool ColumnNameValidate(); | ||||
| int32_t device_id_; | |||||
| bool shuffle_files_; | |||||
| bool finished_reading_dataset_; | |||||
| int32_t num_devices_; | |||||
| int64_t rows_per_buffer_; | |||||
| bool load_io_block_queue_; | |||||
| int64_t num_rows_per_shard_; | |||||
| int64_t all_num_rows_; | |||||
| int64_t num_samples_; | |||||
| std::map<std::string, int64_t> filename_numrows_; | |||||
| std::unique_ptr<StringIndex> filename_index_; | |||||
| std::vector<std::string> csv_files_list_; | std::vector<std::string> csv_files_list_; | ||||
| WaitPost io_block_queue_wait_post_; | |||||
| std::shared_ptr<JaggedConnector> jagged_buffer_connector_; | |||||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||||
| bool load_jagged_connector_; | |||||
| char field_delim_; | char field_delim_; | ||||
| std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_; | std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_; | ||||
| std::vector<std::string> column_name_list_; | std::vector<std::string> column_name_list_; | ||||
| @@ -0,0 +1,304 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/io_block.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | |||||
| #include "minddata/dataset/engine/jagged_connector.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| #include "minddata/dataset/util/wait_post.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, | |||||
| int64_t total_num_rows, int32_t op_connector_size, bool shuffle_files, | |||||
| int32_t num_devices, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| device_id_(device_id), | |||||
| num_devices_(num_devices), | |||||
| rows_per_buffer_(rows_per_buffer), | |||||
| filename_index_(std::make_unique<StringIndex>()), | |||||
| load_io_block_queue_(true), | |||||
| load_jagged_connector_(true), | |||||
| total_rows_(total_num_rows), | |||||
| finished_reading_dataset_(false), | |||||
| shuffle_files_(shuffle_files), | |||||
| num_rows_per_shard_(0), | |||||
| num_rows_(0) { | |||||
| worker_connector_size_ = worker_connector_size; | |||||
| } | |||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| Status NonMappableLeafOp::operator()() { | |||||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||||
| // Put here to avoid register failed when Worker_Entry thread exits unexpected | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||||
| // launch one thread, responsible for filling mIOBlockQueue | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&NonMappableLeafOp::WaitToFillIOBlockQueue, this), "", id())); | |||||
| // launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading | |||||
| // data from disk into buffers | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers( | |||||
| num_workers_, std::bind(&NonMappableLeafOp::WorkerEntry, this, std::placeholders::_1), "", id())); | |||||
| // must be called after launching workers. workers can't be spawned after this post, | |||||
| // so workers have to be kept alive until the end of the program | |||||
| TaskManager::FindMe()->Post(); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| while (!finished_reading_dataset_) { | |||||
| int64_t buffer_id = 0; | |||||
| int32_t workers_done = 0; | |||||
| int64_t rows_read = 0; | |||||
| { | |||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | |||||
| load_io_block_queue_ = true; | |||||
| } | |||||
| while (workers_done < num_workers_) { | |||||
| std::unique_ptr<DataBuffer> fetched_buffer; | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer)); | |||||
| if (fetched_buffer->eoe()) { | |||||
| workers_done++; | |||||
| } else if (total_rows_ == 0 || rows_read < total_rows_) { | |||||
| // we need to push a buffer | |||||
| if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) { | |||||
| // this is last buffer we need, and we only need a part of it | |||||
| int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read); | |||||
| RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove)); | |||||
| } | |||||
| rows_read += fetched_buffer->NumRows(); | |||||
| fetched_buffer->set_id(buffer_id); | |||||
| buffer_id++; | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); | |||||
| } else { | |||||
| // IOBlockQueue thread needs to: | |||||
| // -stop pushing stuff to IOBlockQueue | |||||
| // -call PostEndOfEpoch (will send EOE) | |||||
| // -wait for reset | |||||
| // | |||||
| // Worker threads need to: | |||||
| // -stop reading the file they are currently reading and throw it away | |||||
| // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) | |||||
| // | |||||
| // Master thread needs to: | |||||
| // -tell IOBlockQueue thread to stop pushing | |||||
| // -tell worker threads to stop reading the file tey are currently reading | |||||
| // -keep pulling until EOE | |||||
| // don't think we need a lock for now | |||||
| load_jagged_connector_ = false; | |||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | |||||
| load_io_block_queue_ = false; | |||||
| } | |||||
| } | |||||
| // all workers finished reading for this epoch, and we have read all the data from all workers | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | |||||
| NotifyToFillIOBlockQueue(); | |||||
| } else { | |||||
| jagged_buffer_connector_->DoReset(); | |||||
| buffer_id = 0; | |||||
| // Self-reset to start a new iteration | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| } | |||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||||
| return Status::OK(); | |||||
| } | |||||
| // The entry point for when workers are launched. | |||||
| Status NonMappableLeafOp::WorkerEntry(int32_t worker_id) { | |||||
| // must be called first if called by worker spawned by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<FilenameBlock> io_block; | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| while (!io_block->eof()) { | |||||
| if (!io_block->eoe()) { | |||||
| if (load_jagged_connector_) { | |||||
| std::string filename; | |||||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||||
| int64_t start_offset = io_block->GetStartOffset(); | |||||
| int64_t end_offset = io_block->GetEndOffset(); | |||||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||||
| MS_LOG(DEBUG) << Name() << " operator worker " << worker_id << " loaded file " << filename << "."; | |||||
| } | |||||
| } else { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| Status NonMappableLeafOp::PostEndOfData() { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| Status NonMappableLeafOp::PostEndOfEpoch(int32_t queue_index) { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. | |||||
| void NonMappableLeafOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||||
| // Pops an element from a queue in io_block_queues | |||||
| Status NonMappableLeafOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes an element to a queue in io_block_queues | |||||
| Status NonMappableLeafOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution and | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| Status NonMappableLeafOp::Reset() { | |||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||||
| // start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true | |||||
| load_jagged_connector_ = true; | |||||
| { | |||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | |||||
| load_io_block_queue_ = true; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| return Status::OK(); | |||||
| } | |||||
| bool NonMappableLeafOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, | |||||
| int64_t *end_offset, const int64_t &pre_count) { | |||||
| *start_offset = 0; | |||||
| *end_offset = 0; | |||||
| bool push = false; | |||||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||||
| if (device_id_ + 1 < 0) { | |||||
| MS_LOG(ERROR) << "Device id is invalid"; | |||||
| return false; | |||||
| } | |||||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||||
| *start_offset = start_index - pre_count; | |||||
| push = true; | |||||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| if (pre_count >= start_index && pre_count < end_index) { | |||||
| *start_offset = 0; | |||||
| push = true; | |||||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| return push; | |||||
| } | |||||
| void NonMappableLeafOp::ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||||
| std::mt19937 rng(seed); | |||||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||||
| } | |||||
| Status NonMappableLeafOp::WaitToFillIOBlockQueue() { | |||||
| // must be called first if called by worker spanwed by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::vector<int64_t> i_keys; | |||||
| if (shuffle_files_) { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| i_keys.push_back(it.key()); | |||||
| } | |||||
| } | |||||
| uint32_t seed = 0; | |||||
| while (true) { | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||||
| io_block_queue_wait_post_.Clear(); | |||||
| if (finished_reading_dataset_) { | |||||
| break; | |||||
| } | |||||
| if (shuffle_files_) { | |||||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||||
| } | |||||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,177 @@ | |||||
| /** | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <map> | |||||
| #include "minddata/dataset/util/wait_post.h" | |||||
| #include "minddata/dataset/util/auto_index.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||||
| namespace dataengine { | |||||
| class Example; | |||||
| class Feature; | |||||
| class BytesList; | |||||
| } // namespace dataengine | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| template <typename T> | |||||
| class Queue; | |||||
| template <class T> | |||||
| class Connector; | |||||
| class JaggedConnector; | |||||
| class FilenameBlock; | |||||
| using StringIndex = AutoIndexObj<std::string>; | |||||
| class NonMappableLeafOp : public ParallelOp { | |||||
| public: | |||||
| // Constructor of TFReaderOp (2) | |||||
| // @note The builder class should be used to call this constructor. | |||||
| // @param num_workers - number of worker threads reading data from tf_file files. | |||||
| // @param worker_connector_size - size of each internal queue. | |||||
| // @param rows_per_buffer - number of rows that a full buffer will contain. | |||||
| // @param total_num_rows - Number of rows to read | |||||
| // @param dataset_files_list - list of filepaths for the dataset files. | |||||
| // @param op_connector_size - size of each queue in the connector that the child operator pulls from. | |||||
| // @param columns_to_load - the names of the columns to load data from. | |||||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||||
| NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, | |||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); | |||||
| // Default destructor | |||||
| ~NonMappableLeafOp() = default; | |||||
| // Instantiates the internal queues and connectors. | |||||
| // @return Status - the error code returned. | |||||
| virtual Status Init() = 0; | |||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - the error code returned. | |||||
| Status operator()() override; | |||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution and | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| // @return Status - the error code returned. | |||||
| Status Reset() override; | |||||
| // Getter method | |||||
| int64_t rows_per_buffer() const { return rows_per_buffer_; } | |||||
| // Op name getter | |||||
| // @return Name of the current Op | |||||
| std::string Name() const override { return "NonMappableLeafOp"; } | |||||
| protected: | |||||
| // The entry point for when workers are launched. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| Status WorkerEntry(int32_t worker_id) override; | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfData(); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfEpoch(int32_t queue_index); | |||||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||||
| // @return Status - the error code returned. | |||||
| Status WaitToFillIOBlockQueue(); | |||||
| // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. | |||||
| void NotifyToFillIOBlockQueue(); | |||||
| // Pops an element from a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to pop from. | |||||
| // @param out_block - the popped element. | |||||
| // @return Status - the error code returned. | |||||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||||
| // Pushes an element to a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to push to. | |||||
| // @param io_block - the element to push onto the queue. | |||||
| // @return Status - the error code returned. | |||||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||||
| // Reads a tf_file file and loads the data into multiple buffers. | |||||
| // @param filename - the tf_file file to read. | |||||
| // @param start_offset - the start offset of file. | |||||
| // @param end_offset - the end offset of file. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0; | |||||
| // Select file and push it to the block queue. | |||||
| // @param file_name - File name. | |||||
| // @param start_file - If file contains the first sample of data. | |||||
| // @param end_file - If file contains the end sample of data. | |||||
| // @param pre_count - Total rows of previous files. | |||||
| // @return Status - the error code returned. | |||||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count); | |||||
| // Calculate number of rows in each shard. | |||||
| // @return Status - the error code returned. | |||||
| virtual Status CalculateNumRowsPerShard() = 0; | |||||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed); | |||||
| // Fill the IOBlockQueue. | |||||
| // @para i_keys - keys of file to fill to the IOBlockQueue | |||||
| // @return Status - the error code returned. | |||||
| virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0; | |||||
| int32_t device_id_; | |||||
| int32_t num_devices_; | |||||
| bool load_jagged_connector_; | |||||
| std::unique_ptr<StringIndex> filename_index_; | |||||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||||
| std::map<std::string, int64_t> filename_numrows_; | |||||
| bool finished_reading_dataset_; | |||||
| int64_t total_rows_; | |||||
| int64_t rows_per_buffer_; | |||||
| WaitPost io_block_queue_wait_post_; | |||||
| bool load_io_block_queue_; | |||||
| std::mutex load_io_block_queue_mutex_; | |||||
| std::unique_ptr<JaggedConnector> jagged_buffer_connector_; | |||||
| bool shuffle_files_; | |||||
| int64_t num_rows_per_shard_; | |||||
| int64_t num_rows_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_ | |||||
| @@ -77,23 +77,11 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | ||||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | ||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) | |||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| device_id_(device_id), | |||||
| num_devices_(num_device), | |||||
| rows_per_buffer_(rows_per_buffer), | |||||
| total_rows_(total_rows), | |||||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id) | |||||
| : NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_rows, op_connector_size, | |||||
| shuffle_files, num_devices, device_id), | |||||
| text_files_list_(std::move(text_files_list)), | text_files_list_(std::move(text_files_list)), | ||||
| shuffle_files_(shuffle_files), | |||||
| data_schema_(std::move(schema)), | |||||
| all_num_rows_(0), | |||||
| num_rows_per_shard_(0), | |||||
| filename_index_(std::make_unique<StringIndex>()), | |||||
| finished_reading_dataset_(false), | |||||
| load_io_block_queue_(true), | |||||
| load_jagged_connector_(true) { | |||||
| worker_connector_size_ = worker_connector_size; | |||||
| } | |||||
| data_schema_(std::move(schema)) {} | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void TextFileOp::Print(std::ostream &out, bool show_all) const { | void TextFileOp::Print(std::ostream &out, bool show_all) const { | ||||
| @@ -129,16 +117,6 @@ Status TextFileOp::Init() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TextFileOp::Reset() { | |||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||||
| load_jagged_connector_ = true; | |||||
| load_io_block_queue_ = true; | |||||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) { | Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) { | ||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); | RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); | ||||
| @@ -146,8 +124,7 @@ Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTa | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id) { | |||||
| Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) { | |||||
| std::ifstream handle(file); | std::ifstream handle(file); | ||||
| if (!handle.is_open()) { | if (!handle.is_open()) { | ||||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file); | RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file); | ||||
| @@ -197,106 +174,6 @@ Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TextFileOp::WorkerEntry(int32_t worker_id) { | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<FilenameBlock> io_block; | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| while (!io_block->eof()) { | |||||
| if (!io_block->eoe()) { | |||||
| if (load_jagged_connector_) { | |||||
| std::string filename; | |||||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||||
| int64_t start_offset = io_block->GetStartOffset(); | |||||
| int64_t end_offset = io_block->GetEndOffset(); | |||||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||||
| } | |||||
| } else { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pops an element from a queue in io_block_queues | |||||
| Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes an element to a queue in io_block_queues | |||||
| Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| Status TextFileOp::PostEndOfData() { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||||
| std::mt19937 rng(seed); | |||||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||||
| } | |||||
| bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count) { | |||||
| *start_offset = 0; | |||||
| *end_offset = 0; | |||||
| bool push = false; | |||||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||||
| if (device_id_ + 1 < 0) { | |||||
| MS_LOG(ERROR) << "Device id is invalid"; | |||||
| return false; | |||||
| } | |||||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||||
| *start_offset = start_index - pre_count; | |||||
| push = true; | |||||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| if (pre_count >= start_index && pre_count < end_index) { | |||||
| *start_offset = 0; | |||||
| push = true; | |||||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| return push; | |||||
| } | |||||
| Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | ||||
| int32_t queue_index = 0; | int32_t queue_index = 0; | ||||
| int64_t pre_count = 0; | int64_t pre_count = 0; | ||||
| @@ -346,101 +223,6 @@ Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TextFileOp::WaitToFillIOBlockQueue() { | |||||
| // must be called first if called by worker spanwed by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::vector<int64_t> i_keys; | |||||
| if (shuffle_files_) { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| i_keys.push_back(it.key()); | |||||
| } | |||||
| } | |||||
| uint32_t seed = 0; | |||||
| while (true) { | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||||
| io_block_queue_wait_post_.Clear(); | |||||
| if (finished_reading_dataset_) { | |||||
| break; | |||||
| } | |||||
| if (shuffle_files_) { | |||||
| ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||||
| } | |||||
| RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||||
| Status TextFileOp::operator()() { | |||||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||||
| // Move register to the front of launching thread, this will fix the problem | |||||
| // when thread exit unnormally register will failed occasionally. | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||||
| // launch one thread, responsible for filling IoBlockQueue | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this), Name(), id())); | |||||
| // Read data from disk into buffers | |||||
| RETURN_IF_NOT_OK( | |||||
| tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1), Name(), id())); | |||||
| // must be called after launching workers. | |||||
| TaskManager::FindMe()->Post(); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| while (!finished_reading_dataset_) { | |||||
| int64_t buffer_id = 0; | |||||
| int32_t workers_done = 0; | |||||
| int64_t rows_read = 0; | |||||
| load_io_block_queue_ = true; | |||||
| while (workers_done < num_workers_) { | |||||
| std::unique_ptr<DataBuffer> buffer; | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||||
| if (buffer->eoe()) { | |||||
| workers_done++; | |||||
| } else if (total_rows_ == 0 || rows_read < total_rows_) { | |||||
| if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) { | |||||
| int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read); | |||||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||||
| } | |||||
| rows_read += buffer->NumRows(); | |||||
| buffer->set_id(buffer_id++); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); | |||||
| } else { | |||||
| // end of epoch | |||||
| load_jagged_connector_ = false; | |||||
| load_io_block_queue_ = false; | |||||
| } | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | |||||
| NotifyToFillIOBlockQueue(); | |||||
| } else { | |||||
| jagged_buffer_connector_->DoReset(); | |||||
| buffer_id = 0; | |||||
| // Self-reset to start a new iteration | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| } | |||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||||
| return Status::OK(); | |||||
| } | |||||
| int64_t TextFileOp::CountTotalRows(const std::string &file) { | int64_t TextFileOp::CountTotalRows(const std::string &file) { | ||||
| std::ifstream handle(file); | std::ifstream handle(file); | ||||
| if (!handle.is_open()) { | if (!handle.is_open()) { | ||||
| @@ -463,14 +245,14 @@ Status TextFileOp::CalculateNumRowsPerShard() { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | ||||
| int64_t count = CountTotalRows(it.value()); | int64_t count = CountTotalRows(it.value()); | ||||
| filename_numrows_[it.value()] = count; | filename_numrows_[it.value()] = count; | ||||
| all_num_rows_ += count; | |||||
| num_rows_ += count; | |||||
| } | } | ||||
| if (all_num_rows_ == 0) { | |||||
| if (num_rows_ == 0) { | |||||
| RETURN_STATUS_UNEXPECTED( | RETURN_STATUS_UNEXPECTED( | ||||
| "Invalid data, no valid data matching the dataset API TextFileDataset. Please check file path or dataset API."); | "Invalid data, no valid data matching the dataset API TextFileDataset. Please check file path or dataset API."); | ||||
| } | } | ||||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_)); | |||||
| num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_)); | |||||
| MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "minddata/dataset/util/auto_index.h" | #include "minddata/dataset/util/auto_index.h" | ||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" | |||||
| #include "minddata/dataset/util/queue.h" | #include "minddata/dataset/util/queue.h" | ||||
| #include "minddata/dataset/util/wait_post.h" | #include "minddata/dataset/util/wait_post.h" | ||||
| #include "minddata/dataset/engine/jagged_connector.h" | #include "minddata/dataset/engine/jagged_connector.h" | ||||
| @@ -35,7 +36,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| using StringIndex = AutoIndexObj<std::string>; | using StringIndex = AutoIndexObj<std::string>; | ||||
| class TextFileOp : public ParallelOp { | |||||
| class TextFileOp : public NonMappableLeafOp { | |||||
| public: | public: | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| @@ -150,18 +151,7 @@ class TextFileOp : public ParallelOp { | |||||
| // Instantiates the internal queues and connectors | // Instantiates the internal queues and connectors | ||||
| // @return Status - the error code returned | // @return Status - the error code returned | ||||
| Status Init(); | |||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - the error code returned. | |||||
| Status operator()() override; | |||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| // @return Status - the error code returned. | |||||
| Status Reset() override; | |||||
| Status Init() override; | |||||
| // Get total rows in files. | // Get total rows in files. | ||||
| // @param files - all text files. | // @param files - all text files. | ||||
| @@ -178,11 +168,6 @@ class TextFileOp : public ParallelOp { | |||||
| std::vector<std::string> FileNames() { return text_files_list_; } | std::vector<std::string> FileNames() { return text_files_list_; } | ||||
| private: | private: | ||||
| // The entry point for when workers are launched. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| Status WorkerEntry(int32_t worker_id) override; | |||||
| // Parses a single row and puts the data into a tensor table. | // Parses a single row and puts the data into a tensor table. | ||||
| // @param line - the content of the row. | // @param line - the content of the row. | ||||
| // @param tensor_table - the tensor table to put the parsed data in. | // @param tensor_table - the tensor table to put the parsed data in. | ||||
| @@ -196,82 +181,28 @@ class TextFileOp : public ParallelOp { | |||||
| // @param end_offset - the end offset of file. | // @param end_offset - the end offset of file. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t worker_id); | |||||
| Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; | |||||
| // Calculate number of rows in each shard. | // Calculate number of rows in each shard. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status CalculateNumRowsPerShard(); | |||||
| Status CalculateNumRowsPerShard() override; | |||||
| // Count number of rows in each file. | // Count number of rows in each file. | ||||
| // @param filename - text file name. | // @param filename - text file name. | ||||
| // @return int64_t - the total number of rows in file. | // @return int64_t - the total number of rows in file. | ||||
| int64_t CountTotalRows(const std::string &file); | int64_t CountTotalRows(const std::string &file); | ||||
| // Notifies the thread which called FillIoBlockQueue to resume execution | |||||
| void NotifyToFillIOBlockQueue(); | |||||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||||
| // @return Status - the error code returned. | |||||
| Status WaitToFillIOBlockQueue(); | |||||
| // Fill the IOBlockQueue. | // Fill the IOBlockQueue. | ||||
| // @para i_keys - keys of file to fill to the IOBlockQueue | // @para i_keys - keys of file to fill to the IOBlockQueue | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); | |||||
| // Select file and push it to the block queue. | |||||
| // @param file_name - File name. | |||||
| // @param start_file - If file contains the first sample of data. | |||||
| // @param end_file - If file contains the end sample of data. | |||||
| // @param pre_count - Total rows of previous files. | |||||
| // @return Status - the error code returned. | |||||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count); | |||||
| // Pops an element from a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to pop from. | |||||
| // @param out_block - the popped element. | |||||
| // @return Status - the error code returned. | |||||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||||
| // Pushes an element to a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to push to. | |||||
| // @param io_block - the element to push onto the queue. | |||||
| // @return Status - the error code returned. | |||||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfData(); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfEpoch(int32_t queue_index); | |||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override; | |||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| // @return - Status | // @return - Status | ||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| int32_t device_id_; | |||||
| int32_t num_devices_; | |||||
| int64_t rows_per_buffer_; | |||||
| int64_t total_rows_; | |||||
| std::vector<std::string> text_files_list_; | std::vector<std::string> text_files_list_; | ||||
| bool shuffle_files_; | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| int64_t all_num_rows_; | |||||
| int64_t num_rows_per_shard_; | |||||
| std::map<std::string, int64_t> filename_numrows_; | |||||
| std::unique_ptr<StringIndex> filename_index_; | |||||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||||
| WaitPost io_block_queue_wait_post_; | |||||
| bool finished_reading_dataset_; | |||||
| bool load_io_block_queue_; | |||||
| bool load_jagged_connector_; | |||||
| std::unique_ptr<JaggedConnector> jagged_buffer_connector_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -126,26 +126,14 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op) | |||||
| TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, | TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, | ||||
| int64_t total_num_rows, std::vector<std::string> dataset_files_list, | int64_t total_num_rows, std::vector<std::string> dataset_files_list, | ||||
| std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, | ||||
| std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device, | |||||
| std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_devices, | |||||
| int32_t device_id, bool equal_rows_per_shard) | int32_t device_id, bool equal_rows_per_shard) | ||||
| : ParallelOp(num_workers, op_connector_size), | |||||
| device_id_(device_id), | |||||
| num_devices_(num_device), | |||||
| rows_per_buffer_(rows_per_buffer), | |||||
| total_rows_(total_num_rows), | |||||
| : NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_num_rows, op_connector_size, | |||||
| shuffle_files, num_devices, device_id), | |||||
| dataset_files_list_(std::move(dataset_files_list)), | dataset_files_list_(std::move(dataset_files_list)), | ||||
| columns_to_load_(std::move(columns_to_load)), | columns_to_load_(std::move(columns_to_load)), | ||||
| finished_reading_dataset_(false), | |||||
| shuffle_files_(shuffle_files), | |||||
| data_schema_(std::move(data_schema)), | data_schema_(std::move(data_schema)), | ||||
| filename_index_(std::make_unique<StringIndex>()), | |||||
| load_io_block_queue_(true), | |||||
| load_jagged_connector_(true), | |||||
| num_rows_(0), | |||||
| num_rows_per_shard_(0), | |||||
| equal_rows_per_shard_(equal_rows_per_shard) { | |||||
| worker_connector_size_ = worker_connector_size; | |||||
| } | |||||
| equal_rows_per_shard_(equal_rows_per_shard) {} | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void TFReaderOp::Print(std::ostream &out, bool show_all) const { | void TFReaderOp::Print(std::ostream &out, bool show_all) const { | ||||
| @@ -222,194 +210,6 @@ Status TFReaderOp::CalculateNumRowsPerShard() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| Status TFReaderOp::operator()() { | |||||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||||
| // Put here to avoid register failed when Worker_Entry thread exits unexpected | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); | |||||
| // launch one thread, responsible for filling mIOBlockQueue | |||||
| RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this), "", id())); | |||||
| // launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading | |||||
| // data from disk into buffers | |||||
| RETURN_IF_NOT_OK( | |||||
| tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1), "", id())); | |||||
| // must be called after launching workers. workers can't be spawned after this post, | |||||
| // so workers have to be kept alive until the end of the program | |||||
| TaskManager::FindMe()->Post(); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| while (!finished_reading_dataset_) { | |||||
| int64_t buffer_id = 0; | |||||
| int32_t workers_done = 0; | |||||
| int64_t rows_read = 0; | |||||
| { | |||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | |||||
| load_io_block_queue_ = true; | |||||
| } | |||||
| while (workers_done < num_workers_) { | |||||
| std::unique_ptr<DataBuffer> fetched_buffer; | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer)); | |||||
| if (fetched_buffer->eoe()) { | |||||
| workers_done++; | |||||
| } else if (total_rows_ == 0 || rows_read < total_rows_) { | |||||
| // we need to push a buffer | |||||
| if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) { | |||||
| // this is last buffer we need, and we only need a part of it | |||||
| int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read); | |||||
| RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove)); | |||||
| } | |||||
| rows_read += fetched_buffer->NumRows(); | |||||
| fetched_buffer->set_id(buffer_id); | |||||
| buffer_id++; | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); | |||||
| } else { | |||||
| // user specified number of rows they want, and we read enough rows | |||||
| // | |||||
| // IOBlockQueue thread needs to: | |||||
| // -stop pushing stuff to IOBlockQueue | |||||
| // -call PostEndOfEpoch (will send EOE) | |||||
| // -wait for reset | |||||
| // | |||||
| // Worker threads need to: | |||||
| // -stop reading the file they are currently reading and throw it away | |||||
| // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) | |||||
| // | |||||
| // Master thread needs to: | |||||
| // -tell IOBlockQueue thread to stop pushing | |||||
| // -tell worker threads to stop reading the file tey are currently reading | |||||
| // -keep pulling until EOE | |||||
| // don't think we need a lock for now | |||||
| load_jagged_connector_ = false; | |||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | |||||
| load_io_block_queue_ = false; | |||||
| } | |||||
| } | |||||
| // all workers finished reading for this epoch, and we have read all the data from all workers | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | |||||
| if (IsLastIteration()) { | |||||
| finished_reading_dataset_ = true; | |||||
| NotifyToFillIOBlockQueue(); | |||||
| } else { | |||||
| jagged_buffer_connector_->DoReset(); | |||||
| buffer_id = 0; | |||||
| // Self-reset to start a new iteration | |||||
| RETURN_IF_NOT_OK(Reset()); | |||||
| } | |||||
| UpdateRepeatAndEpochCounter(); | |||||
| } | |||||
| std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF); | |||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); | |||||
| RETURN_IF_NOT_OK(PostEndOfData()); | |||||
| return Status::OK(); | |||||
| } | |||||
| // static local-only helper function | |||||
| static void shuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) { | |||||
| std::mt19937 rng(seed); | |||||
| std::shuffle(i_keys->begin(), i_keys->end(), rng); | |||||
| } | |||||
| // The entry point for when workers are launched. | |||||
| Status TFReaderOp::WorkerEntry(int32_t worker_id) { | |||||
| // must be called first if called by worker spawned by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::unique_ptr<FilenameBlock> io_block; | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| while (!io_block->eof()) { | |||||
| if (!io_block->eoe()) { | |||||
| if (load_jagged_connector_) { | |||||
| std::string filename; | |||||
| RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); | |||||
| int64_t start_offset = io_block->GetStartOffset(); | |||||
| int64_t end_offset = io_block->GetEndOffset(); | |||||
| RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); | |||||
| MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << "."; | |||||
| } | |||||
| } else { | |||||
| std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE); | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); | |||||
| } | |||||
| RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| Status TFReaderOp::PostEndOfData() { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) { | |||||
| for (int i = 0; i < num_workers_; ++i) { | |||||
| std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe); | |||||
| RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count) { | |||||
| *start_offset = 0; | |||||
| *end_offset = 0; | |||||
| bool push = false; | |||||
| int64_t start_index = device_id_ * num_rows_per_shard_; | |||||
| if (device_id_ + 1 < 0) { | |||||
| MS_LOG(ERROR) << "Device id is invalid."; | |||||
| return false; | |||||
| } | |||||
| int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; | |||||
| if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { | |||||
| *start_offset = start_index - pre_count; | |||||
| push = true; | |||||
| if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| if (pre_count >= start_index && pre_count < end_index) { | |||||
| *start_offset = 0; | |||||
| push = true; | |||||
| if (pre_count + filename_numrows_[file_name] >= end_index) { | |||||
| *end_offset = end_index - pre_count; | |||||
| } else { | |||||
| *end_offset = filename_numrows_[file_name]; | |||||
| } | |||||
| } | |||||
| return push; | |||||
| } | |||||
| Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { | Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { | ||||
| int32_t queue_index = 0; | int32_t queue_index = 0; | ||||
| @@ -506,58 +306,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||||
| Status TFReaderOp::WaitToFillIOBlockQueue() { | |||||
| // must be called first if called by worker spawned by taskgroup | |||||
| TaskManager::FindMe()->Post(); | |||||
| std::vector<int64_t> i_keys; | |||||
| // Generate a vector of keys that we can shuffle | |||||
| if (shuffle_files_) { | |||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | |||||
| i_keys.push_back(it.key()); | |||||
| } | |||||
| } | |||||
| uint32_t seed = 0; | |||||
| while (true) { | |||||
| RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); | |||||
| io_block_queue_wait_post_.Clear(); | |||||
| if (finished_reading_dataset_) { | |||||
| break; | |||||
| } | |||||
| if (shuffle_files_) { | |||||
| shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); | |||||
| RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys)); | |||||
| } else { // shuffle_files_ == false | |||||
| RETURN_IF_NOT_OK(FillIOBlockNoShuffle()); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. | |||||
| void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } | |||||
| // Pops an element from a queue in io_block_queues | |||||
| Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Pushes an element to a queue in io_block_queues | |||||
| Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) { | |||||
| RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); | |||||
| return Status::OK(); | |||||
| } | |||||
| // Reads a tf_file file and loads the data into multiple buffers. | // Reads a tf_file file and loads the data into multiple buffers. | ||||
| Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t &worker_id) { | |||||
| Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) { | |||||
| std::ifstream reader; | std::ifstream reader; | ||||
| reader.open(filename); | reader.open(filename); | ||||
| if (!reader) { | if (!reader) { | ||||
| @@ -698,24 +448,6 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution and | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| Status TFReaderOp::Reset() { | |||||
| MS_LOG(DEBUG) << Name() << " performing a self-reset."; | |||||
| // start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true | |||||
| load_jagged_connector_ = true; | |||||
| { | |||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | |||||
| load_io_block_queue_ = true; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ParallelOp::Reset()); | |||||
| NotifyToFillIOBlockQueue(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, | Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, | ||||
| int32_t *num_elements, std::shared_ptr<Tensor> *tensor) { | int32_t *num_elements, std::shared_ptr<Tensor> *tensor) { | ||||
| // kBytesList can map to the following DE types ONLY! | // kBytesList can map to the following DE types ONLY! | ||||
| @@ -1029,6 +761,12 @@ Status TFReaderOp::ComputeColMap() { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { | |||||
| if (shuffle_files_) { | |||||
| return FillIOBlockShuffle(i_keys); | |||||
| } | |||||
| return FillIOBlockNoShuffle(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/data_schema.h" | #include "minddata/dataset/engine/data_schema.h" | ||||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | #include "minddata/dataset/engine/datasetops/parallel_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h" | |||||
| namespace dataengine { | namespace dataengine { | ||||
| class Example; | class Example; | ||||
| @@ -51,7 +52,7 @@ class FilenameBlock; | |||||
| using StringIndex = AutoIndexObj<std::string>; | using StringIndex = AutoIndexObj<std::string>; | ||||
| class TFReaderOp : public ParallelOp { | |||||
| class TFReaderOp : public NonMappableLeafOp { | |||||
| public: | public: | ||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| @@ -195,21 +196,7 @@ class TFReaderOp : public ParallelOp { | |||||
| // Instantiates the internal queues and connectors. | // Instantiates the internal queues and connectors. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status Init(); | |||||
| // Class functor operator () override. | |||||
| // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will | |||||
| // provide the master loop that drives the logic for performing the work | |||||
| // @return Status - the error code returned. | |||||
| Status operator()() override; | |||||
| // Overrides base class reset method. Cleans up any state info from it's previous execution and | |||||
| // reinitializes itself so that it can be executed again, as if it was just created. | |||||
| // @return Status - the error code returned. | |||||
| Status Reset() override; | |||||
| // Getter method | |||||
| int64_t rows_per_buffer() const { return rows_per_buffer_; } | |||||
| Status Init() override; | |||||
| // Reads all the provided tf_file files and counts the total number of rows. filenames will | // Reads all the provided tf_file files and counts the total number of rows. filenames will | ||||
| // first be sectioned into equal parts, then sections are read in parallel. If threads is | // first be sectioned into equal parts, then sections are read in parallel. If threads is | ||||
| @@ -233,48 +220,13 @@ class TFReaderOp : public ParallelOp { | |||||
| static bool ValidateFirstRowCrc(const std::string &filename); | static bool ValidateFirstRowCrc(const std::string &filename); | ||||
| private: | private: | ||||
| // The entry point for when workers are launched. | |||||
| // @param worker_id - the id of the worker that is executing this function. | |||||
| // @return Status - the error code returned. | |||||
| Status WorkerEntry(int32_t worker_id) override; | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. | |||||
| // When the worker pops this control indicator, it will shut itself down gracefully. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfData(); | |||||
| // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker | |||||
| // pops this control indicator, it will wait until the next epoch starts and then resume execution. | |||||
| // @return Status - the error code returned. | |||||
| Status PostEndOfEpoch(int32_t queue_index); | |||||
| // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. | |||||
| // @return Status - the error code returned. | |||||
| Status WaitToFillIOBlockQueue(); | |||||
| // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. | |||||
| void NotifyToFillIOBlockQueue(); | |||||
| // Pops an element from a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to pop from. | |||||
| // @param out_block - the popped element. | |||||
| // @return Status - the error code returned. | |||||
| Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block); | |||||
| // Pushes an element to a queue in IOBlockQueue. | |||||
| // @param index - the index of the queue to push to. | |||||
| // @param io_block - the element to push onto the queue. | |||||
| // @return Status - the error code returned. | |||||
| Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block); | |||||
| // Reads a tf_file file and loads the data into multiple buffers. | // Reads a tf_file file and loads the data into multiple buffers. | ||||
| // @param filename - the tf_file file to read. | // @param filename - the tf_file file to read. | ||||
| // @param start_offset - the start offset of file. | // @param start_offset - the start offset of file. | ||||
| // @param end_offset - the end offset of file. | // @param end_offset - the end offset of file. | ||||
| // @param worker_id - the id of the worker that is executing this function. | // @param worker_id - the id of the worker that is executing this function. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, | |||||
| const int32_t &worker_id); | |||||
| Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; | |||||
| // Parses a single row and puts the data into a tensor table. | // Parses a single row and puts the data into a tensor table. | ||||
| // @param tf_file - the row to be parsed. | // @param tf_file - the row to be parsed. | ||||
| @@ -339,6 +291,11 @@ class TFReaderOp : public ParallelOp { | |||||
| // @return int63_t - the total number of rows of files read. | // @return int63_t - the total number of rows of files read. | ||||
| static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin, | static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin, | ||||
| const int64_t end); | const int64_t end); | ||||
| protected: | |||||
| Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override; | |||||
| private: | |||||
| // Fill IO block queue if shuffle is true | // Fill IO block queue if shuffle is true | ||||
| // @param i_keys - shuffle keys. | // @param i_keys - shuffle keys. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| @@ -351,43 +308,18 @@ class TFReaderOp : public ParallelOp { | |||||
| */ | */ | ||||
| Status FillIOBlockNoShuffle(); | Status FillIOBlockNoShuffle(); | ||||
| // Select file and push it to the block queue. | |||||
| // @param file_name - File name. | |||||
| // @param start_file - If file contains the first sample of data. | |||||
| // @param end_file - If file contains the end sample of data. | |||||
| // @param pre_count - Total rows of previous files. | |||||
| // @return Status - the error code returned. | |||||
| bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, | |||||
| const int64_t &pre_count); | |||||
| // Calculate number of rows in each shard. | // Calculate number of rows in each shard. | ||||
| // @return Status - the error code returned. | // @return Status - the error code returned. | ||||
| Status CalculateNumRowsPerShard(); | |||||
| Status CalculateNumRowsPerShard() override; | |||||
| // Private function for computing the assignment of the column name map. | // Private function for computing the assignment of the column name map. | ||||
| // @return - Status | // @return - Status | ||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| int32_t device_id_; | |||||
| int32_t num_devices_; | |||||
| int64_t rows_per_buffer_; | |||||
| int64_t total_rows_; | |||||
| std::vector<std::string> dataset_files_list_; | std::vector<std::string> dataset_files_list_; | ||||
| std::vector<std::string> columns_to_load_; | std::vector<std::string> columns_to_load_; | ||||
| bool finished_reading_dataset_; | |||||
| bool shuffle_files_; | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::unique_ptr<StringIndex> filename_index_; | |||||
| bool load_io_block_queue_; | |||||
| bool load_jagged_connector_; | |||||
| std::unique_ptr<JaggedConnector> jagged_buffer_connector_; | |||||
| QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; | |||||
| WaitPost io_block_queue_wait_post_; | |||||
| std::mutex load_io_block_queue_mutex_; | |||||
| std::map<std::string, int64_t> filename_numrows_; | |||||
| int64_t num_rows_; | |||||
| int64_t num_rows_per_shard_; | |||||
| bool equal_rows_per_shard_; | bool equal_rows_per_shard_; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||