diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc new file mode 100644 index 0000000000..a6505b0927 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h" +#include +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/thread_pool.h" + +namespace mindspore { +namespace kernel { +void UnsortedSegmentSumCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but UnsortedSegmentSum needs 2 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but UnsortedSegmentSum needs 1 output."; + } + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + segment_ids_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto segment_ids_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + unit_num_ *= input_shape[i]; + if (i >= segment_ids_shape.size()) { + input_dim1_ *= input_shape[i]; + } + } + output_dim0_ = output_shape[0]; + for (size_t j = 1; j < output_shape.size(); j++) { + output_dim1_ *= output_shape[j]; + } +} + +bool UnsortedSegmentSumCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + bool ret{true}; + if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt32) { + ret = LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt32) { + ret = LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt32 && segment_ids_dtype_ == kNumberTypeInt64) { + ret = LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32 && segment_ids_dtype_ == kNumberTypeInt64) { + ret = LaunchKernel(inputs, outputs); + } else { + MS_LOG(ERROR) << "Only support input_x int32 and float32, indices int32 and int64"; + return false; + } + return ret; +} + +template +bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + S *input_addr = reinterpret_cast(inputs[0]->addr); + T *indices_addr = reinterpret_cast(inputs[1]->addr); + S *output_addr = reinterpret_cast(outputs[0]->addr); + auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size); + if (ret != EOK) { + MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret; + return false; + } + for (size_t i = 0; i < unit_num_; ++i) { + size_t j = i / input_dim1_; + size_t k = i % input_dim1_; + + T index = indices_addr[j]; + if (index < 0 || index >= SizeToInt(output_dim0_)) { + continue; + } + size_t output_index = index * output_dim1_ + k; + output_addr[output_index] += input_addr[i]; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h new file mode 100644 index 0000000000..96ac95a559 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNSORTED_SEGMENT_SUM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNSORTED_SEGMENT_SUM_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class UnsortedSegmentSumCPUKernel : public CPUKernel { + public: + UnsortedSegmentSumCPUKernel() = default; + ~UnsortedSegmentSumCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + TypeId dtype_{kTypeUnknown}; + TypeId segment_ids_dtype_{kTypeUnknown}; + size_t unit_num_{1}; + size_t input_dim1_{1}; + size_t output_dim0_{1}; + size_t output_dim1_{1}; +}; +MS_REG_CPU_KERNEL( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentSumCPUKernel); +MS_REG_CPU_KERNEL( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentSumCPUKernel); +MS_REG_CPU_KERNEL( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentSumCPUKernel); +MS_REG_CPU_KERNEL( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentSumCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNSORTED_SEGMENT_SUM_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_unsorted_segment_sum.py b/tests/st/ops/cpu/test_unsorted_segment_sum.py new file mode 100644 index 0000000000..4f80b4c80b --- /dev/null +++ b/tests/st/ops/cpu/test_unsorted_segment_sum.py @@ -0,0 +1,105 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + +class UnsortedSegmentSumNet(nn.Cell): + def __init__(self, num_segments): + super(UnsortedSegmentSumNet, self).__init__() + self.unsorted_segment_sum = P.UnsortedSegmentSum() + self.num_segments = num_segments + + def construct(self, data, ids): + return self.unsorted_segment_sum(data, ids, self.num_segments) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_1D(): + input_x = Tensor([1, 2, 3, 4], mstype.float32) + segment_ids = Tensor([0, 0, 1, 2], mstype.int32) + num_segments = 4 + + net = UnsortedSegmentSumNet(num_segments) + output = net(input_x, segment_ids) + expect = [3, 3, 4, 0] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_2D(): + input_x = Tensor([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], mstype.float32) + segment_ids = Tensor([2, 1, 1], mstype.int32) + num_segments = 4 + + net = UnsortedSegmentSumNet(num_segments) + output = net(input_x, segment_ids) + expect = [[0, 0, 0, 0], + [14, 16, 18, 20], + [1, 2, 3, 4], + [0, 0, 0, 0]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_3D(): + input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3)) + segment_ids = Tensor([2, 1, 1, -1], mstype.int32) + num_segments = 5 + + net = UnsortedSegmentSumNet(num_segments) + output = net(input_x, segment_ids) + expect = [[[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]], + [[45., 47., 49.], + [51., 53., 55.], + [57., 59., 61.], + [63., 65., 67.], + [69., 71., 73.]], + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.], + [9., 10., 11.], + [12., 13., 14.]], + [[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]], + [[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]] + assert (output.asnumpy() == expect).all()