|
|
|
@@ -26,25 +26,6 @@ void SplitCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
output_num_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "output_num"); |
|
|
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
CheckParam(kernel_node); |
|
|
|
Reshape(); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void SplitCPUKernel<T>::Reshape() { |
|
|
|
param_ = new SplitParameter(); |
|
|
|
param_->num_split_ = output_num_; |
|
|
|
param_->split_dim_ = axis_ >= 0 ? axis_ : input_shape_.size() + axis_; |
|
|
|
|
|
|
|
param_->strides_[input_shape_.size() - 1] = 1; |
|
|
|
for (int i = input_shape_.size() - 2; i >= 0; i--) { |
|
|
|
param_->strides_[i] = param_->strides_[i + 1] * input_shape_[i + 1]; |
|
|
|
} |
|
|
|
|
|
|
|
param_->split_sizes_ = new int[sizeof(int) * param_->num_split_]; |
|
|
|
int split_size = input_shape_[param_->split_dim_] / output_num_; |
|
|
|
for (int i = 0; i < param_->num_split_; i++) { |
|
|
|
param_->split_sizes_[i] = split_size; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
@@ -65,39 +46,27 @@ template <typename T> |
|
|
|
void SplitCPUKernel<T>::LaunchSplit(T *input, T **output, size_t size) { |
|
|
|
(void)std::transform(input_shape_.begin(), input_shape_.end(), std::back_inserter(input_shape_int_), |
|
|
|
[](const int &value) { return static_cast<int>(value); }); |
|
|
|
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); |
|
|
|
const float block_size = 128.0; |
|
|
|
size_t thread_num = size < block_size * max_thread_num ? std::ceil(size / block_size) : max_thread_num; |
|
|
|
|
|
|
|
param_->split_count_ = size / (input_shape_[param_->split_dim_] * param_->strides_[param_->split_dim_]); |
|
|
|
int num_unit = param_->split_count_ * param_->num_split_; |
|
|
|
int thread_n_stride; |
|
|
|
if (thread_num != 0) { |
|
|
|
thread_n_stride = UP_DIV(num_unit, thread_num); |
|
|
|
SplitParameter param; |
|
|
|
param.num_split_ = output_num_; |
|
|
|
param.split_dim_ = axis_; |
|
|
|
param.strides_[input_shape_.size() - 1] = 1; |
|
|
|
for (int i = input_shape_.size() - 2; i >= 0; i--) { |
|
|
|
param.strides_[i] = param.strides_[i + 1] * input_shape_[i + 1]; |
|
|
|
} |
|
|
|
|
|
|
|
auto task = [&](size_t start, size_t end) { |
|
|
|
int task_id = start / (size / thread_num); |
|
|
|
int thread_offset = task_id * thread_n_stride; |
|
|
|
int num_unit_thread = MSMIN(thread_n_stride, num_unit - task_id * thread_n_stride); |
|
|
|
DoSplit(input, reinterpret_cast<void **>(output), &input_shape_int_[0], thread_offset, num_unit_thread, param_, |
|
|
|
sizeof(T)); |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, size); |
|
|
|
|
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void SplitCPUKernel<T>::FreeTmpBuff() { |
|
|
|
if (param_->split_sizes_ != nullptr) { |
|
|
|
delete[] param_->split_sizes_; |
|
|
|
param_->split_sizes_ = nullptr; |
|
|
|
auto split_sizes = std::make_unique<int[]>(param.num_split_); |
|
|
|
param.split_sizes_ = split_sizes.get(); |
|
|
|
int split_size = input_shape_[param.split_dim_] / output_num_; |
|
|
|
for (int i = 0; i < param.num_split_; i++) { |
|
|
|
param.split_sizes_[i] = split_size; |
|
|
|
} |
|
|
|
if (param_ != nullptr) { |
|
|
|
delete param_; |
|
|
|
param_ = nullptr; |
|
|
|
param.split_count_ = 1; |
|
|
|
for (int i = 0; i < axis_; ++i) { |
|
|
|
param.split_count_ *= input_shape_[i]; |
|
|
|
} |
|
|
|
auto task = [&](size_t start, size_t end) { |
|
|
|
DoSplit(input, reinterpret_cast<void **>(output), &input_shape_int_[0], start, end - start, ¶m, sizeof(T)); |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, param.split_count_ * param.num_split_); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -112,7 +81,6 @@ void SplitCPUKernel<T>::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
} |
|
|
|
size_t size = static_cast<size_t>(inputs[0]->size / sizeof(T)); |
|
|
|
LaunchSplit(input, output, size); |
|
|
|
FreeTmpBuff(); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
|