| @@ -433,11 +433,13 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { | |||||
| int64_t start_offset = 0; | int64_t start_offset = 0; | ||||
| int64_t end_offset = 0; | int64_t end_offset = 0; | ||||
| bool finish = false; | bool finish = false; | ||||
| bool end_of_epoch = false; | |||||
| while (!finish) { | while (!finish) { | ||||
| for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { | ||||
| { | { | ||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | ||||
| if (load_io_block_queue_ == false) { | if (load_io_block_queue_ == false) { | ||||
| end_of_epoch = true; | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -461,7 +463,8 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { | |||||
| pre_count += filename_numrows_[file_name]; | pre_count += filename_numrows_[file_name]; | ||||
| } | } | ||||
| } | } | ||||
| if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||||
| if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ && | |||||
| !end_of_epoch) { | |||||
| finish = false; | finish = false; | ||||
| } else { | } else { | ||||
| finish = true; | finish = true; | ||||
| @@ -478,12 +481,14 @@ Status TFReaderOp::FillIOBlockNoShuffle() { | |||||
| int64_t start_offset = 0; | int64_t start_offset = 0; | ||||
| int64_t end_offset = 0; | int64_t end_offset = 0; | ||||
| bool finish = false; | bool finish = false; | ||||
| bool end_of_epoch = true; | |||||
| while (!finish) { | while (!finish) { | ||||
| // Iterate over all the keys and add one key to each block. | // Iterate over all the keys and add one key to each block. | ||||
| for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { | ||||
| { | { | ||||
| std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); | ||||
| if (load_io_block_queue_ == false) { | if (load_io_block_queue_ == false) { | ||||
| end_of_epoch = true; | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -505,7 +510,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() { | |||||
| pre_count += filename_numrows_[file_name]; | pre_count += filename_numrows_[file_name]; | ||||
| } | } | ||||
| } | } | ||||
| if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { | |||||
| if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ && | |||||
| !end_of_epoch) { | |||||
| finish = false; | finish = false; | ||||
| } else { | } else { | ||||
| finish = true; | finish = true; | ||||