Browse Source

Add NonMappableLeafOp and unify TfReader and TextFile, CSV, Clue and CSV

pull/13291/head
hesham 4 years ago
parent
commit
c877ac255b
11 changed files with 559 additions and 1259 deletions
  1. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt
  2. +9
    -224
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc
  3. +7
    -75
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h
  4. +16
    -230
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
  5. +9
    -77
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
  6. +304
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.cc
  7. +177
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h
  8. +8
    -226
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc
  9. +6
    -75
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h
  10. +11
    -273
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc
  11. +11
    -79
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h

+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt View File

@@ -15,6 +15,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
csv_op.cc csv_op.cc
album_op.cc album_op.cc
mappable_leaf_op.cc mappable_leaf_op.cc
nonmappable_leaf_op.cc
) )


set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES


+ 9
- 224
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc View File

@@ -89,23 +89,11 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim


ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
num_samples_(num_samples),
filename_index_(std::make_unique<StringIndex>()),
bool shuffle_files, int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size,
shuffle_files, num_devices, device_id),
clue_files_list_(std::move(clue_files_list)), clue_files_list_(std::move(clue_files_list)),
load_jagged_connector_(true),
cols_to_keyword_(cols_to_keyword),
shuffle_files_(shuffle_files),
finished_reading_dataset_(false),
num_devices_(num_device),
device_id_(device_id),
load_io_block_queue_(true) {
worker_connector_size_ = worker_connector_size;
}
cols_to_keyword_(cols_to_keyword) {}


Status ClueOp::Init() { Status ClueOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_));
@@ -119,16 +107,6 @@ Status ClueOp::Init() {
return Status::OK(); return Status::OK();
} }


Status ClueOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true;
load_io_block_queue_ = true;

RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}

Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) { Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) {
nlohmann::json cursor = js; nlohmann::json cursor = js;
for (int i = 0; i < key_chain.size(); i++) { for (int i = 0; i < key_chain.size(); i++) {
@@ -161,8 +139,7 @@ Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_c
return Status::OK(); return Status::OK();
} }


Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
std::ifstream handle(file); std::ifstream handle(file);
if (!handle.is_open()) { if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file); RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
@@ -228,93 +205,6 @@ Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, con
return Status::OK(); return Status::OK();
} }


Status ClueOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());

// Move register to the front of launching thread, this will fix the problem
// when thread exit unnormally register will failed occasionally.
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));

// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this), "", id()));

RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1), "", id()));

// must be called after launching workers.
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();

while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;

while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}

std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));

if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}

Status ClueOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}

RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}

// A print method typically used for debugging // A print method typically used for debugging
void ClueOp::Print(std::ostream &out, bool show_all) const { void ClueOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) { if (!show_all) {
@@ -326,7 +216,7 @@ void ClueOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all); ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff // Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n";
for (int i = 0; i < clue_files_list_.size(); ++i) { for (int i = 0; i < clue_files_list_.size(); ++i) {
@@ -336,52 +226,6 @@ void ClueOp::Print(std::ostream &out, bool show_all) const {
} }
} }


// Pops an element from a queue in io_block_queues
Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));

return Status::OK();
}

// Pushes an element to a queue in io_block_queues
Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));

return Status::OK();
}

static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}

Status ClueOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();

std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();

if (finished_reading_dataset_) {
break;
}

if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}

Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0; int32_t queue_index = 0;
int64_t pre_count = 0; int64_t pre_count = 0;
@@ -431,66 +275,18 @@ Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
return Status::OK(); return Status::OK();
} }


void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }

bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}

int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

return push;
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status ClueOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}

return Status::OK();
}

Status ClueOp::CalculateNumRowsPerShard() { Status ClueOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value()); int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count; filename_numrows_[it.value()] = count;
all_num_rows_ += count;
num_rows_ += count;
} }
if (all_num_rows_ == 0) {
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED( RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API."); "Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API.");
} }


num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK(); return Status::OK();
} }
@@ -513,17 +309,6 @@ int64_t ClueOp::CountTotalRows(const std::string &file) {
return count; return count;
} }


// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status ClueOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}

