Browse Source

!20120 InTopK gpu kernel bugfix

Merge pull request !20120 from Peilin/topk-bugfix-index
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
51cd4215be
4 changed files with 40 additions and 10 deletions
  1. +24
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/in_top_k_gpu_kernel.h
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cu
  3. +0
    -1
      mindspore/ops/operations/nn_ops.py
  4. +15
    -5
      tests/st/ops/gpu/test_in_top_k.py

+ 24
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/in_top_k_gpu_kernel.h View File

@@ -45,6 +45,22 @@ class InTopKGpuKernel : public GpuKernel {

bool *output_device = GetDeviceAddress<bool>(outputs, 0);

if (k_ <= 0) {
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_, cudaMemsetAsync(output_device, false, outer_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync failed.");

return true;
}

if (k_ >= static_cast<int64_t>(inner_size_)) {
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_, cudaMemsetAsync(output_device, true, outer_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync failed.");

return true;
}

T *top_k_output_device = GetDeviceAddress<T>(workspace, 0);
int32_t *top_k_indices_device = GetDeviceAddress<int32_t>(workspace, 1);

@@ -76,6 +92,7 @@ class InTopKGpuKernel : public GpuKernel {
}

bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 2) {
MS_LOG(ERROR) << input_count << " inputs were provided, but InTopKGpuKernel expects 2.";
@@ -130,13 +147,17 @@ class InTopKGpuKernel : public GpuKernel {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(input_shape_[0] * sizeof(int32_t));
output_size_list_.push_back(input_shape_[0] * sizeof(bool));
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(T));
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(int32_t));
if (k_ > 0) {
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(T));
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(int32_t));
}

// remove later! urgent fix for bug: topk has incorrect output for float16
if (std::is_same<T, half>::value) {
workspace_size_list_.push_back(input_size_ * sizeof(float));
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(float));
if (k_ > 0) {
workspace_size_list_.push_back(input_shape_[0] * k_ * sizeof(float));
}
}
}



+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/in_top_k_impl.cu View File

@@ -26,7 +26,7 @@ __global__ void InTopK(const T *predictions, const int32_t *targets, bool *outpu
for (; gt_id < batch_size; gt_id += blockDim.x * gridDim.x) {
int32_t target_index = targets[gt_id];
T predicted_value = predictions[gt_id * class_id_count + target_index];
T top_k_smallest_value = top_k_output[k - 1];
T top_k_smallest_value = top_k_output[gt_id * k + k - 1];

output[gt_id] = predicted_value >= top_k_smallest_value;
}


+ 0
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -7762,7 +7762,6 @@ class InTopK(PrimitiveWithInfer):
"""Initialize InTopK"""
self.init_prim_io_names(inputs=['x1', 'x2', 'k'], outputs=['y'])
validator.check_value_type("k", k, [int], self.name)
validator.check("k", k, "", 0, Rel.GT, self.name)

def infer_dtype(self, x1_dtype, x2_dtype):
validator.check_tensor_dtype_valid("x1", x1_dtype, (mstype.float16, mstype.float32,), self.name)


+ 15
- 5
tests/st/ops/gpu/test_in_top_k.py View File

@@ -32,9 +32,22 @@ class InTopKNet(nn.Cell):
def in_top_k(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

predictions = Tensor(np.array([[9, 3, 8, 0, 0, 0, 0, 0, 0],
predictions = Tensor(np.array([[4, 1, 2, 0, 0, 0, 0, 0, 0],
[7, 9, 9, 0, 0, 0, 0, 0, 0],
[9, 9, 9, 0, 0, 0, 0, 0, 0]]).astype(nptype))
[3, 3, 3, 0, 0, 0, 0, 0, 0]]).astype(nptype))
k = 165
in_top_k_net = InTopKNet(k)
targets = Tensor(np.array([0, 1, 0]).astype(np.int32))
output = in_top_k_net(predictions, targets)
expected_output = np.array([True, True, True])
np.testing.assert_array_equal(output.asnumpy(), expected_output)

k = -2
in_top_k_net = InTopKNet(k)
targets = Tensor(np.array([0, 1, 0]).astype(np.int32))
output = in_top_k_net(predictions, targets)
expected_output = np.array([False, False, False])
np.testing.assert_array_equal(output.asnumpy(), expected_output)

k = 1
in_top_k_net = InTopKNet(k)
@@ -104,9 +117,6 @@ def test_in_top_k_float32():
@pytest.mark.env_onecard
def test_in_top_k_invalid_input():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# k must be > 0
with pytest.raises(ValueError):
in_top_k_net = InTopKNet(0)

# predictions must be 2d
with pytest.raises(ValueError):


Loading…
Cancel
Save