From: @yuan_shen_zhou Reviewed-by: @c_34,@liangchenghui Signed-off-by: @c_34,@liangchenghuitags/v1.1.0
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh" | |||||
| template <typename S> | |||||
| __global__ void RandomCategorical(int num_samples, double** dev_rand, double** dev_cdf, | |||||
| int batch_size, int num_classes, S *output_addr) { | |||||
| int size = num_samples * batch_size; | |||||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) { | |||||
| int cur_row = pos / num_samples; | |||||
| int cur_col = pos % num_samples; | |||||
| const double to_find = dev_cdf[cur_row][num_classes-1] * dev_rand[cur_row][cur_col]; | |||||
| int idx = 0; | |||||
| while (dev_cdf[cur_row][idx] < to_find) { | |||||
| idx++; | |||||
| } | |||||
| output_addr[pos] = static_cast<S>(idx); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void GetCdf(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes) { | |||||
| int size = num_classes * batch_size; | |||||
| for (int pos= blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) { | |||||
| int cur_row = pos / num_classes; | |||||
| int cur_col = pos % num_classes; | |||||
| if (cur_col != 0) { | |||||
| return; | |||||
| } | |||||
| T max_of_row = logits_addr[pos]; | |||||
| for (int i = 1; i < num_classes; i++) { | |||||
| if (logits_addr[pos + i] > max_of_row) { | |||||
| max_of_row = logits_addr[pos + i]; | |||||
| } | |||||
| } | |||||
| dev_cdf[cur_row][0] = exp(static_cast<double>(logits_addr[pos] - max_of_row)); | |||||
| for (int i = 1; i < num_classes; i++) { | |||||
| double tmp = exp(static_cast<double>(logits_addr[pos + i] - max_of_row)); | |||||
| dev_cdf[cur_row][i] = dev_cdf[cur_row][i - 1] + tmp; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename S> | |||||
| void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf, int batch_size, | |||||
| int num_classes, S *output_addr, cudaStream_t cuda_stream) { | |||||
| int size_out = num_samples * batch_size; | |||||
| RandomCategorical<<<GET_BLOCKS(size_out), GET_THREADS, 0, cuda_stream>>>(num_samples, dev_rand, | |||||
| dev_cdf, batch_size, | |||||
| num_classes, output_addr); | |||||
| } | |||||
| template <typename T> | |||||
| void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes, | |||||
| cudaStream_t cuda_stream) { | |||||
| int size_cdf = num_classes * batch_size; | |||||
| GetCdf<<<GET_BLOCKS(size_cdf), GET_THREADS, 0, cuda_stream>>>(logits_addr, dev_cdf, batch_size, num_classes); | |||||
| } | |||||
| template void GetCdfKernel<half>(const half *logits_addr, double** dev_cdf, const int batch_size, | |||||
| const int num_classes, cudaStream_t cuda_stream); | |||||
| template void GetCdfKernel<float>(const float *logits_addr, double** dev_cdf, const int batch_size, | |||||
| const int num_classes, cudaStream_t cuda_stream); | |||||
| template void GetCdfKernel<double>(const double *logits_addr, double** dev_cdf, const int batch_size, | |||||
| const int num_classes, cudaStream_t cuda_stream); | |||||
| template void RandomCategoricalKernel<int16_t>(int num_samples, | |||||
| double** dev_rand, double** dev_cdf, int batch_size, int num_classes, | |||||
| int16_t *output_addr, cudaStream_t cuda_stream); | |||||
| template void RandomCategoricalKernel<int>(int num_samples, | |||||
| double** dev_rand, double** dev_cdf, int batch_size, int num_classes, | |||||
| int *output_addr, cudaStream_t cuda_stream); | |||||
| template void RandomCategoricalKernel<int64_t>(int num_samples, | |||||
| double** dev_rand, double** dev_cdf, int batch_size, int num_classes, | |||||
| int64_t *output_addr, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes, | |||||
| cudaStream_t cuda_stream); | |||||
| template <typename S> | |||||
| void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf, | |||||
| int batch_size, int num_classes, S *output_addr, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_ | |||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/random/random_categorical_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt16), | |||||
| RandomCategoricalGpuKernel, half, int16_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| RandomCategoricalGpuKernel, half, int32_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt64), | |||||
| RandomCategoricalGpuKernel, half, int64_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt16), | |||||
| RandomCategoricalGpuKernel, float, int16_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| RandomCategoricalGpuKernel, float, int32_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt64), | |||||
| RandomCategoricalGpuKernel, float, int64_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt16), | |||||
| RandomCategoricalGpuKernel, double, int16_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| RandomCategoricalGpuKernel, double, int32_t) | |||||
| MS_REG_GPU_KERNEL_TWO(RandomCategorical, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt64), | |||||
| RandomCategoricalGpuKernel, double, int64_t) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,141 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <random> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class RandomCategoricalGpuKernel : public GpuKernel { | |||||
| public: | |||||
| RandomCategoricalGpuKernel() : batch_size_(0), num_classes_(0), num_samples_(0), seed_(0) {} | |||||
| ~RandomCategoricalGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *logits_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| S *output_addr = GetDeviceAddress<S>(outputs, 0); | |||||
| std::unique_ptr<double *[]> host_cdf; | |||||
| host_cdf = std::make_unique<double *[]>(batch_size_); | |||||
| for (int i = 0; i < batch_size_; i++) { | |||||
| host_cdf[i] = GetDeviceAddress<double>(workspaces, i); | |||||
| } | |||||
| double **dev_cdf = GetDeviceAddress<double *>(workspaces, batch_size_); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_cdf, host_cdf.get(), sizeof(double *) * batch_size_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Random_categorica cudaMemcpyAsync dev_cdf failed"); | |||||
| std::unique_ptr<double *[]> host_rand; | |||||
| host_rand = std::make_unique<double *[]>(batch_size_); | |||||
| for (int i = 0; i < batch_size_; i++) { | |||||
| host_rand[i] = GetDeviceAddress<double>(workspaces, batch_size_ + 1 + i); | |||||
| } | |||||
| double **dev_rand = GetDeviceAddress<double *>(workspaces, batch_size_ * 2 + 1); | |||||
| for (int i = 0; i < batch_size_; i++) { | |||||
| double *host_1d_rand = new double[num_samples_]; | |||||
| std::default_random_engine rng(seed_); | |||||
| std::uniform_real_distribution<> dist(0, 1); | |||||
| for (int j = 0; j < num_samples_; j++) { | |||||
| host_1d_rand[j] = dist(rng); | |||||
| } | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(host_rand[i], host_1d_rand, sizeof(double) * num_samples_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Random_categorica cudaMemcpyAsync host_1d_rand failed"); | |||||
| delete[] host_1d_rand; | |||||
| } | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_rand, host_rand.get(), sizeof(double *) * batch_size_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Random_categorica cudaMemcpyAsync dev_rand failed"); | |||||
| GetCdfKernel(logits_addr, dev_cdf, batch_size_, num_classes_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| RandomCategoricalKernel(num_samples_, dev_rand, dev_cdf, batch_size_, num_classes_, output_addr, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 3) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomCategorical needs 3 inputs."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomCategorical has 1 output."; | |||||
| return false; | |||||
| } | |||||
| auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (logits_shape.size() != 2) { | |||||
| MS_LOG(ERROR) << "logits's dims is " << logits_shape.size() << ", but it should be only 2-D."; | |||||
| return false; | |||||
| } | |||||
| batch_size_ = SizeToInt(logits_shape[0]); | |||||
| num_classes_ = SizeToInt(logits_shape[1]); | |||||
| num_samples_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_samples")); | |||||
| seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed")); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitResource() override {} | |||||
| void InitSizeLists() override { | |||||
| // init memory | |||||
| input_size_list_.push_back(sizeof(T) * batch_size_ * num_classes_); | |||||
| input_size_list_.push_back(sizeof(int) * 2); | |||||
| output_size_list_.push_back(sizeof(S) * batch_size_ * num_samples_); | |||||
| for (int i = 0; i < batch_size_; i++) { | |||||
| workspace_size_list_.push_back(sizeof(double) * num_classes_); | |||||
| } | |||||
| workspace_size_list_.push_back(sizeof(double *) * batch_size_); | |||||
| for (int i = 0; i < batch_size_; i++) { | |||||
| workspace_size_list_.push_back(sizeof(double) * num_samples_); | |||||
| } | |||||
| workspace_size_list_.push_back(sizeof(double *) * batch_size_); | |||||
| } | |||||
| private: | |||||
| int batch_size_; | |||||
| int num_classes_; | |||||
| int num_samples_; | |||||
| int seed_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CATEGORICAL_GPU_KERNEL_H_ | |||||
| @@ -430,6 +430,8 @@ class RandomCategorical(PrimitiveWithInfer): | |||||
| raise ValueError("RandomCategorical shape should be 2-dimension.") | raise ValueError("RandomCategorical shape should be 2-dimension.") | ||||
| ndim = len(x_shape) - 1 | ndim = len(x_shape) - 1 | ||||
| x_shape[ndim] = num_samples_v | x_shape[ndim] = num_samples_v | ||||
| self.add_prim_attr('num_samples', num_samples_v) | |||||
| self.add_prim_attr('seed', seed_v) | |||||
| return {'shape': (x_shape), | return {'shape': (x_shape), | ||||
| 'dtype': (self.dtype), | 'dtype': (self.dtype), | ||||
| 'value': None} | 'value': None} | ||||
| @@ -0,0 +1,180 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| import mindspore as ms | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class RCnet(nn.Cell): | |||||
| def __init__(self, num_sample, seed=0, dtype=ms.int64): | |||||
| super(RCnet, self).__init__() | |||||
| self.rc = P.RandomCategorical(dtype) | |||||
| self.num_sample = num_sample | |||||
| self.seed = seed | |||||
| def construct(self, logits): | |||||
| return self.rc(logits, self.num_sample, self.seed) | |||||
| TARGET = "GPU" | |||||
| def test_rc_graph_fp16_int64(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int64 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) | |||||
| random_cateogoric = RCnet(num_sample, seed, dtype) | |||||
| output = random_cateogoric(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_graph_fp32_int64(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float32) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int64 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) | |||||
| random_cateogoric = RCnet(num_sample, seed, dtype) | |||||
| output = random_cateogoric(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_graph_fp64_int64(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float64) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int64 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) | |||||
| random_cateogoric = RCnet(num_sample, seed, dtype) | |||||
| output = random_cateogoric(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_graph_fp16_int16(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int16 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int16) | |||||
| random_cateogoric = RCnet(num_sample, seed, dtype) | |||||
| output = random_cateogoric(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_graph_fp16_int32(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int32 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int32) | |||||
| random_cateogoric = RCnet(num_sample, seed, dtype) | |||||
| output = random_cateogoric(x) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_pynative_fp16_int64(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int64 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) | |||||
| output = P.RandomCategorical(dtype)(x, num_sample, seed) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_pynative_fp32_int64(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float32) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int64 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) | |||||
| output = P.RandomCategorical(dtype)(x, num_sample, seed) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_pynative_fp64_int64(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float64) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int64 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int64) | |||||
| output = P.RandomCategorical(dtype)(x, num_sample, seed) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_pynative_fp16_int16(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int16 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int16) | |||||
| output = P.RandomCategorical(dtype)(x, num_sample, seed) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||
| def test_rc_pynative_fp16_int32(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=TARGET) | |||||
| x = Tensor(np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), ms.float16) | |||||
| num_sample = 10 | |||||
| seed = 5 | |||||
| dtype = ms.int32 | |||||
| expect = np.array([[4, 3, 2, 4, 4, 4, 3, 4, 1, 3], [4, 3, 2, 4, 4, 4, 3, 4, 1, 3]], dtype=np.int32) | |||||
| output = P.RandomCategorical(dtype)(x, num_sample, seed) | |||||
| diff = output.asnumpy() - expect | |||||
| assert expect.dtype == output.asnumpy().dtype | |||||
| assert np.all(diff == 0) | |||||