return Status::OK();
}

Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) { Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
std::shared_ptr<ClueOp> op; std::shared_ptr<ClueOp> op;
*count = 0; *count = 0;


+ 7
- 75
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h View File

@@ -26,6 +26,8 @@


#include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/jagged_connector.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -34,7 +36,7 @@ using ColKeyMap = std::map<std::string, std::vector<std::string>>;


class JaggedConnector; class JaggedConnector;


class ClueOp : public ParallelOp {
class ClueOp : public NonMappableLeafOp {
public: public:
class Builder { class Builder {
public: public:
@@ -150,18 +152,7 @@ class ClueOp : public ParallelOp {


// Instantiates the internal queues and connectors // Instantiates the internal queues and connectors
// @return Status - the error code returned // @return Status - the error code returned
Status Init();

// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;

// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
Status Init() override;


// Get total rows in files. // Get total rows in files.
// @param files - all clue files. // @param files - all clue files.
@@ -178,72 +169,28 @@ class ClueOp : public ParallelOp {
std::string Name() const override { return "ClueOp"; } std::string Name() const override { return "ClueOp"; }


private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;

// Reads a clue file and loads the data into multiple buffers. // Reads a clue file and loads the data into multiple buffers.
// @param file - the file to read. // @param file - the file to read.
// @param start_offset - the start offset of file. // @param start_offset - the start offset of file.
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);

// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);

// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);

// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;


// Fill the IOBlockQueue. // Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue // @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned. // @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);

// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();

// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;


// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;


// Count number of rows in each file. // Count number of rows in each file.
// @param filename - clue file name. // @param filename - clue file name.
// @return int64_t - the total number of rows in file. // @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file); int64_t CountTotalRows(const std::string &file);


// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();

// @return Status - the error code returned. // @return Status - the error code returned.
Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t); Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t);


@@ -251,22 +198,7 @@ class ClueOp : public ParallelOp {
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;


int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> clue_files_list_; std::vector<std::string> clue_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
ColKeyMap cols_to_keyword_; ColKeyMap cols_to_keyword_;
}; };
} // namespace dataset } // namespace dataset


+ 16
- 230
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc View File

@@ -71,25 +71,13 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size,
shuffle_files, num_devices, device_id),
csv_files_list_(std::move(csv_files_list)), csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim), field_delim_(field_delim),
column_default_list_(column_default), column_default_list_(column_default),
column_name_list_(column_name),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
num_samples_(num_samples),
filename_index_(std::make_unique<StringIndex>()),
load_jagged_connector_(true),
shuffle_files_(shuffle_files),
finished_reading_dataset_(false),
num_devices_(num_device),
device_id_(device_id),
load_io_block_queue_(true) {
worker_connector_size_ = worker_connector_size;
}
column_name_list_(column_name) {}


Status CsvOp::Init() { Status CsvOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_)); RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_));
@@ -98,14 +86,13 @@ Status CsvOp::Init() {
io_block_queues_.Init(num_workers_, safe_queue_size); io_block_queues_.Init(num_workers_, safe_queue_size);


RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
jagged_buffer_connector_ = std::make_shared<JaggedConnector>(num_workers_, 1, worker_connector_size_);
jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);


return Status::OK(); return Status::OK();
} }


CsvOp::CsvParser::CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer,
char field_delim, std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default,
std::string file_path)
CsvOp::CsvParser::CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim,
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path)
: worker_id_(worker_id), : worker_id_(worker_id),
buffer_connector_(connector), buffer_connector_(connector),
csv_rows_per_buffer_(rows_per_buffer), csv_rows_per_buffer_(rows_per_buffer),
@@ -221,6 +208,7 @@ int CsvOp::CsvParser::PutRow(int c) {


if (cur_row_ == csv_rows_per_buffer_) { if (cur_row_ == csv_rows_per_buffer_) {
cur_buffer_->set_tensor_table(std::move(tensor_table_)); cur_buffer_->set_tensor_table(std::move(tensor_table_));

buffer_connector_->Add(worker_id_, std::move(cur_buffer_)); buffer_connector_->Add(worker_id_, std::move(cur_buffer_));


cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone); cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
@@ -499,19 +487,9 @@ Status CsvOp::CsvParser::InitCsvParser() {
return Status::OK(); return Status::OK();
} }


Status CsvOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true;
load_io_block_queue_ = true;

RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}

Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file);
Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_,
file);
csv_parser.SetStartOffset(start_offset); csv_parser.SetStartOffset(start_offset);
csv_parser.SetEndOffset(end_offset); csv_parser.SetEndOffset(end_offset);
std::ifstream ifs; std::ifstream ifs;
@@ -546,93 +524,6 @@ Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, cons
return Status::OK(); return Status::OK();
} }


