From 03553d9b6a5b1dc4e4562e75689479ada160859b Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Sat, 8 May 2021 14:45:32 +0800 Subject: [PATCH] [CPU] Fix the bug of split op --- .../kernel_compiler/cpu/split_cpu_kernel.cc | 52 ++++++------------- .../kernel_compiler/cpu/split_cpu_kernel.h | 5 -- 2 files changed, 16 insertions(+), 41 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc index 26fa98ac5c..bd21debbd6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc @@ -26,25 +26,6 @@ void SplitCPUKernel::InitKernel(const CNodePtr &kernel_node) { output_num_ = AnfAlgo::GetNodeAttr(kernel_node, "output_num"); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); CheckParam(kernel_node); - Reshape(); -} - -template -void SplitCPUKernel::Reshape() { - param_ = new SplitParameter(); - param_->num_split_ = output_num_; - param_->split_dim_ = axis_ >= 0 ? axis_ : input_shape_.size() + axis_; - - param_->strides_[input_shape_.size() - 1] = 1; - for (int i = input_shape_.size() - 2; i >= 0; i--) { - param_->strides_[i] = param_->strides_[i + 1] * input_shape_[i + 1]; - } - - param_->split_sizes_ = new int[sizeof(int) * param_->num_split_]; - int split_size = input_shape_[param_->split_dim_] / output_num_; - for (int i = 0; i < param_->num_split_; i++) { - param_->split_sizes_[i] = split_size; - } } template @@ -69,8 +50,21 @@ void SplitCPUKernel::LaunchSplit(T *input, T **output, size_t size) { const float block_size = 128.0; size_t thread_num = size < block_size * max_thread_num ? std::ceil(size / block_size) : max_thread_num; - param_->split_count_ = size / (input_shape_[param_->split_dim_] * param_->strides_[param_->split_dim_]); - int num_unit = param_->split_count_ * param_->num_split_; + SplitParameter param; + param.num_split_ = output_num_; + param.split_dim_ = axis_ >= 0 ? axis_ : input_shape_.size() + axis_; + param.strides_[input_shape_.size() - 1] = 1; + for (int i = input_shape_.size() - 2; i >= 0; i--) { + param.strides_[i] = param.strides_[i + 1] * input_shape_[i + 1]; + } + auto split_sizes = std::make_unique(param.num_split_); + param.split_sizes_ = split_sizes.get(); + int split_size = input_shape_[param.split_dim_] / output_num_; + for (int i = 0; i < param.num_split_; i++) { + param.split_sizes_[i] = split_size; + } + param.split_count_ = size / (input_shape_[param.split_dim_] * param.strides_[param.split_dim_]); + int num_unit = param.split_count_ * param.num_split_; int thread_n_stride; if (thread_num != 0) { thread_n_stride = UP_DIV(num_unit, thread_num); @@ -80,7 +74,7 @@ void SplitCPUKernel::LaunchSplit(T *input, T **output, size_t size) { int task_id = start / (size / thread_num); int thread_offset = task_id * thread_n_stride; int num_unit_thread = MSMIN(thread_n_stride, num_unit - task_id * thread_n_stride); - DoSplit(input, reinterpret_cast(output), &input_shape_int_[0], thread_offset, num_unit_thread, param_, + DoSplit(input, reinterpret_cast(output), &input_shape_int_[0], thread_offset, num_unit_thread, ¶m, sizeof(T)); }; CPUKernelUtils::ParallelFor(task, size); @@ -88,19 +82,6 @@ void SplitCPUKernel::LaunchSplit(T *input, T **output, size_t size) { return; } -template -void SplitCPUKernel::FreeTmpBuff() { - if (param_->split_sizes_ != nullptr) { - delete[] param_->split_sizes_; - param_->split_sizes_ = nullptr; - } - if (param_ != nullptr) { - delete param_; - param_ = nullptr; - } - return; -} - template void SplitCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &workspace, @@ -112,7 +93,6 @@ void SplitCPUKernel::LaunchKernel(const std::vector &inputs, } size_t size = static_cast(inputs[0]->size / sizeof(T)); LaunchSplit(input, output, size); - FreeTmpBuff(); return; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.h index b74def2b96..51e43c8fd1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.h @@ -42,9 +42,7 @@ class SplitCPUKernel : public CPUKernel { private: void CheckParam(const CNodePtr &kernel_node); - void Reshape(); void LaunchSplit(T *input, T **output, size_t size); - void FreeTmpBuff(); int64_t axis_; int64_t output_num_; int64_t axis_step_; @@ -57,9 +55,6 @@ class SplitCPUKernel : public CPUKernel { std::vector input_shape_; std::vector input_shape_int_; TypeId dtype_{kTypeUnknown}; - - protected: - SplitParameter *param_ = nullptr; }; MS_REG_CPU_KERNEL_T(