From 5263588b2c125a5dceab76a64e4becd7142f50f3 Mon Sep 17 00:00:00 2001 From: TFbunny Date: Thu, 10 Dec 2020 11:08:33 -0500 Subject: [PATCH] add int64 support to UniformCandidateSampler GPU --- .../uniform_candidate_sampler_impl.cu | 12 ++-- .../uniform_candidate_sampler_impl.cuh | 4 +- .../uniform_candidate_sampler_gpu_kernel.cc | 7 +++ .../uniform_candidate_sampler_gpu_kernel.h | 57 ++++++++++++------- mindspore/ops/operations/random_ops.py | 4 +- .../gpu/test_uniform_candidate_sampler_op.py | 21 +++++++ 6 files changed, 74 insertions(+), 31 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu index d57ca7907d..d1f2c0825f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu @@ -17,20 +17,20 @@ #include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" template -__global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { +__global__ void AssignToOutput(const int64_t size, const S prob_val, S *output_array) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { output_array[pos] = prob_val; } } template -void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, - S *sampled_expected_count, cudaStream_t cuda_stream) { +void CalUniformCandidateSampler(const int64_t true_size, const int64_t num_sampled, const S prob_val, + S *true_expected_count, S *sampled_expected_count, cudaStream_t cuda_stream) { AssignToOutput<<>>(true_size, prob_val, true_expected_count); AssignToOutput<<>>(num_sampled, prob_val, sampled_expected_count); } -template void CalUniformCandidateSampler(const int true_size, const int num_sampled, const float prob_val, - float *true_expected_count, float *sampled_expected_count, - cudaStream_t cuda_stream); +template void CalUniformCandidateSampler(const int64_t true_size, const int64_t num_sampled, + const float prob_val, float *true_expected_count, + float *sampled_expected_count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh index 314c3e1c86..14bc7bd458 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh @@ -20,7 +20,7 @@ #include "runtime/device/gpu/cuda_common.h" template -void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, - S *sampled_expected_count, cudaStream_t cuda_stream); +void CalUniformCandidateSampler(const int64_t true_size, const int64_t num_sampled, const S prob_val, + S *true_expected_count, S *sampled_expected_count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.cc index 8454a0a298..012791060a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.cc @@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler, .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), UniformCandidateSamplerGpuKernel, int, float) +MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + UniformCandidateSamplerGpuKernel, int64_t, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h index 30aa5958c8..7ea4ab3ed3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/uniform_candidate_sampler_gpu_kernel.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" @@ -55,15 +56,15 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { set_input_.insert(item); } } - int counter = Sampling(); - float prob = Probability(); + int64_t counter = Sampling(); + S prob = Probability(); size_t sampled_candidates_size = num_sampled_ * sizeof(T); S value = ApproximateExpectedCount(prob, num_sampled_, counter); CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync sampled_candidates failed"); - CalUniformCandidateSampler(static_cast(input_size_), num_sampled_, value, true_expected_count, + CalUniformCandidateSampler(static_cast(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count, reinterpret_cast(stream_ptr)); return true; } @@ -81,11 +82,11 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { return false; } // getting attrs - num_true_ = static_cast(GetAttr(kernel_node, "num_true")); - num_sampled_ = static_cast(GetAttr(kernel_node, "num_sampled")); + num_true_ = GetAttr(kernel_node, "num_true"); + num_sampled_ = GetAttr(kernel_node, "num_sampled"); unique_ = GetAttr(kernel_node, "unique"); - range_max_ = static_cast(GetAttr(kernel_node, "range_max")); - int seed = static_cast(GetAttr(kernel_node, "seed")); + range_max_ = GetAttr(kernel_node, "range_max"); + int64_t seed = GetAttr(kernel_node, "seed"); remove_accidental_hits_ = GetAttr(kernel_node, "remove_accidental_hits"); if (seed == 0) seed = time(NULL); generator_.seed(seed); @@ -95,7 +96,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { return false; } input_size_ = input_shape[0] * input_shape[1]; - if (num_sampled_ * num_true_ + static_cast(input_size_) > range_max_ * num_true_) { + if (num_sampled_ * num_true_ + static_cast(input_size_) > range_max_ * num_true_) { remove_accidental_hits_ = false; } InitSizeLists(); @@ -110,13 +111,18 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { output_size_list_.push_back(num_sampled_ * sizeof(S)); } - int Sampling() { - int counter = 0; - int tmp; - int picked; - std::set set_container; + int64_t Sampling() { + int64_t counter = 0; + T tmp; + int64_t picked; + std::set set_container; // pick between [0, range_max_-1] - std::uniform_int_distribution distribution(0, range_max_ - 1); + T range; + if (range_max_ > static_cast(std::numeric_limits::max())) { + MS_LOG(EXCEPTION) << "range_max_ failed to cast"; + } + range = static_cast(range_max_); + std::uniform_int_distribution distribution(0, range - 1); sampled_candidates_.clear(); if (unique_) { picked = 0; @@ -131,7 +137,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { } } } else { - for (int i = 0; i < num_sampled_; i++) { + for (int64_t i = 0; i < num_sampled_; i++) { sampled_candidates_.push_back(distribution(generator_)); } counter = num_sampled_; @@ -139,24 +145,31 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { return counter; } - S Probability() { return static_cast(1.0f / range_max_); } + S Probability() { + S range; + if (range_max_ > static_cast(std::numeric_limits::max())) { + MS_LOG(EXCEPTION) << "range_max_ failed to cast"; + } + range = static_cast(range_max_); + return static_cast(1.0f / range); + } - S ApproximateExpectedCount(S p, int sampled_size, int counter) { + S ApproximateExpectedCount(S p, int64_t sampled_size, int64_t counter) { if (sampled_size == counter) return p * sampled_size; return -std::expm1(counter * std::log1p(-p)); } private: - int num_true_; - int num_sampled_; + int64_t num_true_; + int64_t num_sampled_; bool unique_; - int range_max_; + int64_t range_max_; size_t input_size_; bool remove_accidental_hits_; std::vector array_input_; - std::set set_input_; + std::set set_input_; std::default_random_engine generator_; - std::vector sampled_candidates_; + std::vector sampled_candidates_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 04db0f4571..7e4a0f9080 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -583,7 +583,9 @@ class UniformCandidateSampler(PrimitiveWithInfer): self.num_sampled = num_sampled def infer_dtype(self, true_classes_type): - Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name) + Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) + Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, + (mstype.int32, mstype.int64), self.name) return (true_classes_type, mstype.float32, mstype.float32) def infer_shape(self, true_classes_shape): diff --git a/tests/st/ops/gpu/test_uniform_candidate_sampler_op.py b/tests/st/ops/gpu/test_uniform_candidate_sampler_op.py index 6e6628f3b4..38258e39e7 100644 --- a/tests/st/ops/gpu/test_uniform_candidate_sampler_op.py +++ b/tests/st/ops/gpu/test_uniform_candidate_sampler_op.py @@ -39,6 +39,14 @@ def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max): out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32))) return out1.shape, out2.shape, out3.shape +def uniform_candidate_sampler_int64(x, num_true, num_sampled, unique, range_max): + uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true, + num_sampled, + unique, + range_max) + out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int64))) + return out1.shape, out2.shape, out3.shape + class UniformCandidateSamplerHitNet(nn.Cell): def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): @@ -155,6 +163,19 @@ def test_uniform_candidate_sampler_large_random(): np.testing.assert_array_equal(ms2, expected_2) np.testing.assert_array_equal(ms3, expected_3) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_candidate_sampler_large_random_int64_input(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, ms2, ms3 = uniform_candidate_sampler_int64(np.arange(2142).reshape(34, 63), + 63, 10, False, 12) + expected_1 = (10,) + expected_2 = (34, 63) + expected_3 = (10,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training