Status CsvOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());

// Move register to the front of launching thread, this will fix the problem
// when thread exit unnormally register will failed occasionally.
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));

// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this), "", id()));

RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1), "", id()));

// must be called after launching workers.
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();

while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;

while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}

std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));

if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}

Status CsvOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}

RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}

// A print method typically used for debugging // A print method typically used for debugging
void CsvOp::Print(std::ostream &out, bool show_all) const { void CsvOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) { if (!show_all) {
@@ -644,7 +535,7 @@ void CsvOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all); ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff // Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n"; << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n";
for (int i = 0; i < csv_files_list_.size(); ++i) { for (int i = 0; i < csv_files_list_.size(); ++i) {
@@ -654,52 +545,6 @@ void CsvOp::Print(std::ostream &out, bool show_all) const {
} }
} }


// Pops an element from a queue in io_block_queues
Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));

return Status::OK();
}

// Pushes an element to a queue in io_block_queues
Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));

return Status::OK();
}

static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}

Status CsvOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();

std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();

if (finished_reading_dataset_) {
break;
}

if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}

Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0; int32_t queue_index = 0;
int64_t pre_count = 0; int64_t pre_count = 0;
@@ -749,72 +594,24 @@ Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
return Status::OK(); return Status::OK();
} }


void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }

bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}

int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

return push;
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status CsvOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}

return Status::OK();
}

Status CsvOp::CalculateNumRowsPerShard() { Status CsvOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value()); int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count; filename_numrows_[it.value()] = count;
all_num_rows_ += count;
num_rows_ += count;
} }
if (all_num_rows_ == 0) {
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED( RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format."); "Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format.");
} }


num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK(); return Status::OK();
} }


