|
|
|
@@ -34,13 +34,13 @@ ConcatOp::ConcatOp(const std::shared_ptr<SamplerRT> &sampler, |
|
|
|
children_start_end_index_ = children_start_end_index; |
|
|
|
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler); |
|
|
|
if (distribute_sampler != nullptr) { |
|
|
|
num_shard_ = distribute_sampler->GetDeviceNum(); |
|
|
|
shard_index_ = distribute_sampler->GetDeviceID(); |
|
|
|
num_shard_ = static_cast<int32_t>(distribute_sampler->GetDeviceNum()); |
|
|
|
shard_index_ = static_cast<int32_t>(distribute_sampler->GetDeviceID()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
ConcatOp::ConcatOp() |
|
|
|
: PipelineOp(0), cur_child_(0), verified_(false), num_shard_(1), shard_index_(0), sample_number_(0) {} |
|
|
|
: PipelineOp(0), cur_child_(0), verified_(false), sample_number_(0), num_shard_(1), shard_index_(0) {} |
|
|
|
|
|
|
|
// A function that prints info about the Operator |
|
|
|
void ConcatOp::Print(std::ostream &out, bool show_all) const { |
|
|
|
@@ -124,7 +124,7 @@ bool ConcatOp::IgnoreSample() { |
|
|
|
bool is_not_mappable_or_second_ne_zero = true; |
|
|
|
|
|
|
|
if (!children_flag_and_nums_.empty()) { |
|
|
|
bool is_not_mappable = children_flag_and_nums_[cur_child_].first; |
|
|
|
bool is_not_mappable = static_cast<bool>(children_flag_and_nums_[cur_child_].first); |
|
|
|
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second); |
|
|
|
} |
|
|
|
bool ret = true; |
|
|
|
@@ -151,7 +151,7 @@ Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe |
|
|
|
bool is_not_mappable_or_second_ne_zero = true; |
|
|
|
|
|
|
|
if (!children_flag_and_nums_.empty()) { |
|
|
|
bool is_not_mappable = children_flag_and_nums_[cur_child_].first; |
|
|
|
bool is_not_mappable = static_cast<bool>(children_flag_and_nums_[cur_child_].first); |
|
|
|
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second); |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(child_[cur_child_]->GetNextRow(row, worker_id, retry_if_eoe)); |
|
|
|
|