Browse Source

rename UniformSampler to UniformCandidateSampler

tags/v1.1.0
TFbunny 5 years ago
parent
commit
ee4e2db77e
8 changed files with 85 additions and 64 deletions
  1. +6
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu
  2. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh
  3. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.cc
  4. +12
    -12
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h
  5. +1
    -1
      mindspore/nn/loss/loss.py
  6. +2
    -2
      mindspore/ops/operations/__init__.py
  7. +3
    -3
      mindspore/ops/operations/nn_ops.py
  8. +53
    -32
      tests/st/ops/gpu/test_uniform_candidate_sampler_op.py

mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu → mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu View File

@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh"


template <typename S> template <typename S>
__global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { __global__ void AssignToOutput(const int size, const S prob_val, S *output_array) {
@@ -24,13 +24,13 @@ __global__ void AssignToOutput(const int size, const S prob_val, S *output_array
} }


template <typename S> template <typename S>
void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream) {
void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream) {
AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count); AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count);
AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val, AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val,
sampled_expected_count); sampled_expected_count);
} }


template void CalUniformSampler<float>(const int true_size, const int num_sampled, const float prob_val,
float *true_expected_count, float *sampled_expected_count,
cudaStream_t cuda_stream);
template void CalUniformCandidateSampler<float>(const int true_size, const int num_sampled, const float prob_val,
float *true_expected_count, float *sampled_expected_count,
cudaStream_t cuda_stream);

mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh → mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh View File

@@ -14,13 +14,13 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"


template <typename S> template <typename S>
void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream);
void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream);


#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_

mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc → mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.cc View File

@@ -14,16 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */


#include "backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h"
#include "backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h"


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_TWO(UniformSampler,
MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
UniformSamplerGpuKernel, int, float)
UniformCandidateSamplerGpuKernel, int, float)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h → mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_


#include <cmath> #include <cmath>
#include <set> #include <set>
@@ -23,16 +23,16 @@
#include <random> #include <random>
#include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh"


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T, typename S> template <typename T, typename S>
class UniformSamplerGpuKernel : public GpuKernel {
class UniformCandidateSamplerGpuKernel : public GpuKernel {
public: public:
UniformSamplerGpuKernel()
UniformCandidateSamplerGpuKernel()
: num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {} : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {}
~UniformSamplerGpuKernel() override = default;
~UniformCandidateSamplerGpuKernel() override = default;


const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@@ -61,20 +61,20 @@ class UniformSamplerGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync sampled_candidates failed"); "cudaMemcpyAsync sampled_candidates failed");
CalUniformSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalUniformCandidateSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count,
sampled_expected_count, reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }


bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) { if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformSampler needs 1 input.";
MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformCandidateSampler needs 1 input.";
return false; return false;
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) { if (output_num != 3) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformSampler has 3 outputs.";
MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformCandidateSampler has 3 outputs.";
return false; return false;
} }
// getting attrs // getting attrs
@@ -88,7 +88,7 @@ class UniformSamplerGpuKernel : public GpuKernel {
generator_.seed(seed); generator_.seed(seed);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape.size() != 2) { if (input_shape.size() != 2) {
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformSampler supports only 2-D inputs.";
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformCandidateSampler supports only 2-D inputs.";
return false; return false;
} }
input_size_ = input_shape[0] * input_shape[1]; input_size_ = input_shape[0] * input_shape[1];
@@ -160,4 +160,4 @@ class UniformSamplerGpuKernel : public GpuKernel {
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_

+ 1
- 1
mindspore/nn/loss/loss.py View File

@@ -303,7 +303,7 @@ class SampledSoftmaxLoss(_Loss):
self.sampled_values = sampled_values self.sampled_values = sampled_values
self.remove_accidental_hits = remove_accidental_hits self.remove_accidental_hits = remove_accidental_hits
self.seed = seed self.seed = seed
self.sampler = P.UniformSampler(
self.sampler = P.UniformCandidateSampler(
num_true, num_true,
num_sampled, num_sampled,
True, True,


+ 2
- 2
mindspore/ops/operations/__init__.py View File

@@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
FusedSparseFtrl, FusedSparseProximalAdagrad, FusedSparseFtrl, FusedSparseProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformSampler)
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler)
from . import _quant_ops from . import _quant_ops
from ._quant_ops import * from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
@@ -375,7 +375,7 @@ __all__ = [
"ApproximateEqual", "ApproximateEqual",
"InplaceUpdate", "InplaceUpdate",
"InTopK", "InTopK",
"UniformSampler",
"UniformCandidateSampler",
"LRN", "LRN",
"Mod", "Mod",
"PopulationCount", "PopulationCount",


+ 3
- 3
mindspore/ops/operations/nn_ops.py View File

@@ -5820,7 +5820,7 @@ class LRN(PrimitiveWithInfer):
return x_shape return x_shape




class UniformSampler(PrimitiveWithInfer):
class UniformCandidateSampler(PrimitiveWithInfer):
r""" r"""
Uniform candidate sampler. Uniform candidate sampler.


@@ -5848,14 +5848,14 @@ class UniformSampler(PrimitiveWithInfer):
sampled_candidates. Shape: (num_sampled, ). sampled_candidates. Shape: (num_sampled, ).


Examples: Examples:
>>> sampler = P.UniformSampler(1, 3, False, 4)
>>> sampler = P.UniformCandidateSampler(1, 3, False, 4)
>>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6], >>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6],
[3]], dtype=np.int32))) [3]], dtype=np.int32)))
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
""" """
@prim_attr_register @prim_attr_register
def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
"""Initialize UniformSampler"""
"""Initialize UniformCandidateSampler"""
validator.check_value_type("num_true", num_true, [int], self.name) validator.check_value_type("num_true", num_true, [int], self.name)
validator.check_value_type("num_sampled", num_sampled, [int], self.name) validator.check_value_type("num_sampled", num_sampled, [int], self.name)
validator.check_value_type("unique", unique, [bool], self.name) validator.check_value_type("unique", unique, [bool], self.name)