int64_t CsvOp::CountTotalRows(const std::string &file) { int64_t CsvOp::CountTotalRows(const std::string &file) {
CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file);
CsvParser csv_parser(0, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_, file);
std::ifstream ifs; std::ifstream ifs;
ifs.open(file, std::ifstream::in); ifs.open(file, std::ifstream::in);
if (!ifs.is_open()) { if (!ifs.is_open()) {
@@ -835,17 +632,6 @@ int64_t CsvOp::CountTotalRows(const std::string &file) {
return csv_parser.GetTotalRows(); return csv_parser.GetTotalRows();
} }


// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status CsvOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}

return Status::OK();
}

Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) { Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) {
std::shared_ptr<CsvOp> op; std::shared_ptr<CsvOp> op;
*count = 0; *count = 0;


+ 9
- 77
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h View File

@@ -26,6 +26,8 @@
#include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/jagged_connector.h"


namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@@ -34,7 +36,7 @@ const size_t CSV_BUFFER_SIZE = 4096;
using StringIndex = AutoIndexObj<std::string>; using StringIndex = AutoIndexObj<std::string>;
class JaggedConnector; class JaggedConnector;


class CsvOp : public ParallelOp {
class CsvOp : public NonMappableLeafOp {
public: public:
enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; enum RecordType : uint8_t { INT = 0, FLOAT, STRING };


@@ -63,7 +65,7 @@ class CsvOp : public ParallelOp {
public: public:
CsvParser() = delete; CsvParser() = delete;


CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim,
CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim,
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path); std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path);


~CsvParser() = default; ~CsvParser() = default;
@@ -125,7 +127,7 @@ class CsvOp : public ParallelOp {
int CatchException(int c); int CatchException(int c);


int32_t worker_id_; int32_t worker_id_;
std::shared_ptr<JaggedConnector> buffer_connector_;
JaggedConnector *buffer_connector_;
int64_t csv_rows_per_buffer_; int64_t csv_rows_per_buffer_;
const char csv_field_delim_; const char csv_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_; std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_;
@@ -274,18 +276,7 @@ class CsvOp : public ParallelOp {


// Instantiates the internal queues and connectors // Instantiates the internal queues and connectors
// @return Status - the error code returned // @return Status - the error code returned
Status Init();

// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;

// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
Status Init() override;


// Get total rows in files. // Get total rows in files.
// @param files - all csv files. // @param files - all csv files.
@@ -303,11 +294,6 @@ class CsvOp : public ParallelOp {
std::string Name() const override { return "CsvOp"; } std::string Name() const override { return "CsvOp"; }


private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;

// Parses a single row and puts the data into a tensor table. // Parses a single row and puts the data into a tensor table.
// @param line - the content of the row. // @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in. // @param tensor_table - the tensor table to put the parsed data in.
@@ -321,61 +307,22 @@ class CsvOp : public ParallelOp {
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);

// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);

// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);

// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;


// Fill the IOBlockQueue. // Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue // @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned. // @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);

// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();

// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_offset - If file contains the first sample of data.
// @param end_offset - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;


// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;


// Count number of rows in each file. // Count number of rows in each file.
// @param filename - csv file name. // @param filename - csv file name.
// @return int64_t - the total number of rows in file. // @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file); int64_t CountTotalRows(const std::string &file);


// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();

// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;
@@ -394,22 +341,7 @@ class CsvOp : public ParallelOp {
// @return bool - whether column name identical in all CSV files // @return bool - whether column name identical in all CSV files
bool ColumnNameValidate(); bool ColumnNameValidate();


int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> csv_files_list_; std::vector<std::string> csv_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
char field_delim_; char field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_; std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_;
std::vector<std::string> column_name_list_; std::vector<std::string> column_name_list_;


+ 304
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.cc View File

@@ -0,0 +1,304 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"

#include <algorithm>

#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>

#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/util/wait_post.h"

namespace mindspore {
namespace dataset {

NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer,
int64_t total_num_rows, int32_t op_connector_size, bool shuffle_files,
int32_t num_devices, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_devices),
rows_per_buffer_(rows_per_buffer),
filename_index_(std::make_unique<StringIndex>()),
load_io_block_queue_(true),
load_jagged_connector_(true),
total_rows_(total_num_rows),
finished_reading_dataset_(false),
shuffle_files_(shuffle_files),
num_rows_per_shard_(0),
num_rows_(0) {
worker_connector_size_ = worker_connector_size;
}

// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
Status NonMappableLeafOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());

// Put here to avoid register failed when Worker_Entry thread exits unexpected
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));

// launch one thread, responsible for filling mIOBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&NonMappableLeafOp::WaitToFillIOBlockQueue, this), "", id()));

// launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading
// data from disk into buffers
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
num_workers_, std::bind(&NonMappableLeafOp::WorkerEntry, this, std::placeholders::_1), "", id()));

// must be called after launching workers. workers can't be spawned after this post,
// so workers have to be kept alive until the end of the program
TaskManager::FindMe()->Post();

NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}

while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> fetched_buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer));
if (fetched_buffer->eoe()) {
workers_done++;
} else if (total_rows_ == 0 || rows_read < total_rows_) {
// we need to push a buffer
if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) {
// this is last buffer we need, and we only need a part of it
int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read);
RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove));
}

rows_read += fetched_buffer->NumRows();
fetched_buffer->set_id(buffer_id);
buffer_id++;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
} else {
// IOBlockQueue thread needs to:
// -stop pushing stuff to IOBlockQueue
// -call PostEndOfEpoch (will send EOE)
// -wait for reset
//
// Worker threads need to:
// -stop reading the file they are currently reading and throw it away
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
//
// Master thread needs to:
// -tell IOBlockQueue thread to stop pushing
// -tell worker threads to stop reading the file tey are currently reading
// -keep pulling until EOE

// don't think we need a lock for now
load_jagged_connector_ = false;

std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = false;
}
}

// all workers finished reading for this epoch, and we have read all the data from all workers
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));

if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}

std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

RETURN_IF_NOT_OK(PostEndOfData());

return Status::OK();
}

// The entry point for when workers are launched.
Status NonMappableLeafOp::WorkerEntry(int32_t worker_id) {
// must be called first if called by worker spawned by taskgroup
TaskManager::FindMe()->Post();

std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));

