|
|
|
@@ -27,18 +27,14 @@ |
|
|
|
|
|
|
|
#include "proto/example.pb.h" |
|
|
|
#include "./securec.h" |
|
|
|
#include "utils/ms_utils.h" |
|
|
|
#include "minddata/dataset/core/config_manager.h" |
|
|
|
#include "minddata/dataset/core/global_context.h" |
|
|
|
#include "minddata/dataset/engine/connector.h" |
|
|
|
#include "minddata/dataset/engine/data_schema.h" |
|
|
|
#include "minddata/dataset/engine/datasetops/source/io_block.h" |
|
|
|
#include "minddata/dataset/engine/db_connector.h" |
|
|
|
#include "minddata/dataset/engine/execution_tree.h" |
|
|
|
#include "minddata/dataset/engine/jagged_connector.h" |
|
|
|
#include "minddata/dataset/engine/opt/pass.h" |
|
|
|
#include "minddata/dataset/util/path.h" |
|
|
|
#include "minddata/dataset/util/queue.h" |
|
|
|
#include "minddata/dataset/util/random.h" |
|
|
|
#include "minddata/dataset/util/status.h" |
|
|
|
#include "minddata/dataset/util/task_manager.h" |
|
|
|
@@ -387,14 +383,14 @@ Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) { |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, |
|
|
|
bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, |
|
|
|
const int64_t &pre_count) { |
|
|
|
*start_offset = 0; |
|
|
|
*end_offset = 0; |
|
|
|
bool push = false; |
|
|
|
int64_t start_index = device_id_ * num_rows_per_shard_; |
|
|
|
if (device_id_ + 1 < 0) { |
|
|
|
MS_LOG(ERROR) << "Device id is invalid"; |
|
|
|
MS_LOG(ERROR) << "Device id is invalid."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_; |
|
|
|
@@ -448,7 +444,7 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { |
|
|
|
} else { |
|
|
|
// Do an index lookup using that key to get the filename. |
|
|
|
std::string file_name = (*filename_index_)[*it]; |
|
|
|
if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { |
|
|
|
if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) { |
|
|
|
auto ioBlock = std::make_unique<FilenameBlock>(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone); |
|
|
|
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); |
|
|
|
MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset; |
|
|
|
@@ -496,7 +492,7 @@ Status TFReaderOp::FillIOBlockNoShuffle() { |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::string file_name = it.value(); |
|
|
|
if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { |
|
|
|
if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) { |
|
|
|
auto ioBlock = std::make_unique<FilenameBlock>(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone); |
|
|
|
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); |
|
|
|
queue_index = (queue_index + 1) % num_workers_; |
|
|
|
@@ -711,7 +707,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table |
|
|
|
// reinitializes itself so that it can be executed again, as if it was just created. |
|
|
|
Status TFReaderOp::Reset() { |
|
|
|
MS_LOG(DEBUG) << Name() << " performing a self-reset."; |
|
|
|
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true |
|
|
|
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true |
|
|
|
load_jagged_connector_ = true; |
|
|
|
|
|
|
|
{ |
|
|
|
@@ -767,6 +763,14 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng |
|
|
|
new_pad_size *= cur_shape[i]; |
|
|
|
} |
|
|
|
pad_size = new_pad_size; |
|
|
|
} else { |
|
|
|
if (cur_shape.known() && cur_shape.NumOfElements() != max_size) { |
|
|
|
std::string err_msg = "Shape in schema's column '" + current_col.name() + "' is incorrect." + |
|
|
|
"\nshape received: " + cur_shape.ToString() + |
|
|
|
"\ntotal elements in shape received: " + std::to_string(cur_shape.NumOfElements()) + |
|
|
|
"\nexpected total elements in shape: " + std::to_string(max_size); |
|
|
|
RETURN_STATUS_UNEXPECTED(err_msg); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|