tests/st/ops/gpu/test_uniform_sampler_op.py → tests/st/ops/gpu/test_uniform_candidate_sampler_op.py View File

@@ -21,45 +21,55 @@ from mindspore.ops import operations as P
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context


class UniformSamplerNet(nn.Cell):
class UniformCandidateSamplerNet(nn.Cell):
def __init__(self, num_true, num_sampled, unique, range_max): def __init__(self, num_true, num_sampled, unique, range_max):
super(UniformSamplerNet, self).__init__()
self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max)
super(UniformCandidateSamplerNet, self).__init__()
self.sampler = P.UniformCandidateSampler(num_true, num_sampled,
unique, range_max)


def construct(self, x): def construct(self, x):
return self.sampler(x) return self.sampler(x)




def uniform_sampler(x, num_true, num_sampled, unique, range_max):
uniform_sampler_net = UniformSamplerNet(num_true, num_sampled, unique, range_max)
out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32)))
def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max):
uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true,
num_sampled,
unique,
range_max)
out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32)))
return out1.shape, out2.shape, out3.shape return out1.shape, out2.shape, out3.shape




class UniformSamplerHitNet(nn.Cell):
class UniformCandidateSamplerHitNet(nn.Cell):
def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits):
super(UniformSamplerHitNet, self).__init__()
self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max, seed=seed,
remove_accidental_hits=remove_accidental_hits)
super(UniformCandidateSamplerHitNet, self).__init__()
self.sampler = P.UniformCandidateSampler(num_true, num_sampled, unique,
range_max, seed=seed,
remove_accidental_hits=remove_accidental_hits)


def construct(self, x): def construct(self, x):
return self.sampler(x) return self.sampler(x)




def uniform_sampler_hit(x, num_true, num_sampled, unique, range_max, seed,
remove_accidental_hits):
uniform_sampler_net = UniformSamplerHitNet(num_true, num_sampled, unique, range_max,
seed, remove_accidental_hits)
out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32)))
def uniform_candidate_sampler_hit(x, num_true, num_sampled, unique, range_max, seed,
remove_accidental_hits):
uniform_candidate_sampler_net = UniformCandidateSamplerHitNet(num_true,
num_sampled,
unique,
range_max,
seed,
remove_accidental_hits)
out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32)))
return out1, out2, out3 return out1, out2, out3