while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
MS_LOG(DEBUG) << Name() << " operator worker " << worker_id << " loaded file " << filename << ".";
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}

RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status NonMappableLeafOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status NonMappableLeafOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}

return Status::OK();
}

// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NonMappableLeafOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }

// Pops an element from a queue in io_block_queues
Status NonMappableLeafOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}

// Pushes an element to a queue in io_block_queues
Status NonMappableLeafOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}

// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
Status NonMappableLeafOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
load_jagged_connector_ = true;

{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}

RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();

return Status::OK();
}

bool NonMappableLeafOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset,
int64_t *end_offset, const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}

int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

return push;
}

void NonMappableLeafOp::ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}

Status NonMappableLeafOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();

std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();

if (finished_reading_dataset_) {
break;
}

if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 177
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h View File

@@ -0,0 +1,177 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_

#include <algorithm>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include <utility>
#include <map>

#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"

namespace dataengine {
class Example;
class Feature;
class BytesList;
} // namespace dataengine

namespace mindspore {
namespace dataset {
template <typename T>
class Queue;

template <class T>
class Connector;

class JaggedConnector;
class FilenameBlock;

using StringIndex = AutoIndexObj<std::string>;

class NonMappableLeafOp : public ParallelOp {
public:
// Constructor of TFReaderOp (2)
// @note The builder class should be used to call this constructor.
// @param num_workers - number of worker threads reading data from tf_file files.
// @param worker_connector_size - size of each internal queue.
// @param rows_per_buffer - number of rows that a full buffer will contain.
// @param total_num_rows - Number of rows to read
// @param dataset_files_list - list of filepaths for the dataset files.
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);

// Default destructor
~NonMappableLeafOp() = default;

// Instantiates the internal queues and connectors.
// @return Status - the error code returned.
virtual Status Init() = 0;

// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;

// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;

// Getter method
int64_t rows_per_buffer() const { return rows_per_buffer_; }

// Op name getter
// @return Name of the current Op
std::string Name() const override { return "NonMappableLeafOp"; }

protected:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);

// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();

// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NotifyToFillIOBlockQueue();

// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);

// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);

// Reads a tf_file file and loads the data into multiple buffers.
// @param filename - the tf_file file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0;

// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);

// Calculate number of rows in each shard.
// @return Status - the error code returned.
virtual Status CalculateNumRowsPerShard() = 0;

static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed);

// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0;

int32_t device_id_;
int32_t num_devices_;
bool load_jagged_connector_;
std::unique_ptr<StringIndex> filename_index_;

QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
std::map<std::string, int64_t> filename_numrows_;
bool finished_reading_dataset_;
int64_t total_rows_;

int64_t rows_per_buffer_;
WaitPost io_block_queue_wait_post_;
bool load_io_block_queue_;
std::mutex load_io_block_queue_mutex_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
bool shuffle_files_;
int64_t num_rows_per_shard_;
int64_t num_rows_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_

+ 8
- 226
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc View File

@@ -77,23 +77,11 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {


TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
total_rows_(total_rows),
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_rows, op_connector_size,
shuffle_files, num_devices, device_id),
text_files_list_(std::move(text_files_list)), text_files_list_(std::move(text_files_list)),
shuffle_files_(shuffle_files),
data_schema_(std::move(schema)),
all_num_rows_(0),
num_rows_per_shard_(0),
filename_index_(std::make_unique<StringIndex>()),
finished_reading_dataset_(false),
load_io_block_queue_(true),
load_jagged_connector_(true) {
worker_connector_size_ = worker_connector_size;
}
data_schema_(std::move(schema)) {}


// A print method typically used for debugging // A print method typically used for debugging
void TextFileOp::Print(std::ostream &out, bool show_all) const { void TextFileOp::Print(std::ostream &out, bool show_all) const {
@@ -129,16 +117,6 @@ Status TextFileOp::Init() {
return Status::OK(); return Status::OK();
} }


Status TextFileOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true;
load_io_block_queue_ = true;

RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}

Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) { Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor)); RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
@@ -146,8 +124,7 @@ Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTa
return Status::OK(); return Status::OK();
} }


Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
std::ifstream handle(file); std::ifstream handle(file);
if (!handle.is_open()) { if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file); RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
@@ -197,106 +174,6 @@ Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset,
return Status::OK(); return Status::OK();
} }


