Browse Source

!11089 General reduction with hybrid mode

From: @jonwe
Reviewed-by: @robingrosman,@tom__chen,@robingrosman,@tom__chen
Signed-off-by: @tom__chen
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
f07bb3bd04
5 changed files with 416 additions and 72 deletions
  1. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h
  2. +0
    -55
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu
  3. +321
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu
  4. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh
  5. +87
    -9
      tests/st/ops/gpu/test_argmaxwithvalue_op.py

+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h View File

@@ -20,7 +20,7 @@
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
@@ -38,8 +38,8 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 1);
S *index = GetDeviceAddress<S>(outputs, 0);
CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalGeneralReduction(false, input, bound_, outerSize_, innerSize_, index, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}



+ 0
- 55
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu View File

@@ -1,55 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "argmaxwithvalue_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
template <typename T, typename S>
__global__ void ArgmaxWithValue(const T *input, const S bound, size_t outerSize,
size_t innerSize, S *index, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize;
pos += gridDim.x * blockDim.x) {
size_t x = pos / innerSize % outerSize;
size_t y = pos % innerSize;
S idx = 0;
size_t InputOffset = x * bound * innerSize + 0 * innerSize + y;
T maxData = input[InputOffset];
for (S i = 0; i < bound; i++) {
InputOffset = x * bound * innerSize + i * innerSize + y;
auto inputData = input[InputOffset];
idx = inputData > maxData ? i : idx;
maxData = inputData > maxData ? inputData : maxData;
}
output[pos] = maxData;
index[pos] = idx;
}
return;
}

template <typename T, typename S>
void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_,
S *index, T *output, cudaStream_t cuda_stream) {
ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_,
index, output);
return;
}

template void CalArgmaxWithValue<float, int>(const float *input, const int bound_, const size_t outerSize_,
const size_t innerSize_, int *index, float *output,
cudaStream_t cuda_stream);
template void CalArgmaxWithValue<half, int>(const half *input, const int bound_, const size_t outerSize_,
const size_t innerSize_, int *index, half *output,
cudaStream_t cuda_stream);

+ 321
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu View File

@@ -0,0 +1,321 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>
#include <limits>
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
#include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh"

const int kWarpSize = 32;
const int kBlockSize = 512;
const int kWarpGroup = 4;
const int kNumWarps = kBlockSize / kWarpSize; // 16
const int kGroupSize = kWarpGroup * kWarpSize; // 128

// Mode selection constant
const int kMaxThreadLoop = 4;
const int kMaxWarpLoop = kWarpSize * 3; // 32 * 3 = 96
const int kMaxGroupLoop = kGroupSize * 3; // 128 * 3 =
// 384

template <typename T>
struct Cmp {
__device__ static inline bool lt(T a, T b) { return a <= b; }
__device__ static inline bool gt(T a, T b) { return a >= b; }
};

template <typename T>
inline __device__ void ConditionAssign(bool is_assign, T *x, const T &y) {
(*x) = is_assign ? y : (*x);
}

template <typename T, typename S>
__global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
T *output, S *output_index, bool fp16_flag, T init_K) {
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}

const S init_V = static_cast<S>(-1);

for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < outer_size * inner_size;
t_idx += blockDim.x * gridDim.x) {
int outer_id = t_idx / inner_size;
int inner_id = t_idx % inner_size;

T threadK = init_K;
S threadV = init_V;

for (int i = 0; i < bound; i++) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}

output[outer_id * inner_size + inner_id] = threadK;
output_index[outer_id * inner_size + inner_id] = threadV;
}
}

template <typename T, typename S>
__global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output,
S *output_index, bool fp16_flag, T init_K) {
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}
const S init_V = static_cast<S>(-1);

for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kWarpSize * outer_size * inner_size;
t_idx += blockDim.x * gridDim.x) {
int outer_id = t_idx / kWarpSize / inner_size;
int inner_id = t_idx / kWarpSize % inner_size;

int laneId = threadIdx.x % kWarpSize;

T threadK = init_K;
S threadV = init_V;

for (int i = laneId; i < bound; i += kWarpSize) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
__syncwarp();

for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}