@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_unique_1_true():
def test_uniform_candidate_sampler_unique_1_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, True, 4)
ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1], [3], [4], [6], [3]]),
1, 3, True, 4)
expected_1 = (3,) expected_1 = (3,)
expected_2 = (5, 1) expected_2 = (5, 1)
expected_3 = (3,) expected_3 = (3,)
@@ -70,9 +80,10 @@ def test_uniform_sampler_unique_1_true():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_not_unique_1_true():
def test_uniform_candidate_sampler_not_unique_1_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, False, 4)
ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1], [3], [4], [6], [3]]),
1, 3, False, 4)
expected_1 = (3,) expected_1 = (3,)
expected_2 = (5, 1) expected_2 = (5, 1)
expected_3 = (3,) expected_3 = (3,)
@@ -83,9 +94,11 @@ def test_uniform_sampler_not_unique_1_true():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_unique_2_true():
def test_uniform_candidate_sampler_unique_2_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, True, 4)
ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1, 2], [3, 2], [4, 2],
[6, 2], [3, 2]]),
2, 3, True, 4)
expected_1 = (3,) expected_1 = (3,)
expected_2 = (5, 2) expected_2 = (5, 2)
expected_3 = (3,) expected_3 = (3,)
@@ -96,9 +109,12 @@ def test_uniform_sampler_unique_2_true():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_not_unique_2_true():
def test_uniform_candidate_sampler_not_unique_2_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, False, 4)
ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1, 2], [3, 2],
[4, 2], [6, 2],
[3, 2]]),
2, 3, False, 4)
expected_1 = (3,) expected_1 = (3,)
expected_2 = (5, 2) expected_2 = (5, 2)
expected_3 = (3,) expected_3 = (3,)
@@ -109,10 +125,14 @@ def test_uniform_sampler_not_unique_2_true():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_large():
def test_uniform_candidate_sampler_large():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[12221, 41414], [3312, 5125152], [3312454, 51252],
[65125, 225125], [35125, 5125122]]), 2, 5, False, 100)
ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[12221, 41414],
[3312, 5125152],
[3312454, 51252],
[65125, 225125],
[35125, 5125122]]),
2, 5, False, 100)
expected_1 = (5,) expected_1 = (5,)
expected_2 = (5, 2) expected_2 = (5, 2)
expected_3 = (5,) expected_3 = (5,)
@@ -124,9 +144,10 @@ def test_uniform_sampler_large():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_large_random():
def test_uniform_candidate_sampler_large_random():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.arange(2142).reshape(34, 63), 63, 10, False, 12)
ms1, ms2, ms3 = uniform_candidate_sampler(np.arange(2142).reshape(34, 63),
63, 10, False, 12)
expected_1 = (10,) expected_1 = (10,)
expected_2 = (34, 63) expected_2 = (34, 63)
expected_3 = (10,) expected_3 = (10,)
@@ -138,9 +159,9 @@ def test_uniform_sampler_large_random():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_unique_1_true_hit():
def test_uniform_candidate_sampler_unique_1_true_hit():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False)
ms1, _, _ = uniform_candidate_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False)
expected_1 = np.array([0, 3, 1]) expected_1 = np.array([0, 3, 1])
np.testing.assert_array_equal(ms1.asnumpy(), expected_1) np.testing.assert_array_equal(ms1.asnumpy(), expected_1)


@@ -148,8 +169,8 @@ def test_uniform_sampler_unique_1_true_hit():
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_uniform_sampler_unique_1_true_no_hit():
def test_uniform_candidate_sampler_unique_1_true_no_hit():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True)
ms1, _, _ = uniform_candidate_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True)
expected_1 = np.array([0, 3, 2]) expected_1 = np.array([0, 3, 2])
np.testing.assert_array_equal(ms1.asnumpy(), expected_1) np.testing.assert_array_equal(ms1.asnumpy(), expected_1)

Loading…
Cancel
Save