Browse Source

gpu support UnsortedSegmentSum kernel

tags/v0.3.0-alpha
wilfChen 5 years ago
parent
commit
31f3611f9a
5 changed files with 326 additions and 0 deletions
  1. +42
    -0
      mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc
  2. +90
    -0
      mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h
  3. +56
    -0
      mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu
  4. +27
    -0
      mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh
  5. +111
    -0
      tests/st/ops/gpu/test_unsorted_segment_sum.py

+ 42
- 0
mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
UnsortedSegmentSumGpuKernel, float, int)

MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
UnsortedSegmentSumGpuKernel, float, int64_t)

MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSum,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
UnsortedSegmentSumGpuKernel, int, int)

MS_REG_GPU_KERNEL_TWO(
UnsortedSegmentSum,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
UnsortedSegmentSumGpuKernel, int, int64_t)

} // namespace kernel
} // namespace mindspore

+ 90
- 0
mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h View File

@@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_

#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh"

namespace mindspore {
namespace kernel {
template <typename T, typename S>
class UnsortedSegmentSumGpuKernel : public GpuKernel {
public:
UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {}
~UnsortedSegmentSumGpuKernel() override = default;

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);

CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemsetAsync(output_addr, 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemSet Failed");
UnsortedSegmentSum(input_dim0_, input_dim1_, output_dim0_, output_dim1_, input_addr, indices_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);

input_dim0_ = input_shapes[0];
for (size_t i = 1; i < input_shapes.size(); i++) {
input_dim1_ *= input_shapes[i];
}

output_dim0_ = output_shapes[0];
for (size_t i = 1; i < output_shapes.size(); i++) {
output_dim1_ *= output_shapes[i];
}

InitSizeLists();
return true;
}

protected:
void InitSizeLists() override {
input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T));
input_size_list_.push_back(output_dim0_ * sizeof(S));
input_size_list_.push_back(output_dim0_ * sizeof(int));
output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(S));
}

private:
size_t input_dim0_;
size_t input_dim1_;
size_t output_dim0_;
size_t output_dim1_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_

+ 56
- 0
mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu View File

@@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh"

template<typename T, typename S>
__global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
T* input_addr, S* ids_addr, T* output_addr) {
for (int input_index = blockIdx.x * blockDim.x + threadIdx.x; input_index < input_dim0 * input_dim1;
input_index += blockDim.x * gridDim.x) {
size_t j = input_index / input_dim1;
size_t k = input_index % input_dim1;

S i = ids_addr[j];
if (i < 0 || i >= output_dim0) {
continue;
}
size_t output_index = i * output_dim1 + k;
atomicAdd(output_addr + output_index, input_addr[input_index]);
}
}

template<typename T, typename S>
void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
T* input_addr, S* ids_addr, T* output_addr, cudaStream_t stream) {
int size = input_dim0 * input_dim1;
UnsortedSegmentSum<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input_dim0, input_dim1,
output_dim0, output_dim1, input_addr, ids_addr, output_addr);
return;
}

template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
float* input_addr, int* ids_addr, float* output_addr, cudaStream_t stream);
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
float* input_addr, int64_t* ids_addr, float* output_addr, cudaStream_t stream);

template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
int* input_addr, int* ids_addr, int* output_addr, cudaStream_t stream);
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
int* input_addr, int64_t* ids_addr, int* output_addr, cudaStream_t stream);




+ 27
- 0
mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh View File

@@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_

#include <cuda_runtime.h>
#include "device/gpu/cuda_common.h"

template<typename T, typename S>
void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
T* input_addr, S* ids, T* output_addr, cudaStream_t stream);

#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_

+ 111
- 0
tests/st/ops/gpu/test_unsorted_segment_sum.py View File

@@ -0,0 +1,111 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import pytest
import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
import mindspore.nn as nn
import mindspore.context as context
from mindspore.common import dtype as mstype

context.set_context(device_target='GPU')

class UnsortedSegmentSumNet(nn.Cell):
def __init__(self, num_segments):
super(UnsortedSegmentSumNet, self).__init__()
self.unsorted_segment_sum = P.UnsortedSegmentSum()
self.num_segments = num_segments

def construct(self, data, ids):
return self.unsorted_segment_sum(data, ids, self.num_segments)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_1D():
input_x = Tensor([1, 2, 3, 4], mstype.float32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
num_segments = 4

net = UnsortedSegmentSumNet(num_segments)
output = net(input_x, segment_ids)
expect = [3, 3, 4, 0]
assert (output.asnumpy() == expect).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_2D():
input_x = Tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], mstype.float32)
segment_ids = Tensor([2, 1, 1], mstype.int32)
num_segments = 4

net = UnsortedSegmentSumNet(num_segments)
output = net(input_x, segment_ids)
expect = [[ 0, 0, 0, 0],
[14, 16, 18, 20],
[ 1, 2, 3, 4],
[ 0, 0, 0, 0]]
assert (output.asnumpy() == expect).all()



@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3D():
input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3))
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
num_segments = 5

net = UnsortedSegmentSumNet(num_segments)
output = net(input_x, segment_ids)
expect = [[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]],

[[45., 47., 49.],
[51., 53., 55.],
[57., 59., 61.],
[63., 65., 67.],
[69., 71., 73.]],

[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.],
[12., 13., 14.]],

[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]],

[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]]]
assert (output.asnumpy() == expect).all()

Loading…
Cancel
Save