| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" | |||||
| template <typename S> | template <typename S> | ||||
| __global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { | __global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { | ||||
| @@ -24,13 +24,13 @@ __global__ void AssignToOutput(const int size, const S prob_val, S *output_array | |||||
| } | } | ||||
| template <typename S> | template <typename S> | ||||
| void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||||
| S *sampled_expected_count, cudaStream_t cuda_stream) { | |||||
| void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||||
| S *sampled_expected_count, cudaStream_t cuda_stream) { | |||||
| AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count); | AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count); | ||||
| AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val, | AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val, | ||||
| sampled_expected_count); | sampled_expected_count); | ||||
| } | } | ||||
| template void CalUniformSampler<float>(const int true_size, const int num_sampled, const float prob_val, | |||||
| float *true_expected_count, float *sampled_expected_count, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalUniformCandidateSampler<float>(const int true_size, const int num_sampled, const float prob_val, | |||||
| float *true_expected_count, float *sampled_expected_count, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -14,13 +14,13 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ | |||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename S> | template <typename S> | ||||
| void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||||
| S *sampled_expected_count, cudaStream_t cuda_stream); | |||||
| void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, | |||||
| S *sampled_expected_count, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ | |||||
| @@ -14,16 +14,16 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_TWO(UniformSampler, | |||||
| MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler, | |||||
| KernelAttr() | KernelAttr() | ||||
| .AddInputAttr(kNumberTypeInt32) | .AddInputAttr(kNumberTypeInt32) | ||||
| .AddOutputAttr(kNumberTypeInt32) | .AddOutputAttr(kNumberTypeInt32) | ||||
| .AddOutputAttr(kNumberTypeFloat32) | .AddOutputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32), | .AddOutputAttr(kNumberTypeFloat32), | ||||
| UniformSamplerGpuKernel, int, float) | |||||
| UniformCandidateSamplerGpuKernel, int, float) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_ | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <set> | #include <set> | ||||
| @@ -23,16 +23,16 @@ | |||||
| #include <random> | #include <random> | ||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | #include "backend/kernel_compiler/gpu/gpu_kernel.h" | ||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| class UniformSamplerGpuKernel : public GpuKernel { | |||||
| class UniformCandidateSamplerGpuKernel : public GpuKernel { | |||||
| public: | public: | ||||
| UniformSamplerGpuKernel() | |||||
| UniformCandidateSamplerGpuKernel() | |||||
| : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {} | : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {} | ||||
| ~UniformSamplerGpuKernel() override = default; | |||||
| ~UniformCandidateSamplerGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| @@ -61,20 +61,20 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, | CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, | ||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | ||||
| "cudaMemcpyAsync sampled_candidates failed"); | "cudaMemcpyAsync sampled_candidates failed"); | ||||
| CalUniformSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CalUniformCandidateSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count, | |||||
| sampled_expected_count, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 1) { | if (input_num != 1) { | ||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformSampler needs 1 input."; | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformCandidateSampler needs 1 input."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | ||||
| if (output_num != 3) { | if (output_num != 3) { | ||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformSampler has 3 outputs."; | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformCandidateSampler has 3 outputs."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| // getting attrs | // getting attrs | ||||
| @@ -88,7 +88,7 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||||
| generator_.seed(seed); | generator_.seed(seed); | ||||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | ||||
| if (input_shape.size() != 2) { | if (input_shape.size() != 2) { | ||||
| MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformSampler supports only 2-D inputs."; | |||||
| MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformCandidateSampler supports only 2-D inputs."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| input_size_ = input_shape[0] * input_shape[1]; | input_size_ = input_shape[0] * input_shape[1]; | ||||
| @@ -160,4 +160,4 @@ class UniformSamplerGpuKernel : public GpuKernel { | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_ | |||||
| @@ -303,7 +303,7 @@ class SampledSoftmaxLoss(_Loss): | |||||
| self.sampled_values = sampled_values | self.sampled_values = sampled_values | ||||
| self.remove_accidental_hits = remove_accidental_hits | self.remove_accidental_hits = remove_accidental_hits | ||||
| self.seed = seed | self.seed = seed | ||||
| self.sampler = P.UniformSampler( | |||||
| self.sampler = P.UniformCandidateSampler( | |||||
| num_true, | num_true, | ||||
| num_sampled, | num_sampled, | ||||
| True, | True, | ||||
| @@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl | |||||
| FusedSparseFtrl, FusedSparseProximalAdagrad, | FusedSparseFtrl, FusedSparseProximalAdagrad, | ||||
| ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, | ||||
| ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, | ||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformSampler) | |||||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler) | |||||
| from . import _quant_ops | from . import _quant_ops | ||||
| from ._quant_ops import * | from ._quant_ops import * | ||||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | ||||
| @@ -375,7 +375,7 @@ __all__ = [ | |||||
| "ApproximateEqual", | "ApproximateEqual", | ||||
| "InplaceUpdate", | "InplaceUpdate", | ||||
| "InTopK", | "InTopK", | ||||
| "UniformSampler", | |||||
| "UniformCandidateSampler", | |||||
| "LRN", | "LRN", | ||||
| "Mod", | "Mod", | ||||
| "PopulationCount", | "PopulationCount", | ||||
| @@ -5820,7 +5820,7 @@ class LRN(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| class UniformSampler(PrimitiveWithInfer): | |||||
| class UniformCandidateSampler(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Uniform candidate sampler. | Uniform candidate sampler. | ||||
| @@ -5848,14 +5848,14 @@ class UniformSampler(PrimitiveWithInfer): | |||||
| sampled_candidates. Shape: (num_sampled, ). | sampled_candidates. Shape: (num_sampled, ). | ||||
| Examples: | Examples: | ||||
| >>> sampler = P.UniformSampler(1, 3, False, 4) | |||||
| >>> sampler = P.UniformCandidateSampler(1, 3, False, 4) | |||||
| >>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6], | >>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6], | ||||
| [3]], dtype=np.int32))) | [3]], dtype=np.int32))) | ||||
| [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] | [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): | def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): | ||||
| """Initialize UniformSampler""" | |||||
| """Initialize UniformCandidateSampler""" | |||||
| validator.check_value_type("num_true", num_true, [int], self.name) | validator.check_value_type("num_true", num_true, [int], self.name) | ||||
| validator.check_value_type("num_sampled", num_sampled, [int], self.name) | validator.check_value_type("num_sampled", num_sampled, [int], self.name) | ||||
| validator.check_value_type("unique", unique, [bool], self.name) | validator.check_value_type("unique", unique, [bool], self.name) | ||||
| @@ -21,45 +21,55 @@ from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| class UniformSamplerNet(nn.Cell): | |||||
| class UniformCandidateSamplerNet(nn.Cell): | |||||
| def __init__(self, num_true, num_sampled, unique, range_max): | def __init__(self, num_true, num_sampled, unique, range_max): | ||||
| super(UniformSamplerNet, self).__init__() | |||||
| self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max) | |||||
| super(UniformCandidateSamplerNet, self).__init__() | |||||
| self.sampler = P.UniformCandidateSampler(num_true, num_sampled, | |||||
| unique, range_max) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return self.sampler(x) | return self.sampler(x) | ||||
| def uniform_sampler(x, num_true, num_sampled, unique, range_max): | |||||
| uniform_sampler_net = UniformSamplerNet(num_true, num_sampled, unique, range_max) | |||||
| out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) | |||||
| def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max): | |||||
| uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true, | |||||
| num_sampled, | |||||
| unique, | |||||
| range_max) | |||||
| out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32))) | |||||
| return out1.shape, out2.shape, out3.shape | return out1.shape, out2.shape, out3.shape | ||||
| class UniformSamplerHitNet(nn.Cell): | |||||
| class UniformCandidateSamplerHitNet(nn.Cell): | |||||
| def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): | def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): | ||||
| super(UniformSamplerHitNet, self).__init__() | |||||
| self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max, seed=seed, | |||||
| remove_accidental_hits=remove_accidental_hits) | |||||
| super(UniformCandidateSamplerHitNet, self).__init__() | |||||
| self.sampler = P.UniformCandidateSampler(num_true, num_sampled, unique, | |||||
| range_max, seed=seed, | |||||
| remove_accidental_hits=remove_accidental_hits) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return self.sampler(x) | return self.sampler(x) | ||||
| def uniform_sampler_hit(x, num_true, num_sampled, unique, range_max, seed, | |||||
| remove_accidental_hits): | |||||
| uniform_sampler_net = UniformSamplerHitNet(num_true, num_sampled, unique, range_max, | |||||
| seed, remove_accidental_hits) | |||||
| out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) | |||||
| def uniform_candidate_sampler_hit(x, num_true, num_sampled, unique, range_max, seed, | |||||
| remove_accidental_hits): | |||||
| uniform_candidate_sampler_net = UniformCandidateSamplerHitNet(num_true, | |||||
| num_sampled, | |||||
| unique, | |||||
| range_max, | |||||
| seed, | |||||
| remove_accidental_hits) | |||||
| out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32))) | |||||
| return out1, out2, out3 | return out1, out2, out3 | ||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_unique_1_true(): | |||||
| def test_uniform_candidate_sampler_unique_1_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, True, 4) | |||||
| ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1], [3], [4], [6], [3]]), | |||||
| 1, 3, True, 4) | |||||
| expected_1 = (3,) | expected_1 = (3,) | ||||
| expected_2 = (5, 1) | expected_2 = (5, 1) | ||||
| expected_3 = (3,) | expected_3 = (3,) | ||||
| @@ -70,9 +80,10 @@ def test_uniform_sampler_unique_1_true(): | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_not_unique_1_true(): | |||||
| def test_uniform_candidate_sampler_not_unique_1_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, False, 4) | |||||
| ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1], [3], [4], [6], [3]]), | |||||
| 1, 3, False, 4) | |||||
| expected_1 = (3,) | expected_1 = (3,) | ||||
| expected_2 = (5, 1) | expected_2 = (5, 1) | ||||
| expected_3 = (3,) | expected_3 = (3,) | ||||
| @@ -83,9 +94,11 @@ def test_uniform_sampler_not_unique_1_true(): | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_unique_2_true(): | |||||
| def test_uniform_candidate_sampler_unique_2_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, True, 4) | |||||
| ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1, 2], [3, 2], [4, 2], | |||||
| [6, 2], [3, 2]]), | |||||
| 2, 3, True, 4) | |||||
| expected_1 = (3,) | expected_1 = (3,) | ||||
| expected_2 = (5, 2) | expected_2 = (5, 2) | ||||
| expected_3 = (3,) | expected_3 = (3,) | ||||
| @@ -96,9 +109,12 @@ def test_uniform_sampler_unique_2_true(): | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_not_unique_2_true(): | |||||
| def test_uniform_candidate_sampler_not_unique_2_true(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, False, 4) | |||||
| ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1, 2], [3, 2], | |||||
| [4, 2], [6, 2], | |||||
| [3, 2]]), | |||||
| 2, 3, False, 4) | |||||
| expected_1 = (3,) | expected_1 = (3,) | ||||
| expected_2 = (5, 2) | expected_2 = (5, 2) | ||||
| expected_3 = (3,) | expected_3 = (3,) | ||||
| @@ -109,10 +125,14 @@ def test_uniform_sampler_not_unique_2_true(): | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_large(): | |||||
| def test_uniform_candidate_sampler_large(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, ms2, ms3 = uniform_sampler(np.array([[12221, 41414], [3312, 5125152], [3312454, 51252], | |||||
| [65125, 225125], [35125, 5125122]]), 2, 5, False, 100) | |||||
| ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[12221, 41414], | |||||
| [3312, 5125152], | |||||
| [3312454, 51252], | |||||
| [65125, 225125], | |||||
| [35125, 5125122]]), | |||||
| 2, 5, False, 100) | |||||
| expected_1 = (5,) | expected_1 = (5,) | ||||
| expected_2 = (5, 2) | expected_2 = (5, 2) | ||||
| expected_3 = (5,) | expected_3 = (5,) | ||||
| @@ -124,9 +144,10 @@ def test_uniform_sampler_large(): | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_large_random(): | |||||
| def test_uniform_candidate_sampler_large_random(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, ms2, ms3 = uniform_sampler(np.arange(2142).reshape(34, 63), 63, 10, False, 12) | |||||
| ms1, ms2, ms3 = uniform_candidate_sampler(np.arange(2142).reshape(34, 63), | |||||
| 63, 10, False, 12) | |||||
| expected_1 = (10,) | expected_1 = (10,) | ||||
| expected_2 = (34, 63) | expected_2 = (34, 63) | ||||
| expected_3 = (10,) | expected_3 = (10,) | ||||
| @@ -138,9 +159,9 @@ def test_uniform_sampler_large_random(): | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_unique_1_true_hit(): | |||||
| def test_uniform_candidate_sampler_unique_1_true_hit(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False) | |||||
| ms1, _, _ = uniform_candidate_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False) | |||||
| expected_1 = np.array([0, 3, 1]) | expected_1 = np.array([0, 3, 1]) | ||||
| np.testing.assert_array_equal(ms1.asnumpy(), expected_1) | np.testing.assert_array_equal(ms1.asnumpy(), expected_1) | ||||
| @@ -148,8 +169,8 @@ def test_uniform_sampler_unique_1_true_hit(): | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_uniform_sampler_unique_1_true_no_hit(): | |||||
| def test_uniform_candidate_sampler_unique_1_true_no_hit(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True) | |||||
| ms1, _, _ = uniform_candidate_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True) | |||||
| expected_1 = np.array([0, 3, 2]) | expected_1 = np.array([0, 3, 2]) | ||||
| np.testing.assert_array_equal(ms1.asnumpy(), expected_1) | np.testing.assert_array_equal(ms1.asnumpy(), expected_1) | ||||