| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" | |||||
| template <typename S> | |||||
| __global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| output_array[pos] = prob_val; | |||||
| } | |||||
| } | |||||
| template <typename S> | |||||
| void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||||
| S *sampled_expected_count, cudaStream_t cuda_stream) { | |||||
| AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count); | |||||
| AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val, | |||||
| sampled_expected_count); | |||||
| } | |||||
| template void CalUniformSampler<float>(const int true_size, const int num_sampled, const float prob_val, | |||||
| float *true_expected_count, float *sampled_expected_count, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ | |||||
| #include <cuda_runtime.h> | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename S> | |||||
| void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||||
| S *sampled_expected_count, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO(UniformSampler, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| UniformSamplerGpuKernel, int, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,144 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ | |||||
| #include <cmath> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include <random> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class UniformSamplerGpuKernel : public GpuKernel { | |||||
| public: | |||||
| UniformSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0) {} | |||||
| ~UniformSamplerGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| VARIABLE_NOT_USED(workspaces); | |||||
| T *sampled_candidates = GetDeviceAddress<T>(outputs, 0); | |||||
| S *true_expected_count = GetDeviceAddress<S>(outputs, 1); | |||||
| S *sampled_expected_count = GetDeviceAddress<S>(outputs, 2); | |||||
| int counter = Sampling(); | |||||
| float prob = Probability(); | |||||
| size_t sampled_candidates_size = num_sampled_ * sizeof(T); | |||||
| S value = ApproximateExpectedCount(prob, num_sampled_, counter); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync sampled_candidates failed"); | |||||
| CalUniformSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 1) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformSampler needs 1 input."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 3) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformSampler has 3 outputs."; | |||||
| return false; | |||||
| } | |||||
| // getting attrs | |||||
| num_true_ = GetAttr<int>(kernel_node, "num_true"); | |||||
| num_sampled_ = GetAttr<int>(kernel_node, "num_sampled"); | |||||
| unique_ = GetAttr<bool>(kernel_node, "unique"); | |||||
| range_max_ = GetAttr<int>(kernel_node, "range_max"); | |||||
| int seed = GetAttr<int>(kernel_node, "seed"); | |||||
| if (seed == 0) seed = time(NULL); | |||||
| generator_.seed(seed); | |||||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| if (input_shape.size() != 2) { | |||||
| MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformSampler supports only 2-D inputs."; | |||||
| return false; | |||||
| } | |||||
| input_size_ = input_shape[0] * input_shape[1]; | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||||
| output_size_list_.push_back(num_sampled_ * sizeof(T)); | |||||
| output_size_list_.push_back(input_size_ * sizeof(S)); | |||||
| output_size_list_.push_back(num_sampled_ * sizeof(S)); | |||||
| } | |||||
| int Sampling() { | |||||
| int counter = 0; | |||||
| int tmp; | |||||
| int picked; | |||||
| std::set<int> set_container; | |||||
| // pick between [0, range_max_-1] | |||||
| std::uniform_int_distribution<int> distribution(0, range_max_ - 1); | |||||
| sampled_candidates_.clear(); | |||||
| if (unique_) { | |||||
| picked = 0; | |||||
| while (picked < num_sampled_) { | |||||
| tmp = distribution(generator_); | |||||
| counter++; | |||||
| if (set_container.find(tmp) == set_container.end()) { | |||||
| set_container.insert(tmp); | |||||
| sampled_candidates_.push_back(tmp); | |||||
| picked++; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < num_sampled_; i++) { | |||||
| sampled_candidates_.push_back(distribution(generator_)); | |||||
| } | |||||
| counter = num_sampled_; | |||||
| } | |||||
| return counter; | |||||
| } | |||||
| S Probability() { return static_cast<S>(1.0f / range_max_); } | |||||
| S ApproximateExpectedCount(S p, int sampled_size, int counter) { | |||||
| if (sampled_size == counter) return p * sampled_size; | |||||
| return -std::expm1(counter * std::log1p(-p)); | |||||
| } | |||||
| private: | |||||
| int num_true_; | |||||
| int num_sampled_; | |||||
| bool unique_; | |||||
| int range_max_; | |||||
| size_t input_size_; | |||||
| std::default_random_engine generator_; | |||||
| std::vector<int> sampled_candidates_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ | |||||
| @@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl | |||||
| FusedSparseFtrl, FusedSparseProximalAdagrad, | FusedSparseFtrl, FusedSparseProximalAdagrad, | ||||
| ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ||||
| ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) | |||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformSampler) | |||||
| from . import _quant_ops | from . import _quant_ops | ||||
| from ._quant_ops import * | from ._quant_ops import * | ||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | ||||
| @@ -373,6 +373,7 @@ __all__ = [ | |||||
| "ApproximateEqual", | "ApproximateEqual", | ||||
| "InplaceUpdate", | "InplaceUpdate", | ||||
| "InTopK", | "InTopK", | ||||
| "UniformSampler", | |||||
| "LRN", | "LRN", | ||||
| "Mod", | "Mod", | ||||
| "PopulationCount", | "PopulationCount", | ||||
| @@ -5730,3 +5730,56 @@ class LRN(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) | validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) | ||||
| return x_shape | return x_shape | ||||
| class UniformSampler(PrimitiveWithInfer): | |||||
| r""" | |||||
| Uniform candidate sampler. | |||||
| This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. | |||||
| If unique=True, candidates are drawn without replacement, else unique=False with replacement. | |||||
| Args: | |||||
| num_true (int): The number of target classes in each training example. | |||||
| num_sampled (int): The number of classes to randomly sample. The **sampled_candidates** will have a shape | |||||
| of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. | |||||
| unique (bool): Whether all sampled classes in a batch are unique. | |||||
| range_max (int): The number of possible classes. | |||||
| seed (int): Random seed, must be non-negative. Default: 0. | |||||
| Inputs: | |||||
| true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true). | |||||
| Outputs: | |||||
| A tuple of 3 tensors. | |||||
| sampled_candidates: (int): The sampled_candidates is independent of the true classes. Shape: (num_sampled, ). | |||||
| true_expected_count: (float): The expected counts under the sampling distribution of each of true_classes. | |||||
| Shape: (batch_size, num_true). | |||||
| sampled_expected_count: (float): The expected counts under the sampling distribution of each of | |||||
| sampled_candidates. Shape: (num_sampled, ). | |||||
| Examples: | |||||
| >>> sampler = P.UniformSampler(1, 3, False, 4) | |||||
| >>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6], | |||||
| [3]], dtype=np.int32))) | |||||
| [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, num_true, num_sampled, unique, range_max, seed=0): | |||||
| """Initialize UniformSampler""" | |||||
| validator.check_value_type("num_true", num_true, [int], self.name) | |||||
| validator.check_value_type("num_sampled", num_sampled, [int], self.name) | |||||
| validator.check_value_type("unique", unique, [bool], self.name) | |||||
| validator.check_value_type("range_max", range_max, [int], self.name) | |||||
| validator.check_value_type("seed", seed, [int], self.name) | |||||
| validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name) | |||||
| if unique: | |||||
| validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name) | |||||
| validator.check("value of seed", seed, '', 0, Rel.GE, self.name) | |||||
| self.num_sampled = num_sampled | |||||
| def infer_dtype(self, true_classes_type): | |||||
| return (true_classes_type, mstype.float32, mstype.float32) | |||||
| def infer_shape(self, true_classes_shape): | |||||
| return ([self.num_sampled], true_classes_shape, [self.num_sampled]) | |||||
| @@ -0,0 +1,116 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| class UniformSamplerNet(nn.Cell): | |||||
| def __init__(self, num_true, num_sampled, unique, range_max): | |||||
| super(UniformSamplerNet, self).__init__() | |||||
| self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max) | |||||
| def construct(self, x): | |||||
| return self.sampler(x) | |||||
| def uniform_sampler(x, num_true, num_sampled, unique, range_max): | |||||
| uniform_sampler_net = UniformSamplerNet(num_true, num_sampled, unique, range_max) | |||||
| out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) | |||||
| return out1.shape, out2.shape, out3.shape | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_uniform_sampler_unique_1_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, True, 4) | |||||
| expected_1 = (3,) | |||||
| expected_2 = (5, 1) | |||||
| expected_3 = (3,) | |||||
| np.testing.assert_array_equal(ms1, expected_1) | |||||
| np.testing.assert_array_equal(ms2, expected_2) | |||||
| np.testing.assert_array_equal(ms3, expected_3) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_uniform_sampler_not_unique_1_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, False, 4) | |||||
| expected_1 = (3,) | |||||
| expected_2 = (5, 1) | |||||
| expected_3 = (3,) | |||||
| np.testing.assert_array_equal(ms1, expected_1) | |||||
| np.testing.assert_array_equal(ms2, expected_2) | |||||
| np.testing.assert_array_equal(ms3, expected_3) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_uniform_sampler_unique_2_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, True, 4) | |||||
| expected_1 = (3,) | |||||
| expected_2 = (5, 2) | |||||
| expected_3 = (3,) | |||||
| np.testing.assert_array_equal(ms1, expected_1) | |||||
| np.testing.assert_array_equal(ms2, expected_2) | |||||
| np.testing.assert_array_equal(ms3, expected_3) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_uniform_sampler_not_unique_2_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, False, 4) | |||||
| expected_1 = (3,) | |||||
| expected_2 = (5, 2) | |||||
| expected_3 = (3,) | |||||
| np.testing.assert_array_equal(ms1, expected_1) | |||||
| np.testing.assert_array_equal(ms2, expected_2) | |||||
| np.testing.assert_array_equal(ms3, expected_3) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_uniform_sampler_large(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[12221, 41414], [3312, 5125152], [3312454, 51252], | |||||
| [65125, 225125], [35125, 5125122]]), 2, 5, False, 100) | |||||
| expected_1 = (5,) | |||||
| expected_2 = (5, 2) | |||||
| expected_3 = (5,) | |||||
| np.testing.assert_array_equal(ms1, expected_1) | |||||
| np.testing.assert_array_equal(ms2, expected_2) | |||||
| np.testing.assert_array_equal(ms3, expected_3) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_uniform_sampler_large_random(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| ms1, ms2, ms3 = uniform_sampler(np.arange(2142).reshape(34, 63), 63, 10, False, 12) | |||||
| expected_1 = (10,) | |||||
| expected_2 = (34, 63) | |||||
| expected_3 = (10,) | |||||
| np.testing.assert_array_equal(ms1, expected_1) | |||||
| np.testing.assert_array_equal(ms2, expected_2) | |||||
| np.testing.assert_array_equal(ms3, expected_3) | |||||