Status TextFileOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();

std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}

RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}

// Pops an element from a queue in io_block_queues
Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));

return Status::OK();
}

// Pushes an element to a queue in io_block_queues
Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status TextFileOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status TextFileOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}

return Status::OK();
}

static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}

bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}

int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

return push;
}

Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) { Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0; int32_t queue_index = 0;
int64_t pre_count = 0; int64_t pre_count = 0;
@@ -346,101 +223,6 @@ Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
return Status::OK(); return Status::OK();
} }


Status TextFileOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();

std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();

if (finished_reading_dataset_) {
break;
}

if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}

void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }

Status TextFileOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());

// Move register to the front of launching thread, this will fix the problem
// when thread exit unnormally register will failed occasionally.
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));

// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this), Name(), id()));

// Read data from disk into buffers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));

// must be called after launching workers.
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;

while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (total_rows_ == 0 || rows_read < total_rows_) {
if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) {
int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}

std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));

if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}

std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

RETURN_IF_NOT_OK(PostEndOfData());

return Status::OK();
}

int64_t TextFileOp::CountTotalRows(const std::string &file) { int64_t TextFileOp::CountTotalRows(const std::string &file) {
std::ifstream handle(file); std::ifstream handle(file);
if (!handle.is_open()) { if (!handle.is_open()) {
@@ -463,14 +245,14 @@ Status TextFileOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value()); int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count; filename_numrows_[it.value()] = count;
all_num_rows_ += count;
num_rows_ += count;
} }
if (all_num_rows_ == 0) {
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED( RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API TextFileDataset. Please check file path or dataset API."); "Invalid data, no valid data matching the dataset API TextFileDataset. Please check file path or dataset API.");
} }


num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK(); return Status::OK();
} }


+ 6
- 75
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h View File

@@ -27,6 +27,7 @@
#include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/wait_post.h" #include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/jagged_connector.h"
@@ -35,7 +36,7 @@ namespace mindspore {
namespace dataset { namespace dataset {
using StringIndex = AutoIndexObj<std::string>; using StringIndex = AutoIndexObj<std::string>;


class TextFileOp : public ParallelOp {
class TextFileOp : public NonMappableLeafOp {
public: public:
class Builder { class Builder {
public: public:
@@ -150,18 +151,7 @@ class TextFileOp : public ParallelOp {


// Instantiates the internal queues and connectors // Instantiates the internal queues and connectors
// @return Status - the error code returned // @return Status - the error code returned
Status Init();

// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;

// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
Status Init() override;


// Get total rows in files. // Get total rows in files.
// @param files - all text files. // @param files - all text files.
@@ -178,11 +168,6 @@ class TextFileOp : public ParallelOp {
std::vector<std::string> FileNames() { return text_files_list_; } std::vector<std::string> FileNames() { return text_files_list_; }


private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;

// Parses a single row and puts the data into a tensor table. // Parses a single row and puts the data into a tensor table.
// @param line - the content of the row. // @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in. // @param tensor_table - the tensor table to put the parsed data in.
@@ -196,82 +181,28 @@ class TextFileOp : public ParallelOp {
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;


// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;


// Count number of rows in each file. // Count number of rows in each file.
// @param filename - text file name. // @param filename - text file name.
// @return int64_t - the total number of rows in file. // @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file); int64_t CountTotalRows(const std::string &file);


// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();

// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();

// Fill the IOBlockQueue. // Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue // @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned. // @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);

// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);

// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);

// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;


int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t total_rows_;
std::vector<std::string> text_files_list_; std::vector<std::string> text_files_list_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
int64_t all_num_rows_;
int64_t num_rows_per_shard_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
bool finished_reading_dataset_;
bool load_io_block_queue_;
bool load_jagged_connector_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore


+ 11
- 273
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc View File

@@ -126,26 +126,14 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer,
int64_t total_num_rows, std::vector<std::string> dataset_files_list, int64_t total_num_rows, std::vector<std::string> dataset_files_list,
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size, std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_devices,
int32_t device_id, bool equal_rows_per_shard) int32_t device_id, bool equal_rows_per_shard)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
total_rows_(total_num_rows),
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_num_rows, op_connector_size,
shuffle_files, num_devices, device_id),
dataset_files_list_(std::move(dataset_files_list)), dataset_files_list_(std::move(dataset_files_list)),
columns_to_load_(std::move(columns_to_load)), columns_to_load_(std::move(columns_to_load)),
finished_reading_dataset_(false),
shuffle_files_(shuffle_files),
data_schema_(std::move(data_schema)), data_schema_(std::move(data_schema)),
filename_index_(std::make_unique<StringIndex>()),
load_io_block_queue_(true),
load_jagged_connector_(true),
num_rows_(0),
num_rows_per_shard_(0),
equal_rows_per_shard_(equal_rows_per_shard) {
worker_connector_size_ = worker_connector_size;
}
equal_rows_per_shard_(equal_rows_per_shard) {}


// A print method typically used for debugging // A print method typically used for debugging
void TFReaderOp::Print(std::ostream &out, bool show_all) const { void TFReaderOp::Print(std::ostream &out, bool show_all) const {
@@ -222,194 +210,6 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
} }
return Status::OK(); return Status::OK();
} }
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
Status TFReaderOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());

// Put here to avoid register failed when Worker_Entry thread exits unexpected
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));

// launch one thread, responsible for filling mIOBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this), "", id()));

// launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading
// data from disk into buffers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1), "", id()));

// must be called after launching workers. workers can't be spawned after this post,
// so workers have to be kept alive until the end of the program
TaskManager::FindMe()->Post();

NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}

while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> fetched_buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer));
if (fetched_buffer->eoe()) {
workers_done++;
} else if (total_rows_ == 0 || rows_read < total_rows_) {
// we need to push a buffer
if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) {
// this is last buffer we need, and we only need a part of it
int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read);
RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove));
}

