From: @TFbunny Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @tom__chentags/v1.1.0
| @@ -17,20 +17,20 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" | |||
| template <typename S> | |||
| __global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { | |||
| __global__ void AssignToOutput(const int64_t size, const S prob_val, S *output_array) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| output_array[pos] = prob_val; | |||
| } | |||
| } | |||
| template <typename S> | |||
| void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||
| S *sampled_expected_count, cudaStream_t cuda_stream) { | |||
| void CalUniformCandidateSampler(const int64_t true_size, const int64_t num_sampled, const S prob_val, | |||
| S *true_expected_count, S *sampled_expected_count, cudaStream_t cuda_stream) { | |||
| AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count); | |||
| AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val, | |||
| sampled_expected_count); | |||
| } | |||
| template void CalUniformCandidateSampler<float>(const int true_size, const int num_sampled, const float prob_val, | |||
| float *true_expected_count, float *sampled_expected_count, | |||
| cudaStream_t cuda_stream); | |||
| template void CalUniformCandidateSampler<float>(const int64_t true_size, const int64_t num_sampled, | |||
| const float prob_val, float *true_expected_count, | |||
| float *sampled_expected_count, cudaStream_t cuda_stream); | |||
| @@ -20,7 +20,7 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename S> | |||
| void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||
| S *sampled_expected_count, cudaStream_t cuda_stream); | |||
| void CalUniformCandidateSampler(const int64_t true_size, const int64_t num_sampled, const S prob_val, | |||
| S *true_expected_count, S *sampled_expected_count, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ | |||
| @@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UniformCandidateSamplerGpuKernel, int, float) | |||
| MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| UniformCandidateSamplerGpuKernel, int64_t, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include <set> | |||
| #include <vector> | |||
| #include <random> | |||
| #include <limits> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" | |||
| @@ -55,15 +56,15 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { | |||
| set_input_.insert(item); | |||
| } | |||
| } | |||
| int counter = Sampling(); | |||
| float prob = Probability(); | |||
| int64_t counter = Sampling(); | |||
| S prob = Probability(); | |||
| size_t sampled_candidates_size = num_sampled_ * sizeof(T); | |||
| S value = ApproximateExpectedCount(prob, num_sampled_, counter); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync sampled_candidates failed"); | |||
| CalUniformCandidateSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count, | |||
| CalUniformCandidateSampler(static_cast<int64_t>(input_size_), num_sampled_, value, true_expected_count, | |||
| sampled_expected_count, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -81,11 +82,11 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| // getting attrs | |||
| num_true_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_true")); | |||
| num_sampled_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_sampled")); | |||
| num_true_ = GetAttr<int64_t>(kernel_node, "num_true"); | |||
| num_sampled_ = GetAttr<int64_t>(kernel_node, "num_sampled"); | |||
| unique_ = GetAttr<bool>(kernel_node, "unique"); | |||
| range_max_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "range_max")); | |||
| int seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); | |||
| range_max_ = GetAttr<int64_t>(kernel_node, "range_max"); | |||
| int64_t seed = GetAttr<int64_t>(kernel_node, "seed"); | |||
| remove_accidental_hits_ = GetAttr<bool>(kernel_node, "remove_accidental_hits"); | |||
| if (seed == 0) seed = time(NULL); | |||
| generator_.seed(seed); | |||
| @@ -95,7 +96,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| input_size_ = input_shape[0] * input_shape[1]; | |||
| if (num_sampled_ * num_true_ + static_cast<int>(input_size_) > range_max_ * num_true_) { | |||
| if (num_sampled_ * num_true_ + static_cast<int64_t>(input_size_) > range_max_ * num_true_) { | |||
| remove_accidental_hits_ = false; | |||
| } | |||
| InitSizeLists(); | |||
| @@ -110,13 +111,18 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { | |||
| output_size_list_.push_back(num_sampled_ * sizeof(S)); | |||
| } | |||
| int Sampling() { | |||
| int counter = 0; | |||
| int tmp; | |||
| int picked; | |||
| std::set<int> set_container; | |||
| int64_t Sampling() { | |||
| int64_t counter = 0; | |||
| T tmp; | |||
| int64_t picked; | |||
| std::set<T> set_container; | |||
| // pick between [0, range_max_-1] | |||
| std::uniform_int_distribution<int> distribution(0, range_max_ - 1); | |||
| T range; | |||
| if (range_max_ > static_cast<int64_t>(std::numeric_limits<T>::max())) { | |||
| MS_LOG(EXCEPTION) << "range_max_ failed to cast"; | |||
| } | |||
| range = static_cast<T>(range_max_); | |||
| std::uniform_int_distribution<T> distribution(0, range - 1); | |||
| sampled_candidates_.clear(); | |||
| if (unique_) { | |||
| picked = 0; | |||
| @@ -131,7 +137,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { | |||
| } | |||
| } | |||
| } else { | |||
| for (int i = 0; i < num_sampled_; i++) { | |||
| for (int64_t i = 0; i < num_sampled_; i++) { | |||
| sampled_candidates_.push_back(distribution(generator_)); | |||
| } | |||
| counter = num_sampled_; | |||
| @@ -139,24 +145,31 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel { | |||
| return counter; | |||
| } | |||
| S Probability() { return static_cast<S>(1.0f / range_max_); } | |||
| S Probability() { | |||
| S range; | |||
| if (range_max_ > static_cast<int64_t>(std::numeric_limits<S>::max())) { | |||
| MS_LOG(EXCEPTION) << "range_max_ failed to cast"; | |||
| } | |||
| range = static_cast<S>(range_max_); | |||
| return static_cast<S>(1.0f / range); | |||
| } | |||
| S ApproximateExpectedCount(S p, int sampled_size, int counter) { | |||
| S ApproximateExpectedCount(S p, int64_t sampled_size, int64_t counter) { | |||
| if (sampled_size == counter) return p * sampled_size; | |||
| return -std::expm1(counter * std::log1p(-p)); | |||
| } | |||
| private: | |||
| int num_true_; | |||
| int num_sampled_; | |||
| int64_t num_true_; | |||
| int64_t num_sampled_; | |||
| bool unique_; | |||
| int range_max_; | |||
| int64_t range_max_; | |||
| size_t input_size_; | |||
| bool remove_accidental_hits_; | |||
| std::vector<T> array_input_; | |||
| std::set<int> set_input_; | |||
| std::set<T> set_input_; | |||
| std::default_random_engine generator_; | |||
| std::vector<int> sampled_candidates_; | |||
| std::vector<T> sampled_candidates_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -583,7 +583,9 @@ class UniformCandidateSampler(PrimitiveWithInfer): | |||
| self.num_sampled = num_sampled | |||
| def infer_dtype(self, true_classes_type): | |||
| Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name) | |||
| Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) | |||
| Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, | |||
| (mstype.int32, mstype.int64), self.name) | |||
| return (true_classes_type, mstype.float32, mstype.float32) | |||
| def infer_shape(self, true_classes_shape): | |||
| @@ -39,6 +39,14 @@ def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max): | |||
| out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32))) | |||
| return out1.shape, out2.shape, out3.shape | |||
| def uniform_candidate_sampler_int64(x, num_true, num_sampled, unique, range_max): | |||
| uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true, | |||
| num_sampled, | |||
| unique, | |||
| range_max) | |||
| out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int64))) | |||
| return out1.shape, out2.shape, out3.shape | |||
| class UniformCandidateSamplerHitNet(nn.Cell): | |||
| def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): | |||
| @@ -155,6 +163,19 @@ def test_uniform_candidate_sampler_large_random(): | |||
| np.testing.assert_array_equal(ms2, expected_2) | |||
| np.testing.assert_array_equal(ms3, expected_3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_uniform_candidate_sampler_large_random_int64_input(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| ms1, ms2, ms3 = uniform_candidate_sampler_int64(np.arange(2142).reshape(34, 63), | |||
| 63, 10, False, 12) | |||
| expected_1 = (10,) | |||
| expected_2 = (34, 63) | |||
| expected_3 = (10,) | |||
| np.testing.assert_array_equal(ms1, expected_1) | |||
| np.testing.assert_array_equal(ms2, expected_2) | |||
| np.testing.assert_array_equal(ms3, expected_3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||