From: @zhangbuxue Reviewed-by: @zhaizhiqiang,@guoqi1024 Signed-off-by: @zhaizhiqiangpull/15663/MERGE
| @@ -226,27 +226,14 @@ bool EltWiseGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inpu | |||||
| {"GeLUGrad", &EltWiseGradCPUKernel<T>::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel<T>::AsinGrad}, | {"GeLUGrad", &EltWiseGradCPUKernel<T>::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel<T>::AsinGrad}, | ||||
| {"ACosGrad", &EltWiseGradCPUKernel<T>::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel<T>::AtanGrad}, | {"ACosGrad", &EltWiseGradCPUKernel<T>::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel<T>::AtanGrad}, | ||||
| {"AsinhGrad", &EltWiseGradCPUKernel<T>::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel<T>::AcoshGrad}}; | {"AsinhGrad", &EltWiseGradCPUKernel<T>::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel<T>::AcoshGrad}}; | ||||
| T *input1 = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| T *input2 = reinterpret_cast<T *>(inputs[1]->addr); | |||||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| const auto *input1 = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr); | |||||
| auto *output = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| size_t count = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | size_t count = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | ||||
| auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||||
| const float block_size = 128.0; | |||||
| size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num; | |||||
| std::vector<common::Task> tasks; | |||||
| size_t start = 0; | |||||
| size_t once_compute_size = (count + thread_num - 1) / thread_num; | |||||
| while (start < count) { | |||||
| size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size); | |||||
| auto block = [&, start, end]() { | |||||
| elt_map.at(kernel_name_)(this, input1, input2, output, start, end); | |||||
| return common::SUCCESS; | |||||
| }; | |||||
| tasks.emplace_back(block); | |||||
| start += once_compute_size; | |||||
| } | |||||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||||
| CPUKernelUtils::ParallelFor( | |||||
| std::bind(elt_map.at(kernel_name_), this, input1, input2, output, std::placeholders::_1, std::placeholders::_2), | |||||
| count); | |||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -72,7 +72,7 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p | |||||
| if (in_data == NULL || out_data == NULL || param == NULL) { | if (in_data == NULL || out_data == NULL || param == NULL) { | ||||
| return NNACL_NULL_PTR; | return NNACL_NULL_PTR; | ||||
| } | } | ||||
| if (param->num_axes_ > DIMENSION_6D) { | |||||
| if (param->num_axes_ > DIMENSION_8D) { | |||||
| return NNACL_PARAM_INVALID; | return NNACL_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -107,6 +107,10 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p | |||||
| *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); | *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); | ||||
| } else if (param->data_type == kDataTypeInt) { | } else if (param->data_type == kDataTypeInt) { | ||||
| *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); | *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); | ||||
| } else if (param->data_type == kDataTypeFloat64) { | |||||
| *((double *)out_data + out_offset) = *((double *)in_data + in_offset); | |||||
| } else if (param->data_type == kDataTypeBool) { | |||||
| *((bool *)out_data + out_offset) = *((bool *)in_data + in_offset); | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| } else if (param->data_type == kDataTypeFloat16) { | } else if (param->data_type == kDataTypeFloat16) { | ||||
| *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset); | *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset); | ||||
| @@ -69,7 +69,8 @@ typedef enum LiteDataType { | |||||
| kDataTypeFloat16, | kDataTypeFloat16, | ||||
| kDataTypeInt, | kDataTypeInt, | ||||
| kDataTypeInt8, | kDataTypeInt8, | ||||
| KDataTypeBool, | |||||
| kDataTypeBool, | |||||
| kDataTypeFloat64 | |||||
| } LiteDataType; | } LiteDataType; | ||||
| typedef enum DataOrder { | typedef enum DataOrder { | ||||
| @@ -46,26 +46,26 @@ void ReduceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||||
| if constexpr (std::is_same<T, bool>::value) { | if constexpr (std::is_same<T, bool>::value) { | ||||
| if (kernel_name == "ReduceAll") { | if (kernel_name == "ReduceAll") { | ||||
| reduce_type_ = ReduceType::ReduceAll; | |||||
| reduce_type_ = kReduceAll; | |||||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out &= input[pos]; }; | reduce_func_ = [](const T *input, size_t pos, T *out) { *out &= input[pos]; }; | ||||
| } else if (kernel_name == "ReduceAny") { | } else if (kernel_name == "ReduceAny") { | ||||
| reduce_type_ = ReduceType::ReduceAny; | |||||
| reduce_type_ = kReduceAny; | |||||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out |= input[pos]; }; | reduce_func_ = [](const T *input, size_t pos, T *out) { *out |= input[pos]; }; | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name_ << " for bool."; | MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name_ << " for bool."; | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (kernel_name == "ReduceMax") { | if (kernel_name == "ReduceMax") { | ||||
| reduce_type_ = ReduceType::ReduceMax; | |||||
| reduce_type_ = kReduceMax; | |||||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::max(input[pos], *out); }; | reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::max(input[pos], *out); }; | ||||
| } else if (kernel_name == "ReduceMin") { | } else if (kernel_name == "ReduceMin") { | ||||
| reduce_type_ = ReduceType::ReduceMin; | |||||
| reduce_type_ = kReduceMin; | |||||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::min(input[pos], *out); }; | reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::min(input[pos], *out); }; | ||||
| } else if (kernel_name == "ReduceSum") { | } else if (kernel_name == "ReduceSum") { | ||||
| reduce_type_ = ReduceType::ReduceSum; | |||||
| reduce_type_ = kReduceSum; | |||||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; }; | reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; }; | ||||
| } else if (kernel_name == "ReduceMean") { | } else if (kernel_name == "ReduceMean") { | ||||
| reduce_type_ = ReduceType::ReduceMean; | |||||
| reduce_type_ = kReduceMean; | |||||
| reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; }; | reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; }; | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name; | MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name; | ||||
| @@ -86,7 +86,7 @@ bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| for (size_t i = 1; i < input_size; ++i) { | for (size_t i = 1; i < input_size; ++i) { | ||||
| reduce_func_(input_addr, i, output_addr); | reduce_func_(input_addr, i, output_addr); | ||||
| } | } | ||||
| if (reduce_type_ == ReduceType::ReduceMean) { | |||||
| if (reduce_type_ == kReduceMean) { | |||||
| *output_addr /= input_size; | *output_addr /= input_size; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -126,7 +126,7 @@ bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| reduce_func_(input_addr, iter.GetPos(), &output_addr[i]); | reduce_func_(input_addr, iter.GetPos(), &output_addr[i]); | ||||
| iter.GenNextPos(); | iter.GenNextPos(); | ||||
| } | } | ||||
| if (reduce_type_ == ReduceType::ReduceMean) { | |||||
| if (reduce_type_ == kReduceMean) { | |||||
| output_addr[i] /= stride; | output_addr[i] /= stride; | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,8 +24,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| enum class ReduceType { ReduceAll, ReduceAny, ReduceMax, ReduceMin, ReduceSum, ReduceMean }; | |||||
| template <typename T> | template <typename T> | ||||
| class ReduceCPUKernel : public CPUKernel { | class ReduceCPUKernel : public CPUKernel { | ||||
| public: | public: | ||||
| @@ -36,6 +34,7 @@ class ReduceCPUKernel : public CPUKernel { | |||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | private: | ||||
| enum ReduceType { kReduceAll, kReduceAny, kReduceMax, kReduceMin, kReduceSum, kReduceMean }; | |||||
| std::vector<size_t> input_shape_; | std::vector<size_t> input_shape_; | ||||
| std::vector<int64_t> axis_; | std::vector<int64_t> axis_; | ||||
| ReduceType reduce_type_; | ReduceType reduce_type_; | ||||
| @@ -13,231 +13,107 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/cpu/slice_cpu_kernel.h" | #include "backend/kernel_compiler/cpu/slice_cpu_kernel.h" | ||||
| #include <algorithm> | |||||
| #include <unordered_map> | |||||
| #include "common/thread_pool.h" | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | #include "runtime/device/cpu/cpu_device_address.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| constexpr int MAX_DIMS = 8; | |||||
| void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| CheckParam(kernel_node); | |||||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| std::vector<int64_t> begin_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN); | |||||
| (void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto strides = prim->GetAttr(STRIDES); | |||||
| if (strides != nullptr) { | |||||
| std::vector<int64_t> strides_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES); | |||||
| std::vector<int64_t> end_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END); | |||||
| (void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| (void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| TransArg(); | |||||
| ClipBegin(); | |||||
| } else { | |||||
| std::vector<int> sizes; | |||||
| std::vector<int64_t> sizes_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE); | |||||
| (void)std::transform(sizes_me.begin(), sizes_me.end(), std::back_inserter(sizes), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; | |||||
| } | |||||
| ClipBegin(); | |||||
| for (size_t i = 0; i < sizes.size(); ++i) { | |||||
| while (sizes[i] < 0) { | |||||
| sizes[i] = sizes[i] + SizeToInt(input_shape_[i]); | |||||
| } | |||||
| strides_.emplace_back(1); | |||||
| end_.emplace_back(begin_[i] + sizes[i]); | |||||
| } | |||||
| int NormalizeBeginPos(int begin_pos, int dim_len) { | |||||
| if (begin_pos < 0) { | |||||
| int normal_pos = begin_pos + dim_len; | |||||
| return std::max(normal_pos, 0); | |||||
| } | } | ||||
| ExpandAllMemberDims(); | |||||
| CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); | |||||
| CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); | |||||
| return std::min(begin_pos, dim_len - 1); | |||||
| } | } | ||||
| void SliceCPUKernel::ClipBegin() { | |||||
| for (size_t i = 0; i < begin_.size(); i++) { | |||||
| if (begin_[i] < 0) { | |||||
| auto k = begin_[i] + SizeToInt(input_shape_[i]); | |||||
| begin_[i] = k < 0 ? 0 : k; | |||||
| } | |||||
| if (begin_[i] > SizeToInt(input_shape_[i])) { | |||||
| begin_[i] = SizeToInt(input_shape_[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| void SliceCPUKernel::ExpandAllMemberDims() { | |||||
| auto input_len = input_shape_.size(); | |||||
| if (input_len < 4) { | |||||
| for (size_t i = 0; i < 4 - input_len; ++i) { | |||||
| input_shape_.insert(input_shape_.begin(), 1); | |||||
| begin_.insert(begin_.begin(), 0); | |||||
| strides_.insert(strides_.begin(), 1); | |||||
| end_.insert(end_.begin(), 1); | |||||
| } | |||||
| void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| static const std::unordered_map<TypeId, int> type_size_map = {{kNumberTypeBool, sizeof(bool)}, | |||||
| {kNumberTypeInt32, sizeof(int)}, | |||||
| {kNumberTypeFloat32, sizeof(float)}, | |||||
| {kNumberTypeFloat64, sizeof(double)}}; | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (input_shape.size() > DIMENSION_8D || input_shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Slice only support 1D to 8D input tensor, but got " << input_shape.size() << "D."; | |||||
| } | } | ||||
| for (size_t i = 0; i < 4; ++i) { | |||||
| if (SignOfStride(i)) { | |||||
| int ax = (end_[i] - begin_[i]) * SignOfStride(i); | |||||
| if (ax < 0) { | |||||
| ax = 0; | |||||
| } | |||||
| output_shape_.push_back(IntToSize(ax)); | |||||
| } | |||||
| auto size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE); | |||||
| auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN); | |||||
| if (begin.size() != input_shape.size() || size.size() != input_shape.size()) { | |||||
| MS_LOG(EXCEPTION) << "Slice requires the length of begin and size must be equal to input dimension."; | |||||
| } | } | ||||
| } | |||||
| InitSliceParam(input_shape, begin, size); | |||||
| bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| bool ret{true}; | |||||
| if (dtype_ == kNumberTypeInt32) { | |||||
| ret = LaunchKernel<int>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| ret = LaunchKernel<float>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeBool) { | |||||
| ret = LaunchKernel<bool>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat64) { | |||||
| ret = LaunchKernel<double>(inputs, outputs); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64"; | |||||
| return false; | |||||
| TypeId dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| auto size_pair = type_size_map.find(dtype); | |||||
| if (size_pair == type_size_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "Slice supports bool, int32, float32 and float64 input tensor, but got " | |||||
| << TypeIdToType(dtype)->ToString(); | |||||
| } | } | ||||
| return ret; | |||||
| data_size_ = size_pair->second; | |||||
| } | } | ||||
| template <typename T> | |||||
| bool SliceCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| T *input_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| T *output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; | |||||
| int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)}; | |||||
| size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1], | |||||
| begin_[2] * input_element_num_[2]}; | |||||
| size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1], | |||||
| strides_[2] * input_element_num_[2]}; | |||||
| void SliceCPUKernel::ParallelRun(void *input_addr, void *output_addr, int thread_num) { | |||||
| std::vector<common::Task> tasks; | |||||
| int thread_index = 0; | |||||
| while (thread_index < thread_num) { | |||||
| auto block = [&, thread_index]() { | |||||
| DoSlice(input_addr, output_addr, &slice_param_, thread_index, data_size_); | |||||
| return common::SUCCESS; | |||||
| }; | |||||
| tasks.emplace_back(block); | |||||
| thread_index++; | |||||
| } | |||||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||||
| } | |||||
| auto in_n_offset = in_start_offset[0]; | |||||
| auto out_n_offset = 0; | |||||
| for (int i = begin_[0]; signstride[0] * i < signstride[0] * end_[0]; | |||||
| i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) { | |||||
| if (can_copy_memory[0]) { | |||||
| CopyDataToOutput<T>(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0); | |||||
| continue; | |||||
| } | |||||
| auto in_c_offset = in_start_offset[1]; | |||||
| auto out_c_offset = 0; | |||||
| for (int j = begin_[1]; signstride[1] * j < signstride[1] * end_[1]; | |||||
| j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) { | |||||
| if (can_copy_memory[1]) { | |||||
| CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, | |||||
| input_element_num_[1], 1); | |||||
| continue; | |||||
| } | |||||
| auto in_h_offset = in_start_offset[2]; | |||||
| auto out_h_offset = 0; | |||||
| for (int k = begin_[2]; signstride[2] * k < signstride[2] * end_[2]; | |||||
| k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) { | |||||
| if (can_copy_memory[2]) { | |||||
| CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, | |||||
| out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2); | |||||
| continue; | |||||
| } | |||||
| for (int m = begin_[3]; signstride[3] * m < signstride[3] * end_[3]; m += strides_[3]) { | |||||
| *output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m]; | |||||
| } | |||||
| void SliceCPUKernel::InitSliceParam(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin, | |||||
| const std::vector<int64_t> &size) { | |||||
| for (size_t i = 0; i < DIMENSION_8D; i++) { | |||||
| if (i < input_shape.size()) { | |||||
| int dim_len = SizeToInt(input_shape[i]); | |||||
| int begin_pos = LongToInt(begin[i]); | |||||
| int slice_size = LongToInt(size[i]); | |||||
| if (slice_size <= 0) { | |||||
| MS_LOG(EXCEPTION) << "Slice requires the each dimension slice size must be greater than 0."; | |||||
| } | } | ||||
| slice_param_.shape_[i] = dim_len; | |||||
| slice_param_.size_[i] = slice_size; | |||||
| slice_param_.begin_[i] = NormalizeBeginPos(begin_pos, dim_len); | |||||
| int end = slice_param_.begin_[i] + slice_param_.size_[i]; | |||||
| slice_param_.end_[i] = std::min(end, dim_len); | |||||
| } else { | |||||
| slice_param_.shape_[i] = 1; | |||||
| slice_param_.begin_[i] = 0; | |||||
| slice_param_.size_[i] = 1; | |||||
| slice_param_.end_[i] = 1; | |||||
| } | } | ||||
| } | } | ||||
| slice_param_.param_length_ = DIMENSION_8D; | |||||
| return true; | |||||
| size_t max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||||
| slice_param_.op_parameter_.thread_num_ = std::min(slice_param_.size_[1], SizeToInt(max_thread_num)); | |||||
| } | } | ||||
| bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { | |||||
| for (size_t i = dim + 1; i < 4; ++i) { | |||||
| if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) { | |||||
| return false; | |||||
| } | |||||
| bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (outputs[0]->size == 0) { | |||||
| return true; | |||||
| } | |||||
| auto input_addr = inputs[0]->addr; | |||||
| auto output_addr = outputs[0]->addr; | |||||
| int thread_num = slice_param_.op_parameter_.thread_num_; | |||||
| if (parallel_ && thread_num >= 2) { | |||||
| ParallelRun(input_addr, output_addr, thread_num); | |||||
| } else { | |||||
| DoSliceNoParallel(input_addr, output_addr, &slice_param_, data_size_); | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| int SliceCPUKernel::SignOfStride(size_t axis) const { | |||||
| if (strides_[axis] > 0) { | |||||
| return 1; | |||||
| } | |||||
| return -1; | |||||
| } | |||||
| template <typename T> | |||||
| void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset, | |||||
| const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, | |||||
| size_t copy_num, int id) const { | |||||
| T *input_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| auto in_buff_size = inputs[0]->size; | |||||
| T *output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| auto out_buff_size = outputs[0]->size; | |||||
| if ((in_offset + copy_num) * sizeof(T) > in_buff_size) { | |||||
| MS_LOG(EXCEPTION) << "input memory out of bounds."; | |||||
| } | |||||
| if ((out_offset + copy_num) * sizeof(T) > out_buff_size) { | |||||
| MS_LOG(EXCEPTION) << id << " output memory out of bounds."; | |||||
| } | |||||
| size_t buff_size = out_buff_size - out_offset * sizeof(T); | |||||
| size_t copy_size = copy_num * sizeof(T); | |||||
| if (buff_size < copy_size) { | |||||
| MS_LOG(EXCEPTION) << "output buffer is not enough. memcpy failed!"; | |||||
| } | |||||
| auto ret = memcpy_s(output_addr + out_offset, copy_size, input_addr + in_offset, copy_size); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; | |||||
| } | |||||
| } | |||||
| void SliceCPUKernel::TransArg() { | |||||
| if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; | |||||
| } | |||||
| for (size_t i = 0; i < strides_.size(); ++i) { | |||||
| if (strides_[i] == 0) { | |||||
| MS_LOG(EXCEPTION) << "slice stride cannot be zero"; | |||||
| } | |||||
| if (end_[i] == 0 && begin_[i] < 0) { | |||||
| end_[i] = end_[i] + SizeToInt(input_shape_[i]); | |||||
| } | |||||
| if (end_[i] < 0) { | |||||
| end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]); | |||||
| } | |||||
| if (end_[i] > SizeToInt(input_shape_[i])) { | |||||
| end_[i] = SizeToInt(input_shape_[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 1) { | |||||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs."; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; | |||||
| } | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (input_shape.size() > MAX_DIMS) { | |||||
| MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; | |||||
| } | |||||
| if (input_shape.size() == 0) { | |||||
| MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; | |||||
| } | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,12 +13,16 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ | #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ | ||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ | #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | #include "backend/kernel_compiler/cpu/cpu_kernel.h" | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | ||||
| #include "nnacl/base/slice_base.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| @@ -33,41 +37,20 @@ class SliceCPUKernel : public CPUKernel { | |||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | private: | ||||
| template <typename T> | |||||
| bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs); | |||||
| template <typename T> | |||||
| void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset, | |||||
| const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num, | |||||
| int id) const; | |||||
| void ExpandAllMemberDims(); | |||||
| bool CanCopyMemoryOnAxis(size_t dim) const; | |||||
| int SignOfStride(size_t axis) const; | |||||
| void CheckParam(const CNodePtr &kernel_node) const; | |||||
| void TransArg(); | |||||
| void ClipBegin(); | |||||
| std::vector<int> begin_; | |||||
| std::vector<int> end_; | |||||
| std::vector<int> strides_; | |||||
| std::vector<size_t> input_shape_; | |||||
| std::vector<size_t> input_element_num_; | |||||
| std::vector<size_t> output_shape_; | |||||
| std::vector<size_t> output_element_num_; | |||||
| TypeId dtype_{kTypeUnknown}; | |||||
| void InitSliceParam(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin, | |||||
| const std::vector<int64_t> &size); | |||||
| void ParallelRun(void *input_addr, void *output_addr, int thread_num); | |||||
| bool parallel_{true}; | |||||
| int data_size_{4}; | |||||
| SliceParameter slice_param_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel); | MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel); | ||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| SliceCPUKernel); | SliceCPUKernel); | ||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SliceCPUKernel); | SliceCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,226 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/cpu/stridedslice_cpu_kernel.h" | |||||
| #include <utility> | |||||
| #include <functional> | |||||
| #include <algorithm> | |||||
| #include <unordered_map> | |||||
| #include "common/thread_pool.h" | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| enum PosType { kBegin, kEnd }; | |||||
| int NormalizePos(int pos, int dim_len, PosType pos_type) { | |||||
| if (pos < 0) { | |||||
| int normal_pos = pos + dim_len; | |||||
| normal_pos = std::max(normal_pos, 0); | |||||
| return normal_pos; | |||||
| } | |||||
| int max_pos = pos_type == kBegin ? dim_len - 1 : dim_len; | |||||
| return std::min(pos, max_pos); | |||||
| } | |||||
| void StridedSliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| if (input_shape_.size() > DIMENSION_8D || input_shape_.empty()) { | |||||
| MS_LOG(EXCEPTION) << "StridedSlice only support 1D to 8D input tensor, but got " << input_shape_.size() << "D."; | |||||
| } | |||||
| auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN); | |||||
| auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END); | |||||
| auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES); | |||||
| if (begin.size() != end.size() || begin.size() != stride.size() || begin.size() > input_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) | |||||
| << "StridedSLice requires the length of begin, stride and end must be equal and less than input dimension."; | |||||
| } | |||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| InitSliceParam(begin, end, stride); | |||||
| parallel_ = MatchParallelPattern(); | |||||
| if (parallel_) { | |||||
| InitParallelParam(); | |||||
| } | |||||
| } | |||||
| bool StridedSliceCPUKernel::MatchParallelPattern() { | |||||
| // This function is seeking if that the number of only one dimension | |||||
| // is different between input and output. If so, we can do some trick. | |||||
| // Example 1: | |||||
| // input shape info: [1, 80, 46, 40] | |||||
| // output shape info: [1, 80, 20, 40] | |||||
| // Example 2: | |||||
| // input shape info: [1, 46, 40] | |||||
| // output shape info: [1, 20, 40] | |||||
| if (input_shape_.size() != output_shape_.size()) { | |||||
| return false; | |||||
| } | |||||
| std::vector<int> axis_list; | |||||
| for (size_t i = 0; i < input_shape_.size(); ++i) { | |||||
| if (input_shape_[i] != output_shape_[i]) { | |||||
| axis_list.emplace_back(i); | |||||
| } | |||||
| } | |||||
| if (axis_list.size() == 1) { | |||||
| split_axis_ = axis_list.front(); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void StridedSliceCPUKernel::InitParallelParam() { | |||||
| outer_ = SizeToInt( | |||||
| std::accumulate(input_shape_.begin(), input_shape_.begin() + split_axis_, size_t(1), std::multiplies<size_t>())); | |||||
| inner_ = SizeToInt( | |||||
| std::accumulate(input_shape_.begin() + split_axis_ + 1, input_shape_.end(), size_t(1), std::multiplies<size_t>())); | |||||
| int max_thread_num = SizeToInt(common::ThreadPool::GetInstance().GetSyncRunThreadNum()); | |||||
| int thread_num = 1; | |||||
| if (outer_ == 1) { | |||||
| parallel_strategy_ = kOnSplitAxis; | |||||
| thread_num = std::min(SizeToInt(output_shape_[split_axis_]), max_thread_num); | |||||
| cal_num_per_thread_ = UP_DIV(output_shape_[split_axis_], thread_num); | |||||
| } else { | |||||
| parallel_strategy_ = kOnOuter; | |||||
| thread_num = std::min(outer_, max_thread_num); | |||||
| cal_num_per_thread_ = UP_DIV(outer_, thread_num); | |||||
| } | |||||
| slice_param_.op_parameter_.thread_num_ = thread_num; | |||||
| } | |||||
| void StridedSliceCPUKernel::InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end, | |||||
| const std::vector<int64_t> &stride) { | |||||
| static const std::unordered_map<TypeId, std::pair<LiteDataType, int>> type_convert_map = { | |||||
| {kNumberTypeBool, {kDataTypeBool, sizeof(bool)}}, | |||||
| {kNumberTypeInt32, {kDataTypeInt, sizeof(int)}}, | |||||
| {kNumberTypeFloat32, {kDataTypeFloat, sizeof(float)}}, | |||||
| {kNumberTypeFloat64, {kDataTypeFloat64, sizeof(double)}}}; | |||||
| auto type_pair = type_convert_map.find(dtype_); | |||||
| if (type_pair == type_convert_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "StridedSlice supports bool, int32, float32 and float64 input tensor, but got " | |||||
| << TypeIdToType(dtype_)->ToString(); | |||||
| } | |||||
| data_size_ = type_pair->second.second; | |||||
| slice_param_.data_type = type_pair->second.first; | |||||
| for (size_t i = 0; i < DIMENSION_8D; i++) { | |||||
| if (i < begin.size()) { | |||||
| int dim_len = SizeToInt(input_shape_[i]); | |||||
| int begin_pos = LongToInt(begin[i]); | |||||
| int end_pos = LongToInt(end[i]); | |||||
| int stride_size = LongToInt(stride[i]); | |||||
| if (stride_size == 0) { | |||||
| MS_LOG(EXCEPTION) << "StridedSlice requires the each dimension slice stride can't be 0."; | |||||
| } | |||||
| slice_param_.in_shape_[i] = dim_len; | |||||
| slice_param_.strides_[i] = stride_size; | |||||
| slice_param_.begins_[i] = NormalizePos(begin_pos, dim_len, kBegin); | |||||
| slice_param_.ends_[i] = NormalizePos(end_pos, dim_len, kEnd); | |||||
| if (slice_param_.ends_[i] <= slice_param_.begins_[i] && slice_param_.strides_[i] > 0) { | |||||
| slice_param_.ends_[i] = slice_param_.begins_[i] + 1; | |||||
| } | |||||
| if (slice_param_.ends_[i] >= slice_param_.begins_[i] && slice_param_.strides_[i] < 0) { | |||||
| slice_param_.ends_[i] = slice_param_.begins_[i] - 1; | |||||
| } | |||||
| } else if (i < input_shape_.size()) { | |||||
| int dim_len = SizeToInt(input_shape_[i]); | |||||
| slice_param_.in_shape_[i] = dim_len; | |||||
| slice_param_.begins_[i] = 0; | |||||
| slice_param_.ends_[i] = dim_len; | |||||
| slice_param_.strides_[i] = 1; | |||||
| } else { | |||||
| slice_param_.in_shape_[i] = 1; | |||||
| slice_param_.begins_[i] = 0; | |||||
| slice_param_.ends_[i] = 1; | |||||
| slice_param_.strides_[i] = 1; | |||||
| } | |||||
| } | |||||
| slice_param_.in_shape_length_ = DIMENSION_8D; | |||||
| slice_param_.num_axes_ = DIMENSION_8D; | |||||
| } | |||||
| int StridedSliceCPUKernel::RunTaskOnOuter(uint8_t *input_addr, uint8_t *output_addr, int start_pos) { | |||||
| int begin_index = slice_param_.begins_[split_axis_]; | |||||
| int inner_size = inner_ * data_size_; | |||||
| uint8_t *cur_in_ptr = input_addr + (start_pos * input_shape_[split_axis_] + begin_index) * inner_size; | |||||
| uint8_t *cur_out_ptr = output_addr + start_pos * output_shape_[split_axis_] * inner_size; | |||||
| int cur_outer = outer_ - start_pos; | |||||
| if (cur_outer <= 0) { | |||||
| return common::SUCCESS; | |||||
| } | |||||
| cur_outer = cur_outer > cal_num_per_thread_ ? cal_num_per_thread_ : cur_outer; | |||||
| FastStride(cur_in_ptr, cur_out_ptr, output_shape_[split_axis_], slice_param_.strides_[split_axis_], cur_outer, | |||||
| inner_size, input_shape_[split_axis_] * inner_size); | |||||
| return common::SUCCESS; | |||||
| } | |||||
| int StridedSliceCPUKernel::RunTaskOnSplitAxis(uint8_t *input_addr, uint8_t *output_addr, int start_pos) { | |||||
| int begin_index = slice_param_.begins_[split_axis_]; | |||||
| int inner_size = inner_ * data_size_; | |||||
| uint8_t *cur_in_ptr = input_addr + (start_pos * slice_param_.strides_[split_axis_] + begin_index) * inner_size; | |||||
| uint8_t *cur_out_ptr = output_addr + start_pos * inner_size; | |||||
| int cal_axis_num = output_shape_[split_axis_] - start_pos; | |||||
| if (cal_axis_num <= 0) { | |||||
| return common::SUCCESS; | |||||
| } | |||||
| cal_axis_num = cal_axis_num > cal_num_per_thread_ ? cal_num_per_thread_ : cal_axis_num; | |||||
| FastStride(cur_in_ptr, cur_out_ptr, cal_axis_num, slice_param_.strides_[split_axis_], 1, inner_size, 0); | |||||
| return common::SUCCESS; | |||||
| } | |||||
| void StridedSliceCPUKernel::ParallelRun(uint8_t *input_addr, uint8_t *output_addr, int thread_num) { | |||||
| int thread_index = 0; | |||||
| std::vector<common::Task> tasks; | |||||
| std::function<int(StridedSliceCPUKernel *, uint8_t *, uint8_t *, int)> execute_func; | |||||
| if (parallel_strategy_ == kOnOuter) { | |||||
| execute_func = &StridedSliceCPUKernel::RunTaskOnOuter; | |||||
| } else if (parallel_strategy_ == kOnSplitAxis) { | |||||
| execute_func = &StridedSliceCPUKernel::RunTaskOnSplitAxis; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Not supported parallel execute strategy for StridedSlice."; | |||||
| } | |||||
| while (thread_index < thread_num) { | |||||
| tasks.emplace_back(std::bind(execute_func, this, input_addr, output_addr, thread_index * cal_num_per_thread_)); | |||||
| thread_index++; | |||||
| } | |||||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||||
| } | |||||
| bool StridedSliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (outputs[0]->size == 0) { | |||||
| return true; | |||||
| } | |||||
| auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr); | |||||
| auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr); | |||||
| int thread_num = slice_param_.op_parameter_.thread_num_; | |||||
| if (parallel_ && thread_num >= 2) { | |||||
| ParallelRun(input_addr, output_addr, thread_num); | |||||
| } else { | |||||
| DoStridedSlice(input_addr, output_addr, &slice_param_); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| #include "nnacl/fp32/strided_slice_fp32.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class StridedSliceCPUKernel : public CPUKernel { | |||||
| public: | |||||
| StridedSliceCPUKernel() = default; | |||||
| ~StridedSliceCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| private: | |||||
| enum ParallelStrategy { kOnSplitAxis, kOnOuter }; | |||||
| void InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end, | |||||
| const std::vector<int64_t> &stride); | |||||
| bool MatchParallelPattern(); | |||||
| void InitParallelParam(); | |||||
| void ParallelRun(uint8_t *input_addr, uint8_t *output_addr, int thread_num); | |||||
| int RunTaskOnOuter(uint8_t *input_addr, uint8_t *output_addr, int start_pos); | |||||
| int RunTaskOnSplitAxis(uint8_t *input_addr, uint8_t *output_addr, int start_pos); | |||||
| TypeId dtype_; | |||||
| int data_size_{4}; | |||||
| int split_axis_{-1}; | |||||
| int inner_{1}; | |||||
| int outer_{1}; | |||||
| int cal_num_per_thread_{1}; | |||||
| bool parallel_{false}; | |||||
| ParallelStrategy parallel_strategy_{kOnSplitAxis}; | |||||
| std::vector<size_t> input_shape_; | |||||
| std::vector<size_t> output_shape_; | |||||
| StridedSliceParameter slice_param_; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||||
| StridedSliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| StridedSliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| StridedSliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| StridedSliceCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_ | |||||
| @@ -64,7 +64,7 @@ int ReduceCPUKernel::CallReduceUnit(int task_id) { | |||||
| } | } | ||||
| reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_), | reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_), | ||||
| static_cast<float *>(dst_data_), task_id, context_->thread_num_); | static_cast<float *>(dst_data_), task_id, context_->thread_num_); | ||||
| } else if (data_type_ == KDataTypeBool) { | |||||
| } else if (data_type_ == kDataTypeBool) { | |||||
| if (!bool_reducer_) { | if (!bool_reducer_) { | ||||
| MS_LOG(ERROR) << "function bool_reducer_ is null."; | MS_LOG(ERROR) << "function bool_reducer_ is null."; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -96,7 +96,7 @@ int ReduceCPUKernel::Run() { | |||||
| if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) { | if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) { | ||||
| data_type_ = kDataTypeFloat; | data_type_ = kDataTypeFloat; | ||||
| } else if (in_tensors().at(0)->data_type() == kNumberTypeBool) { | } else if (in_tensors().at(0)->data_type() == kNumberTypeBool) { | ||||
| data_type_ = KDataTypeBool; | |||||
| data_type_ = kDataTypeBool; | |||||
| } else { | } else { | ||||
| data_type_ = kDataTypeInt; | data_type_ = kDataTypeInt; | ||||
| } | } | ||||
| @@ -183,7 +183,7 @@ int ReduceCPUKernel::MallocTmpBuffer() { | |||||
| void *buffer = nullptr; | void *buffer = nullptr; | ||||
| if (data_type_ == kDataTypeFloat) { | if (data_type_ == kDataTypeFloat) { | ||||
| buffer = context_->allocator->Malloc(size * sizeof(float)); | buffer = context_->allocator->Malloc(size * sizeof(float)); | ||||
| } else if (data_type_ == KDataTypeBool) { | |||||
| } else if (data_type_ == kDataTypeBool) { | |||||
| buffer = context_->allocator->Malloc(size * sizeof(bool)); | buffer = context_->allocator->Malloc(size * sizeof(bool)); | ||||
| } else { | } else { | ||||
| buffer = context_->allocator->Malloc(size * sizeof(int)); | buffer = context_->allocator->Malloc(size * sizeof(int)); | ||||