rows_read += fetched_buffer->NumRows();
fetched_buffer->set_id(buffer_id);
buffer_id++;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
} else {
// user specified number of rows they want, and we read enough rows
//
// IOBlockQueue thread needs to:
// -stop pushing stuff to IOBlockQueue
// -call PostEndOfEpoch (will send EOE)
// -wait for reset
//
// Worker threads need to:
// -stop reading the file they are currently reading and throw it away
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
//
// Master thread needs to:
// -tell IOBlockQueue thread to stop pushing
// -tell worker threads to stop reading the file tey are currently reading
// -keep pulling until EOE

// don't think we need a lock for now
load_jagged_connector_ = false;

std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = false;
}
}

// all workers finished reading for this epoch, and we have read all the data from all workers
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));

if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}

std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));

RETURN_IF_NOT_OK(PostEndOfData());

return Status::OK();
}

// static local-only helper function
static void shuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}

// The entry point for when workers are launched.
Status TFReaderOp::WorkerEntry(int32_t worker_id) {
// must be called first if called by worker spawned by taskgroup
TaskManager::FindMe()->Post();

std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));

while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}

RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status TFReaderOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}

return Status::OK();
}

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}

return Status::OK();
}

bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid.";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;

if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}

return push;
}


Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0; int32_t queue_index = 0;
@@ -506,58 +306,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() {
return Status::OK(); return Status::OK();
} }


// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
Status TFReaderOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spawned by taskgroup
TaskManager::FindMe()->Post();

std::vector<int64_t> i_keys;
// Generate a vector of keys that we can shuffle
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();

if (finished_reading_dataset_) {
break;
}

if (shuffle_files_) {
shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys));
} else { // shuffle_files_ == false
RETURN_IF_NOT_OK(FillIOBlockNoShuffle());
}
}

return Status::OK();
}

// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }

// Pops an element from a queue in io_block_queues
Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));

return Status::OK();
}

// Pushes an element to a queue in io_block_queues
Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));

return Status::OK();
}

