|
|
|
@@ -105,6 +105,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 |
|
|
|
data_schema_(std::move(data_schema)), |
|
|
|
filename_index_(make_unique<StringIndex>()), |
|
|
|
load_io_block_queue_(true), |
|
|
|
load_jagged_connector_(true), |
|
|
|
num_rows_(0), |
|
|
|
num_rows_per_shard_(0), |
|
|
|
equal_rows_per_shard_(equal_rows_per_shard) { |
|
|
|
@@ -203,6 +204,25 @@ Status TFReaderOp::operator()() { |
|
|
|
buffer_id++; |
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); |
|
|
|
} else { |
|
|
|
// user specified number of rows they want, and we read enough rows |
|
|
|
// |
|
|
|
// IOBlockQueue thread needs to: |
|
|
|
// -stop pushing stuff to IOBlockQueue |
|
|
|
// -call PostEndOfEpoch (will send EOE) |
|
|
|
// -wait for reset |
|
|
|
// |
|
|
|
// Worker threads need to: |
|
|
|
// -stop reading the file they are currently reading and throw it away |
|
|
|
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) |
|
|
|
// |
|
|
|
// Master thread needs to: |
|
|
|
// -tell IOBlockQueue thread to stop pushing |
|
|
|
// -tell worker threads to stop reading the file tey are currently reading |
|
|
|
// -keep pulling until EOE |
|
|
|
|
|
|
|
// don't think we need a lock for now |
|
|
|
load_jagged_connector_ = false; |
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); |
|
|
|
load_io_block_queue_ = false; |
|
|
|
} |
|
|
|
@@ -245,12 +265,14 @@ Status TFReaderOp::WorkerEntry(int32_t worker_id) { |
|
|
|
|
|
|
|
while (!io_block->eof()) { |
|
|
|
if (!io_block->eoe()) { |
|
|
|
std::string filename; |
|
|
|
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); |
|
|
|
int64_t start_offset = io_block->GetStartOffset(); |
|
|
|
int64_t end_offset = io_block->GetEndOffset(); |
|
|
|
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); |
|
|
|
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << common::SafeCStr(filename) << "."; |
|
|
|
if (load_jagged_connector_) { |
|
|
|
std::string filename; |
|
|
|
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); |
|
|
|
int64_t start_offset = io_block->GetStartOffset(); |
|
|
|
int64_t end_offset = io_block->GetEndOffset(); |
|
|
|
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); |
|
|
|
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << filename << "."; |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE); |
|
|
|
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); |
|
|
|
@@ -478,6 +500,10 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off |
|
|
|
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>(); |
|
|
|
|
|
|
|
while (reader.peek() != EOF) { |
|
|
|
if (!load_jagged_connector_) { |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
// read length |
|
|
|
int64_t record_length = 0; |
|
|
|
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t))); |
|
|
|
@@ -599,6 +625,9 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table |
|
|
|
// Overrides base class reset method. Cleans up any state info from it's previous execution and |
|
|
|
// reinitializes itself so that it can be executed again, as if it was just created. |
|
|
|
Status TFReaderOp::Reset() { |
|
|
|
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true |
|
|
|
load_jagged_connector_ = true; |
|
|
|
|
|
|
|
{ |
|
|
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); |
|
|
|
load_io_block_queue_ = true; |
|
|
|
|