Browse Source

return minimal index for argmaxwithvalue

tags/v1.2.0-rc1
Jonathan Yan 4 years ago
parent
commit
18f67c61c2
2 changed files with 29 additions and 17 deletions
  1. +27
    -15
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu
  2. +2
    -2
      tests/st/ops/gpu/test_argmaxwithvalue_op.py

+ 27
- 15
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu View File

@@ -32,10 +32,10 @@ const int kMaxWarpLoop = kWarpSize * 3; // 32 * 3 = 96
const int kMaxGroupLoop = kGroupSize * 3; // 128 * 3 =
// 384

template <typename T>
template <typename T, typename S>
struct Cmp {
__device__ static inline bool lt(T a, T b) { return a <= b; }
__device__ static inline bool gt(T a, T b) { return a >= b; }
__device__ static inline bool lt(T a, T b, S i, S j) { return (a < b) || ((a == b) && (i < 0 || j < i)); }
__device__ static inline bool gt(T a, T b, S i, S j) { return (a > b) || ((a == b) && (i < 0 || j < i)); }
};

template <typename T>
@@ -63,7 +63,8 @@ __global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, siz
for (int i = 0; i < bound; i++) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
@@ -94,7 +95,8 @@ __global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_
for (int i = laneId; i < bound; i += kWarpSize) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
@@ -104,7 +106,8 @@ __global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
@@ -151,7 +154,8 @@ __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size
for (int i = tgId; i < bound; i += kGroupSize) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
@@ -161,7 +165,8 @@ __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
@@ -176,8 +181,10 @@ __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size

if (tgId < 2) {
bool is_winner =
small ? Cmp<T>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2])
: Cmp<T>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2]);
small ? Cmp<T, S>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2],
shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 2])
: Cmp<T, S>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2],
shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 2]);
ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId),
(shared_K[(groupId * kWarpGroup) + tgId + 2]));
ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId),
@@ -187,8 +194,10 @@ __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size

if (tgId == 0) {
bool is_winner =
small ? Cmp<T>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1])
: Cmp<T>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1]);
small ? Cmp<T, S>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1],
shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 1])
: Cmp<T, S>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1],
shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 1]);
ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId),
(shared_K[(groupId * kWarpGroup) + tgId + 1]));
ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId),
@@ -233,7 +242,8 @@ __global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size
for (int i = tgId; i < bound; i += kBlockSize) {
T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
S other_V = i;
bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
@@ -243,7 +253,8 @@ __global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}
@@ -269,7 +280,8 @@ __global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size
T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
S other_V = __shfl_down_sync(0xffffffff, threadV, offset);

bool is_winner = small ? Cmp<T>::gt(threadK, other_K) : Cmp<T>::lt(threadK, other_K);
bool is_winner =
small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
ConditionAssign(is_winner, &threadK, other_K);
ConditionAssign(is_winner, &threadV, other_V);
}


+ 2
- 2
tests/st/ops/gpu/test_argmaxwithvalue_op.py View File

@@ -75,7 +75,7 @@ def argmaxwithvalue_base(data_type):


def argmaxwithvalue_3d(data_type, shape_x):
np.random.seed(876)
np.random.seed(2)
x_np = np.random.random(shape_x).astype(data_type)
x = Tensor(x_np)

@@ -130,7 +130,7 @@ def test_argmaxwithvalue_3d_float32():
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_3d_float16():
shape_x = (2, 32, 16)
shape_x = (2, 64, 128)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float16, shape_x)



Loading…
Cancel
Save