Browse Source

[CPU] Fix the bug of split op

tags/v1.3.0
zhanyuan 5 years ago
parent
commit
4b45bc9ff9
2 changed files with 18 additions and 55 deletions
  1. +18
    -50
      mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc
  2. +0
    -5
      mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.h

+ 18
- 50
mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc View File

@@ -26,25 +26,6 @@ void SplitCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
output_num_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "output_num");
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CheckParam(kernel_node);
Reshape();
}

template <typename T>
void SplitCPUKernel<T>::Reshape() {
param_ = new SplitParameter();
param_->num_split_ = output_num_;
param_->split_dim_ = axis_ >= 0 ? axis_ : input_shape_.size() + axis_;

param_->strides_[input_shape_.size() - 1] = 1;
for (int i = input_shape_.size() - 2; i >= 0; i--) {
param_->strides_[i] = param_->strides_[i + 1] * input_shape_[i + 1];
}

param_->split_sizes_ = new int[sizeof(int) * param_->num_split_];
int split_size = input_shape_[param_->split_dim_] / output_num_;
for (int i = 0; i < param_->num_split_; i++) {
param_->split_sizes_[i] = split_size;
}
}

template <typename T>
@@ -65,39 +46,27 @@ template <typename T>
void SplitCPUKernel<T>::LaunchSplit(T *input, T **output, size_t size) {
(void)std::transform(input_shape_.begin(), input_shape_.end(), std::back_inserter(input_shape_int_),
[](const int &value) { return static_cast<int>(value); });
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
size_t thread_num = size < block_size * max_thread_num ? std::ceil(size / block_size) : max_thread_num;

param_->split_count_ = size / (input_shape_[param_->split_dim_] * param_->strides_[param_->split_dim_]);
int num_unit = param_->split_count_ * param_->num_split_;
int thread_n_stride;
if (thread_num != 0) {
thread_n_stride = UP_DIV(num_unit, thread_num);
SplitParameter param;
param.num_split_ = output_num_;
param.split_dim_ = axis_;
param.strides_[input_shape_.size() - 1] = 1;
for (int i = input_shape_.size() - 2; i >= 0; i--) {
param.strides_[i] = param.strides_[i + 1] * input_shape_[i + 1];
}

auto task = [&](size_t start, size_t end) {
int task_id = start / (size / thread_num);
int thread_offset = task_id * thread_n_stride;
int num_unit_thread = MSMIN(thread_n_stride, num_unit - task_id * thread_n_stride);
DoSplit(input, reinterpret_cast<void **>(output), &input_shape_int_[0], thread_offset, num_unit_thread, param_,
sizeof(T));
};
CPUKernelUtils::ParallelFor(task, size);

return;
}

template <typename T>
void SplitCPUKernel<T>::FreeTmpBuff() {
if (param_->split_sizes_ != nullptr) {
delete[] param_->split_sizes_;
param_->split_sizes_ = nullptr;
auto split_sizes = std::make_unique<int[]>(param.num_split_);
param.split_sizes_ = split_sizes.get();
int split_size = input_shape_[param.split_dim_] / output_num_;
for (int i = 0; i < param.num_split_; i++) {
param.split_sizes_[i] = split_size;
}
if (param_ != nullptr) {
delete param_;
param_ = nullptr;
param.split_count_ = 1;
for (int i = 0; i < axis_; ++i) {
param.split_count_ *= input_shape_[i];
}
auto task = [&](size_t start, size_t end) {
DoSplit(input, reinterpret_cast<void **>(output), &input_shape_int_[0], start, end - start, &param, sizeof(T));
};
CPUKernelUtils::ParallelFor(task, param.split_count_ * param.num_split_);
return;
}

@@ -112,7 +81,6 @@ void SplitCPUKernel<T>::LaunchKernel(const std::vector<AddressPtr> &inputs,
}
size_t size = static_cast<size_t>(inputs[0]->size / sizeof(T));
LaunchSplit(input, output, size);
FreeTmpBuff();
return;
}



+ 0
- 5
mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.h View File

@@ -42,9 +42,7 @@ class SplitCPUKernel : public CPUKernel {

private:
void CheckParam(const CNodePtr &kernel_node);
void Reshape();
void LaunchSplit(T *input, T **output, size_t size);
void FreeTmpBuff();
int64_t axis_;
int64_t output_num_;
int64_t axis_step_;
@@ -57,9 +55,6 @@ class SplitCPUKernel : public CPUKernel {
std::vector<size_t> input_shape_;
std::vector<int> input_shape_int_;
TypeId dtype_{kTypeUnknown};

protected:
SplitParameter *param_ = nullptr;
};

MS_REG_CPU_KERNEL_T(


Loading…
Cancel
Save