__syncwarp();

if (laneId == 0) {
output[outer_id * inner_size + inner_id] = threadK;
output_index[outer_id * inner_size + inner_id] = threadV;
}
__syncthreads();
}
}

template <typename T, typename S>
__global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
T *output, S *output_index, bool fp16_flag, T init_K) {
__shared__ T shared_K[kNumWarps];
__shared__ S shared_V[kNumWarps];
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}
const S init_V = static_cast<S>(-1);

for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kGroupSize * outer_size * inner_size;
t_idx += blockDim.x * gridDim.x) {
int outer_id = t_idx / kGroupSize / inner_size;
int inner_id = t_idx / kGroupSize % inner_size;

int groupId = threadIdx.x / kGroupSize;
int tgId = threadIdx.x % kGroupSize;
int warpId = threadIdx.x / kWarpSize;
int laneId = threadIdx.x % kWarpSize;

T threadK = init_K;
S threadV = init_V;

if (laneId == 0) {
shared_K[warpId] = init_K;
shared_V[warpId] = init_V;
}
__syncthreads();

for (int i = tgId; i < bound; i += kGroupSize) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
__syncwarp();

for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}

__syncwarp();

if (laneId == 0) {
shared_K[warpId] = threadK;
shared_V[warpId] = threadV;
}
__syncthreads();

if (tgId < 2) {
bool is_winner =
small ? Cmp<T>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2])
: Cmp<T>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2]);
ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId),
(shared_K[(groupId * kWarpGroup) + tgId + 2]));
ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId),
(shared_V[(groupId * kWarpGroup) + tgId + 2]));
}
__syncwarp();

if (tgId == 0) {
bool is_winner =
small ? Cmp<T>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1])
: Cmp<T>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1]);
ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId),
(shared_K[(groupId * kWarpGroup) + tgId + 1]));
ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId),
(shared_V[(groupId * kWarpGroup) + tgId + 1]));

// The first thread of each group write output
output[outer_id * inner_size + inner_id] = shared_K[groupId * kWarpGroup];
output_index[outer_id * inner_size + inner_id] = shared_V[groupId * kWarpGroup];
}
__syncthreads();
}
}

template <typename T, typename S>
__global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
T *output, S *output_index, bool fp16_flag, T init_K) {
__shared__ T shared_K[kNumWarps];
__shared__ S shared_V[kNumWarps];
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}
const S init_V = static_cast<S>(-1);

for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kBlockSize * outer_size * inner_size;
t_idx += blockDim.x * gridDim.x) {
int outer_id = t_idx / kBlockSize / inner_size;
int inner_id = t_idx / kBlockSize % inner_size;

int tgId = threadIdx.x % kBlockSize;
int warpId = threadIdx.x / kWarpSize;
int laneId = threadIdx.x % kWarpSize;

T threadK = init_K;
S threadV = init_V;

if (laneId == 0) {
shared_K[warpId] = init_K;
shared_V[warpId] = init_V;
}
__syncthreads();

for (int i = tgId; i < bound; i += kBlockSize) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
__syncwarp();

for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}

__syncwarp();

if (laneId == 0) {
shared_K[warpId] = threadK;
shared_V[warpId] = threadV;
}
__syncthreads();

// Shared memory reduction
// There are 16 items in shared memory, can be reduced within one warp.
if (warpId == 0) {
threadK = laneId < kNumWarps ? shared_K[laneId] : init_K;
threadV = laneId < kNumWarps ? shared_V[laneId] : init_V;
}
__syncwarp();

if (warpId == 0) {
for (int offset = kWarpSize / 4; offset > 0; offset /= 2) {
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
}
__syncwarp();

if (warpId == 0 && laneId == 0) {
output[outer_id * inner_size + inner_id] = threadK;
output_index[outer_id * inner_size + inner_id] = threadV;
}
}
}

