|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -13,8 +13,10 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include <algorithm> |
|
|
|
#include "backend/kernel_compiler/cpu/split_cpu_kernel.h" |
|
|
|
#include "runtime/device/cpu/cpu_device_address.h" |
|
|
|
#include "common/thread_pool.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
@@ -29,20 +31,19 @@ void SplitCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void SplitCPUKernel<T>::Reshape() { |
|
|
|
input_size_ = 1; |
|
|
|
dims_current_after_axis_ = 1; |
|
|
|
dims_after_axis_ = 1; |
|
|
|
axis_step_ = input_shape_[axis_] / output_num_; |
|
|
|
param_ = new SplitParameter(); |
|
|
|
param_->num_split_ = output_num_; |
|
|
|
param_->split_dim_ = axis_ >= 0 ? axis_ : input_shape_.size() + axis_; |
|
|
|
|
|
|
|
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { |
|
|
|
input_size_ *= input_shape_[i]; |
|
|
|
if (i > axis_) { |
|
|
|
dims_current_after_axis_ *= input_shape_[i]; |
|
|
|
dims_after_axis_ *= input_shape_[i]; |
|
|
|
} |
|
|
|
if (i == axis_) { |
|
|
|
dims_current_after_axis_ *= input_shape_[i]; |
|
|
|
} |
|
|
|
param_->strides_[input_shape_.size() - 1] = 1; |
|
|
|
for (int i = input_shape_.size() - 2; i >= 0; i--) { |
|
|
|
param_->strides_[i] = param_->strides_[i + 1] * input_shape_[i + 1]; |
|
|
|
} |
|
|
|
|
|
|
|
param_->split_sizes_ = new int[sizeof(int) * param_->num_split_]; |
|
|
|
int split_size = input_shape_[param_->split_dim_] / output_num_; |
|
|
|
for (int i = 0; i < param_->num_split_; i++) { |
|
|
|
param_->split_sizes_[i] = split_size; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -61,17 +62,42 @@ bool SplitCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void SplitCPUKernel<T>::LaunchSplit(const T *input, T **output, size_t size) { |
|
|
|
void SplitCPUKernel<T>::LaunchSplit(T *input, T **output, size_t size) { |
|
|
|
(void)std::transform(input_shape_.begin(), input_shape_.end(), std::back_inserter(input_shape_int_), |
|
|
|
[](const int &value) { return static_cast<int>(value); }); |
|
|
|
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); |
|
|
|
const float block_size = 128.0; |
|
|
|
size_t thread_num = size < block_size * max_thread_num ? std::ceil(size / block_size) : max_thread_num; |
|
|
|
|
|
|
|
param_->split_count_ = size / (input_shape_[param_->split_dim_] * param_->strides_[param_->split_dim_]); |
|
|
|
int num_unit = param_->split_count_ * param_->num_split_; |
|
|
|
int thread_n_stride; |
|
|
|
if (thread_num != 0) { |
|
|
|
thread_n_stride = UP_DIV(num_unit, thread_num); |
|
|
|
} |
|
|
|
|
|
|
|
auto task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; i++) { |
|
|
|
int num = i % dims_current_after_axis_ / dims_after_axis_; |
|
|
|
int block = num / axis_step_; |
|
|
|
int block_pos = i / dims_current_after_axis_ * axis_step_ * dims_after_axis_ + |
|
|
|
num % axis_step_ * dims_after_axis_ + i % dims_after_axis_; |
|
|
|
output[block][block_pos] = input[i]; |
|
|
|
} |
|
|
|
int task_id = start / (size / thread_num); |
|
|
|
int thread_offset = task_id * thread_n_stride; |
|
|
|
int num_unit_thread = MSMIN(thread_n_stride, num_unit - task_id * thread_n_stride); |
|
|
|
DoSplit(input, reinterpret_cast<void **>(output), &input_shape_int_[0], thread_offset, num_unit_thread, param_, |
|
|
|
sizeof(T)); |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, size); |
|
|
|
|
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void SplitCPUKernel<T>::FreeTmpBuff() { |
|
|
|
if (param_->split_sizes_ != nullptr) { |
|
|
|
delete[] param_->split_sizes_; |
|
|
|
param_->split_sizes_ = nullptr; |
|
|
|
} |
|
|
|
if (param_ != nullptr) { |
|
|
|
delete param_; |
|
|
|
param_ = nullptr; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -86,6 +112,7 @@ void SplitCPUKernel<T>::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
} |
|
|
|
size_t size = static_cast<size_t>(inputs[0]->size / sizeof(T)); |
|
|
|
LaunchSplit(input, output, size); |
|
|
|
FreeTmpBuff(); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
|