|
|
@@ -153,7 +153,7 @@ Status CocoOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std: |
|
|
Status CocoOp::operator()() { |
|
|
Status CocoOp::operator()() { |
|
|
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); |
|
|
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); |
|
|
std::unique_ptr<DataBuffer> sampler_buffer; |
|
|
std::unique_ptr<DataBuffer> sampler_buffer; |
|
|
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); |
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); |
|
|
while (true) { |
|
|
while (true) { |
|
|
std::vector<int64_t> keys; |
|
|
std::vector<int64_t> keys; |
|
|
keys.reserve(rows_per_buffer_); |
|
|
keys.reserve(rows_per_buffer_); |
|
|
@@ -164,7 +164,7 @@ Status CocoOp::operator()() { |
|
|
RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); |
|
|
RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); |
|
|
} |
|
|
} |
|
|
RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); |
|
|
RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); |
|
|
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); |
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); |
|
|
} |
|
|
} |
|
|
if (keys.empty() == false) { |
|
|
if (keys.empty() == false) { |
|
|
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( |
|
|
RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( |
|
|
@@ -185,7 +185,7 @@ Status CocoOp::operator()() { |
|
|
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); |
|
|
io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe))); |
|
|
RETURN_IF_NOT_OK(wp_.Wait()); |
|
|
RETURN_IF_NOT_OK(wp_.Wait()); |
|
|
wp_.Clear(); |
|
|
wp_.Clear(); |
|
|
RETURN_IF_NOT_OK(sampler_->GetNextBuffer(&sampler_buffer)); |
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|