template <typename T, typename S>
void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output,
S *output_index, cudaStream_t stream) {
int block_num_limit = outer_size * inner_size;
bool fp16_flag = false;
if (std::is_same<T, half>::value) {
fp16_flag = true;
}
T init_K = small ? std::numeric_limits<T>::lowest() : std::numeric_limits<T>::lowest();

if (bound <= kMaxThreadLoop) {
ThreadReduction<T, S><<<GET_BLOCKS(block_num_limit), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
} else if (bound <= kMaxWarpLoop) {
WarpReduction<T, S><<<GET_BLOCKS(block_num_limit * kWarpSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
} else if (bound <= kMaxGroupLoop) {
Warp4Reduction<T, S><<<GET_BLOCKS(block_num_limit * kGroupSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
} else {
BlockReduction<T, S><<<GET_BLOCKS(block_num_limit * kBlockSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
}
}

template <typename T, typename S>
void CalGeneralReduction(bool small, const T *input, const size_t bound, const size_t outerSize, const size_t innerSize,
S *index, T *output, cudaStream_t cuda_stream) {
GeneralReduction(small, outerSize, bound, innerSize, input, output, index, cuda_stream);
return;
}

template void CalGeneralReduction(bool small, const float *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, int *index, float *output, cudaStream_t cuda_stream);
template void CalGeneralReduction(bool small, const half *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, int *index, half *output, cudaStream_t cuda_stream);

mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh → mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh View File

@@ -14,9 +14,9 @@
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_
template <typename T, typename S>
void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_, S *index,
T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
void CalGeneralReduction(bool small, const T *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, S *index, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_

+ 87
- 9
tests/st/ops/gpu/test_argmaxwithvalue_op.py View File

@@ -35,18 +35,24 @@ class NetArgmaxWithValue(nn.Cell):
return (self.argmax1(x), self.argmax2(x), self.argmax3(x))


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue():
class NetArgmaxWithValueBig(nn.Cell):
def __init__(self, axis=0):
super(NetArgmaxWithValueBig, self).__init__()
self.argmax = P.ArgMaxWithValue(axis)

def construct(self, x):
return self.argmax(x)


def argmaxwithvalue_base(data_type):
x = Tensor(np.array([[1., 20., 5.],
[67., 8., 9.],
[130., 24., 15.],
[0.3, -0.4, -15.]]).astype(np.float32))
expect1 = np.array([2, 2, 2]).astype(np.float32)
expect2 = np.array([1, 0, 0, 0]).astype(np.float32)
expect11 = np.array([130, 24, 15]).astype(np.float32)
expect22 = np.array([20, 67, 130, 0.3]).astype(np.float32)
[0.3, -0.4, -15.]]).astype(data_type))
expect1 = np.array([2, 2, 2]).astype(data_type)
expect2 = np.array([1, 0, 0, 0]).astype(data_type)
expect11 = np.array([130, 24, 15]).astype(data_type)
expect22 = np.array([20, 67, 130, 0.3]).astype(data_type)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmax = NetArgmaxWithValue()
output = argmax(x)
@@ -66,3 +72,75 @@ def test_argmaxwithvalue():
assert (output[1][1].asnumpy() == expect22).all()
assert (output[2][0].asnumpy() == expect1).all()
assert (output[2][1].asnumpy() == expect11).all()


def argmaxwithvalue_3d(data_type, shape_x):
np.random.seed(876)
x_np = np.random.random(shape_x).astype(data_type)
x = Tensor(x_np)

argmax = NetArgmaxWithValueBig(0)
output = argmax(x)
expect1 = np.argmax(x_np, axis=0)
expect2 = np.maximum.reduce(x_np, 0)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()

argmax = NetArgmaxWithValueBig(1)
output = argmax(x)
expect1 = np.argmax(x_np, axis=1)
expect2 = np.maximum.reduce(x_np, 1)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()

argmax = NetArgmaxWithValueBig(2)
output = argmax(x)
expect1 = np.argmax(x_np, axis=2)
expect2 = np.maximum.reduce(x_np, 2)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_base_float32():
argmaxwithvalue_base(np.float32)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_base_float16():
argmaxwithvalue_base(np.float16)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_3d_float32():
shape_x = (2, 32, 256)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_3d_float16():
shape_x = (2, 32, 16)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float16, shape_x)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_3d_big_float32():
shape_x = (128, 1024, 1)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)

Loading…
Cancel
Save