// Reads a tf_file file and loads the data into multiple buffers. // Reads a tf_file file and loads the data into multiple buffers.
Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset,
const int32_t &worker_id) {
Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
std::ifstream reader; std::ifstream reader;
reader.open(filename); reader.open(filename);
if (!reader) { if (!reader) {
@@ -698,24 +448,6 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
return Status::OK(); return Status::OK();
} }


// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
Status TFReaderOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
load_jagged_connector_ = true;

{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}

RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();

return Status::OK();
}

Status TFReaderOp::LoadBytesList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list, Status TFReaderOp::LoadBytesList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
int32_t *num_elements, std::shared_ptr<Tensor> *tensor) { int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
// kBytesList can map to the following DE types ONLY! // kBytesList can map to the following DE types ONLY!
@@ -1029,6 +761,12 @@ Status TFReaderOp::ComputeColMap() {
} }
return Status::OK(); return Status::OK();
} }
Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
if (shuffle_files_) {
return FillIOBlockShuffle(i_keys);
}
return FillIOBlockNoShuffle();
}


} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 11
- 79
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h View File

@@ -31,6 +31,7 @@
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"


namespace dataengine { namespace dataengine {
class Example; class Example;
@@ -51,7 +52,7 @@ class FilenameBlock;


using StringIndex = AutoIndexObj<std::string>; using StringIndex = AutoIndexObj<std::string>;


class TFReaderOp : public ParallelOp {
class TFReaderOp : public NonMappableLeafOp {
public: public:
class Builder { class Builder {
public: public:
@@ -195,21 +196,7 @@ class TFReaderOp : public ParallelOp {


// Instantiates the internal queues and connectors. // Instantiates the internal queues and connectors.
// @return Status - the error code returned. // @return Status - the error code returned.
Status Init();

// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;

// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;

// Getter method
int64_t rows_per_buffer() const { return rows_per_buffer_; }
Status Init() override;


// Reads all the provided tf_file files and counts the total number of rows. filenames will // Reads all the provided tf_file files and counts the total number of rows. filenames will
// first be sectioned into equal parts, then sections are read in parallel. If threads is // first be sectioned into equal parts, then sections are read in parallel. If threads is
@@ -233,48 +220,13 @@ class TFReaderOp : public ParallelOp {
static bool ValidateFirstRowCrc(const std::string &filename); static bool ValidateFirstRowCrc(const std::string &filename);


private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;

// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();

// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);

// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();

// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NotifyToFillIOBlockQueue();

// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);

// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);

// Reads a tf_file file and loads the data into multiple buffers. // Reads a tf_file file and loads the data into multiple buffers.
// @param filename - the tf_file file to read. // @param filename - the tf_file file to read.
// @param start_offset - the start offset of file. // @param start_offset - the start offset of file.
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset,
const int32_t &worker_id);
Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;


// Parses a single row and puts the data into a tensor table. // Parses a single row and puts the data into a tensor table.
// @param tf_file - the row to be parsed. // @param tf_file - the row to be parsed.
@@ -339,6 +291,11 @@ class TFReaderOp : public ParallelOp {
// @return int63_t - the total number of rows of files read. // @return int63_t - the total number of rows of files read.
static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin, static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin,
const int64_t end); const int64_t end);

protected:
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;

private:
// Fill IO block queue if shuffle is true // Fill IO block queue if shuffle is true
// @param i_keys - shuffle keys. // @param i_keys - shuffle keys.
// @return Status - the error code returned. // @return Status - the error code returned.
@@ -351,43 +308,18 @@ class TFReaderOp : public ParallelOp {
*/ */
Status FillIOBlockNoShuffle(); Status FillIOBlockNoShuffle();


// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);

// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;


// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;


int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t total_rows_;
std::vector<std::string> dataset_files_list_; std::vector<std::string> dataset_files_list_;
std::vector<std::string> columns_to_load_; std::vector<std::string> columns_to_load_;
bool finished_reading_dataset_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
std::unique_ptr<StringIndex> filename_index_;
bool load_io_block_queue_;
bool load_jagged_connector_;

std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
std::mutex load_io_block_queue_mutex_;
std::map<std::string, int64_t> filename_numrows_;
int64_t num_rows_;
int64_t num_rows_per_shard_;

bool equal_rows_per_shard_; bool equal_rows_per_shard_;
}; };
} // namespace dataset } // namespace dataset


Loading…
Cancel
Save