diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_categorical.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_categorical.cu new file mode 100644 index 0000000000..5d77b453a9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_categorical.cu @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh" + +template +__global__ void RandomCategorical(int num_samples, double** dev_rand, double** dev_cdf, + int batch_size, int num_classes, S *output_addr) { + int size = num_samples * batch_size; + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) { + int cur_row = pos / num_samples; + int cur_col = pos % num_samples; + const double to_find = dev_cdf[cur_row][num_classes-1] * dev_rand[cur_row][cur_col]; + + int idx = 0; + while (dev_cdf[cur_row][idx] < to_find) { + idx++; + } + output_addr[pos] = static_cast(idx); + } +} + +template +__global__ void GetCdf(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes) { + int size = num_classes * batch_size; + for (int pos= blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) { + int cur_row = pos / num_classes; + int cur_col = pos % num_classes; + if (cur_col != 0) { + return; + } + T max_of_row = logits_addr[pos]; + for (int i = 1; i < num_classes; i++) { + if (logits_addr[pos + i] > max_of_row) { + max_of_row = logits_addr[pos + i]; + } + } + dev_cdf[cur_row][0] = exp(static_cast(logits_addr[pos] - max_of_row)); + for (int i = 1; i < num_classes; i++) { + double tmp = exp(static_cast(logits_addr[pos + i] - max_of_row)); + dev_cdf[cur_row][i] = dev_cdf[cur_row][i - 1] + tmp; + } + } +} + +template +void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf, int batch_size, + int num_classes, S *output_addr, cudaStream_t cuda_stream) { + int size_out = num_samples * batch_size; + RandomCategorical<<>>(num_samples, dev_rand, + dev_cdf, batch_size, + num_classes, output_addr); +} + +template +void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes, + cudaStream_t cuda_stream) { + int size_cdf = num_classes * batch_size; + GetCdf<<>>(logits_addr, dev_cdf, batch_size, num_classes); +} + +template void GetCdfKernel(const half *logits_addr, double** dev_cdf, const int batch_size, + const int num_classes, cudaStream_t cuda_stream); +template void GetCdfKernel(const float *logits_addr, double** dev_cdf, const int batch_size, + const int num_classes, cudaStream_t cuda_stream); +template void GetCdfKernel(const double *logits_addr, double** dev_cdf, const int batch_size, + const int num_classes, cudaStream_t cuda_stream); + +template void RandomCategoricalKernel(int num_samples, + double** dev_rand, double** dev_cdf, int batch_size, int num_classes, + int16_t *output_addr, cudaStream_t cuda_stream); +template void RandomCategoricalKernel(int num_samples, + double** dev_rand, double** dev_cdf, int batch_size, int num_classes, + int *output_addr, cudaStream_t cuda_stream); +template void RandomCategoricalKernel(int num_samples, + double** dev_rand, double** dev_cdf, int batch_size, int num_classes, + int64_t *output_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh new file mode 100644 index 0000000000..2cc5b7f952 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes, + cudaStream_t cuda_stream); +template +void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf, + int batch_size, int num_classes, S *output_addr, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.cc new file mode 100644 index 0000000000..267e750547 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + RandomCategoricalGpuKernel, half, int16_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + RandomCategoricalGpuKernel, half, int32_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + RandomCategoricalGpuKernel, half, int64_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + RandomCategoricalGpuKernel, float, int16_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + RandomCategoricalGpuKernel, float, int32_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + RandomCategoricalGpuKernel, float, int64_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + RandomCategoricalGpuKernel, double, int16_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + RandomCategoricalGpuKernel, double, int32_t) +MS_REG_GPU_KERNEL_TWO(RandomCategorical, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + RandomCategoricalGpuKernel, double, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h new file mode 100644 index 0000000000..18a7ce16f7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh" + +namespace mindspore { +namespace kernel { +template +class RandomCategoricalGpuKernel : public GpuKernel { + public: + RandomCategoricalGpuKernel() : batch_size_(0), num_classes_(0), num_samples_(0), seed_(0) {} + ~RandomCategoricalGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, void *stream_ptr) override { + T *logits_addr = GetDeviceAddress(inputs, 0); + S *output_addr = GetDeviceAddress(outputs, 0); + + std::unique_ptr host_cdf; + host_cdf = std::make_unique(batch_size_); + for (int i = 0; i < batch_size_; i++) { + host_cdf[i] = GetDeviceAddress(workspaces, i); + } + double **dev_cdf = GetDeviceAddress(workspaces, batch_size_); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_cdf, host_cdf.get(), sizeof(double *) * batch_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "Random_categorica cudaMemcpyAsync dev_cdf failed"); + + std::unique_ptr host_rand; + host_rand = std::make_unique(batch_size_); + for (int i = 0; i < batch_size_; i++) { + host_rand[i] = GetDeviceAddress(workspaces, batch_size_ + 1 + i); + } + + double **dev_rand = GetDeviceAddress(workspaces, batch_size_ * 2 + 1); + for (int i = 0; i < batch_size_; i++) { + double *host_1d_rand = new double[num_samples_]; + std::default_random_engine rng(seed_); + std::uniform_real_distribution<> dist(0, 1); + for (int j = 0; j < num_samples_; j++) { + host_1d_rand[j] = dist(rng); + } + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(host_rand[i], host_1d_rand, sizeof(double) * num_samples_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "Random_categorica cudaMemcpyAsync host_1d_rand failed"); + delete[] host_1d_rand; + } + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_rand, host_rand.get(), sizeof(double *) * batch_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "Random_categorica cudaMemcpyAsync dev_rand failed"); + + GetCdfKernel(logits_addr, dev_cdf, batch_size_, num_classes_, reinterpret_cast(stream_ptr)); + RandomCategoricalKernel(num_samples_, dev_rand, dev_cdf, batch_size_, num_classes_, output_addr, + reinterpret_cast(stream_ptr)); + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomCategorical needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomCategorical has 1 output."; + return false; + } + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (logits_shape.size() != 2) { + MS_LOG(ERROR) << "logits's dims is " << logits_shape.size() << ", but it should be only 2-D."; + return false; + } + batch_size_ = SizeToInt(logits_shape[0]); + num_classes_ = SizeToInt(logits_shape[1]); + + num_samples_ = static_cast(GetAttr(kernel_node, "num_samples")); + seed_ = static_cast(GetAttr(kernel_node, "seed")); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override {} + void InitSizeLists() override { + // init memory + input_size_list_.push_back(sizeof(T) * batch_size_ * num_classes_); + input_size_list_.push_back(sizeof(int) * 2); + + output_size_list_.push_back(sizeof(S) * batch_size_ * num_samples_); + + for (int i = 0; i < batch_size_; i++) { + workspace_size_list_.push_back(sizeof(double) * num_classes_); + } + workspace_size_list_.push_back(sizeof(double *) * batch_size_); + + for (int i = 0; i < batch_size_; i++) { + workspace_size_list_.push_back(sizeof(double) * num_samples_); + } + workspace_size_list_.push_back(sizeof(double *) * batch_size_); + } + + private: + int batch_size_; + int num_classes_; + int num_samples_; + int seed_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 4c628446bb..033aae1c4f 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -430,6 +430,8 @@ class RandomCategorical(PrimitiveWithInfer): raise ValueError("RandomCategorical shape should be 2-dimension.") ndim = len(x_shape) - 1 x_shape[ndim] = num_samples_v + self.add_prim_attr('num_samples', num_samples_v) + self.add_prim_attr('seed', seed_v) return {'shape': (x_shape), 'dtype': (self.dtype), 'value': None} diff --git a/tests/st/ops/gpu/test_random_categorical_op.py b/tests/st/ops/gpu/test_random_categorical_op.py new file mode 100644 index 0000000000..58710baa6d --- /dev/null +++ b/tests/st/ops/gpu/test_random_categorical_op.py @@ -0,0 +1,180 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor +from mindspore.ops import operations as P + + +class RCnet(nn.Cell): + def __init__(self, num_sample, seed=0, dtype=ms.int64): + super(RCnet, self).__init__() + self.rc = P.RandomCategorical(dtype) + self.num_sample = num_sample + self.seed = seed + + def construct(self, logits): + return self.rc(logits, self.num_sample, self.seed) + +TARGET = "GPU" + +def test_rc_graph_fp16_int64(): + context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) + num_sample = 10 + seed = 5 + dtype = ms.int64 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) + + random_cateogoric = RCnet(num_sample, seed, dtype) + output = random_cateogoric(x) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_graph_fp32_int64(): + context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float32) + num_sample = 10 + seed = 5 + dtype = ms.int64 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) + + random_cateogoric = RCnet(num_sample, seed, dtype) + output = random_cateogoric(x) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_graph_fp64_int64(): + context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float64) + num_sample = 10 + seed = 5 + dtype = ms.int64 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) + + random_cateogoric = RCnet(num_sample, seed, dtype) + output = random_cateogoric(x) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_graph_fp16_int16(): + context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) + num_sample = 10 + seed = 5 + dtype = ms.int16 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int16) + + random_cateogoric = RCnet(num_sample, seed, dtype) + output = random_cateogoric(x) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_graph_fp16_int32(): + context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) + num_sample = 10 + seed = 5 + dtype = ms.int32 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int32) + + random_cateogoric = RCnet(num_sample, seed, dtype) + output = random_cateogoric(x) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_pynative_fp16_int64(): + context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) + num_sample = 10 + seed = 5 + dtype = ms.int64 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) + + output = P.RandomCategorical(dtype)(x, num_sample, seed) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_pynative_fp32_int64(): + context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float32) + num_sample = 10 + seed = 5 + dtype = ms.int64 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) + + output = P.RandomCategorical(dtype)(x, num_sample, seed) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_pynative_fp64_int64(): + context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float64) + num_sample = 10 + seed = 5 + dtype = ms.int64 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) + + output = P.RandomCategorical(dtype)(x, num_sample, seed) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_pynative_fp16_int16(): + context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) + num_sample = 10 + seed = 5 + dtype = ms.int16 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int16) + + output = P.RandomCategorical(dtype)(x, num_sample, seed) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0) + +def test_rc_pynative_fp16_int32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) + + x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) + num_sample = 10 + seed = 5 + dtype = ms.int32 + expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int32) + + output = P.RandomCategorical(dtype)(x, num_sample, seed) + diff = output.asnumpy() - expect + assert expect.dtype == output.asnumpy().dtype + assert np.all(diff == 0)