| @@ -0,0 +1,265 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" | |||
| #include <algorithm> | |||
| int RcwmRoundUpPower2(int v) { | |||
| v--; | |||
| v |= v >> 1; | |||
| v |= v >> 2; | |||
| v |= v >> 4; | |||
| v |= v >> 8; | |||
| v |= v >> 16; | |||
| v++; | |||
| return v; | |||
| } | |||
| template <typename T> | |||
| __inline__ __device__ void Swap(T *lhs, T *rhs) { | |||
| T tmp = lhs[0]; | |||
| lhs[0] = rhs[0]; | |||
| rhs[0] = tmp; | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void InitArray(const int input_size, const int ceil_power2, const T *input, S *mask_buff, S *rank_buff) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < ceil_power2; pos += blockDim.x * gridDim.x) { | |||
| mask_buff[pos] = (pos < input_size) ? static_cast<S>(input[pos]) : 0; | |||
| rank_buff[pos] = (pos < input_size && input[pos] != false) ? pos : (ceil_power2 + 1); | |||
| } | |||
| } | |||
| template <size_t blockSize, typename T> | |||
| __device__ void WarpReduce(volatile T *sdata, size_t tid) { | |||
| if (blockSize >= 64) sdata[tid] += sdata[tid + 32]; | |||
| if (blockSize >= 32) sdata[tid] += sdata[tid + 16]; | |||
| if (blockSize >= 16) sdata[tid] += sdata[tid + 8]; | |||
| if (blockSize >= 8) sdata[tid] += sdata[tid + 4]; | |||
| if (blockSize >= 4) sdata[tid] += sdata[tid + 2]; | |||
| if (blockSize >= 2) sdata[tid] += sdata[tid + 1]; | |||
| } | |||
| template <size_t blockSize, typename T> | |||
| __global__ void ReductionSum(T *g_idata, T *g_odata, size_t n) { | |||
| __shared__ T sdata[blockSize]; | |||
| size_t tid = threadIdx.x; | |||
| size_t i = blockIdx.x * (blockSize) + tid; | |||
| size_t gridSize = blockSize * gridDim.x; | |||
| sdata[tid] = 0; | |||
| while (i < n) { | |||
| sdata[tid] += g_idata[i]; | |||
| i += gridSize; | |||
| } | |||
| __syncthreads(); | |||
| if (blockSize >= 1024) { | |||
| if (tid < 512) { | |||
| sdata[tid] += sdata[tid + 512]; | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| if (blockSize >= 512) { | |||
| if (tid < 256) { | |||
| sdata[tid] += sdata[tid + 256]; | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| if (blockSize >= 256) { | |||
| if (tid < 128) { | |||
| sdata[tid] += sdata[tid + 128]; | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| if (blockSize >= 128) { | |||
| if (tid < 64) { | |||
| sdata[tid] += sdata[tid + 64]; | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| if (tid < 32) WarpReduce<blockSize>(sdata, tid); | |||
| if (tid == 0) g_odata[blockIdx.x] = sdata[0]; | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void Reshape2Index(const int input_size, const int input_shape_size, const int d1, const int d2, | |||
| const int d3, const int d4, const int d5, const T *input, S *output_index) { | |||
| int pos_array[MAX_DIMENSION]; | |||
| int index_pos; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) { | |||
| pos_array[0] = pos / (d2 * d3 * d4 * d5) % d1; | |||
| pos_array[1] = pos / (d3 * d4 * d5) % d2; | |||
| pos_array[2] = pos / (d4 * d5) % d3; | |||
| pos_array[3] = pos / (d5) % d4; | |||
| pos_array[4] = pos % d5; | |||
| index_pos = pos * input_shape_size; | |||
| if (input[pos] == false) { | |||
| for (int i = 0; i < input_shape_size; i++) { | |||
| output_index[index_pos++] = 0; | |||
| } | |||
| } else { | |||
| for (int i = MAX_DIMENSION - input_shape_size; i < MAX_DIMENSION; i++) { | |||
| output_index[index_pos++] = pos_array[i]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void Copy(const T *src, T *dst, const int n) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < n; pos += blockDim.x * gridDim.x) { | |||
| dst[pos] = src[pos]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void Sort(const int ceil_power2, T *rank_buff) { | |||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||
| for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { | |||
| size_t tid_comp = tid ^ j; | |||
| if (tid_comp > tid) { | |||
| if ((tid & i) == 0) { | |||
| if (rank_buff[tid] > rank_buff[tid_comp]) { | |||
| Swap(&rank_buff[tid], &rank_buff[tid_comp]); | |||
| } | |||
| } else { | |||
| if (rank_buff[tid] < rank_buff[tid_comp]) { | |||
| Swap(&rank_buff[tid], &rank_buff[tid_comp]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| } | |||
| } | |||
| __global__ void SrandInit(const int ceil_power2, curandState *globalState, const int seedc) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < ceil_power2; i += blockDim.x * gridDim.x) { | |||
| curand_init(seedc, i, 0, &globalState[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void Shuffle(const int ceil_power2, curandState *globalState, T *rank_buff) { | |||
| int limit = ceil_power2 + 1; | |||
| int value; | |||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||
| for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < ceil_power2; tid += blockDim.x * gridDim.x) { | |||
| size_t tid_comp = tid ^ j; | |||
| if (tid_comp > tid) { | |||
| value = static_cast<int>(curand(&globalState[tid])); | |||
| if (value & 1) { | |||
| if (rank_buff[tid] != limit && rank_buff[tid_comp] != limit) { | |||
| Swap(&rank_buff[tid], &rank_buff[tid_comp]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void MoveToOutput(const int input_shape_size, const int count, const T *input, S *output_index, | |||
| T *output_mask, S *index_buff, S *rank_buff, S *Tnum_buff) { | |||
| int Tnum = static_cast<int>(Tnum_buff[0]); | |||
| int idx = 0; | |||
| int pos; | |||
| if (count <= Tnum) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { | |||
| idx = rank_buff[i]; | |||
| pos = i; | |||
| output_mask[pos] = input[idx]; | |||
| pos *= input_shape_size; | |||
| idx *= input_shape_size; | |||
| for (size_t j = 0; j < input_shape_size; j++) { | |||
| output_index[pos] = index_buff[idx]; | |||
| pos++; | |||
| idx++; | |||
| } | |||
| } | |||
| } else { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { | |||
| if (i < Tnum) { | |||
| idx = rank_buff[i]; | |||
| pos = i; | |||
| output_mask[pos] = input[idx]; | |||
| pos *= input_shape_size; | |||
| idx *= input_shape_size; | |||
| for (size_t j = 0; j < input_shape_size; j++) { | |||
| output_index[pos] = index_buff[idx]; | |||
| pos++; | |||
| idx++; | |||
| } | |||
| } else { | |||
| pos = i; | |||
| output_mask[pos] = static_cast<T>(0); | |||
| pos *= input_shape_size; | |||
| for (size_t j = 0; j < input_shape_size; j++) { | |||
| output_index[pos] = static_cast<S>(0); | |||
| pos++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, | |||
| const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, | |||
| const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff, | |||
| S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream) { | |||
| int ceil_power2 = RcwmRoundUpPower2(input_size); | |||
| InitArray<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, ceil_power2, input, mask_buff, rank_buff); | |||
| size_t BLOCKNUM; | |||
| size_t n = ceil_power2; | |||
| Copy<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(mask_buff, tmp_buff, ceil_power2); | |||
| do { | |||
| BLOCKNUM = std::ceil(static_cast<float>(n) / BLOCKSIZE); | |||
| ReductionSum<BLOCKSIZE, S><<<BLOCKNUM, BLOCKSIZE, 0, stream>>>(tmp_buff, Tnum_buff, n); | |||
| Copy<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(Tnum_buff, tmp_buff, BLOCKNUM); | |||
| n = BLOCKNUM; | |||
| } while (n > BLOCKSIZE); | |||
| if (n > 1) ReductionSum<BLOCKSIZE, S><<<1, BLOCKSIZE, 0, stream>>>(Tnum_buff, Tnum_buff, n); | |||
| Reshape2Index<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input_shape_size, d1, d2, d3, d4, d5, | |||
| input, index_buff); | |||
| Sort<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, rank_buff); | |||
| SrandInit<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, globalState, seedc); | |||
| Shuffle<<<GET_BLOCKS(ceil_power2), GET_THREADS, 0, stream>>>(ceil_power2, globalState, rank_buff); | |||
| MoveToOutput<<<GET_BLOCKS(count), GET_THREADS, 0, stream>>>(input_shape_size, count, input, output_index, output_mask, | |||
| index_buff, rank_buff, Tnum_buff); | |||
| } | |||
| template void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, | |||
| const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, | |||
| const bool *input, int *output_index, bool *output_mask, int *index_buff, | |||
| int *mask_buff, int *rank_buff, int *Tnum_buff, int *tmp_buff, | |||
| curandState *globalState, cudaStream_t stream); | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ | |||
| #include <cuda_runtime.h> | |||
| #include <curand_kernel.h> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #define BLOCKSIZE 256 | |||
| #define MAX_DIMENSION 5 | |||
| template <typename T, typename S> | |||
| void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, | |||
| const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, | |||
| const T *input, S *output_index, T *output_mask, S *index_buff, S *mask_buff, S *rank_buff, | |||
| S *Tnum_buff, S *tmp_buff, curandState *globalState, cudaStream_t stream); | |||
| int RcwmRoundUpPower2(int v); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RANDOM_CHOICE_WITH_MASK_IMPL_CUH_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| RandomChoiceWithMask, | |||
| KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| RandomChoiceWithMaskGpuKernel, bool, int) | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,129 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class RandomChoiceWithMaskGpuKernel : public GpuKernel { | |||
| public: | |||
| RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {} | |||
| ~RandomChoiceWithMaskGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| S *output_index = GetDeviceAddress<S>(outputs, 0); | |||
| T *output_mask = GetDeviceAddress<T>(outputs, 1); | |||
| S *index_buff = GetDeviceAddress<S>(workspaces, 0); | |||
| S *mask_buff = GetDeviceAddress<S>(workspaces, 1); | |||
| S *rank_buff = GetDeviceAddress<S>(workspaces, 2); | |||
| S *Tnum_buff = GetDeviceAddress<S>(workspaces, 3); | |||
| S *tmp_buff = GetDeviceAddress<S>(workspaces, 4); | |||
| void *States = GetDeviceAddress<void *>(workspaces, 5); | |||
| curandState *devStates = reinterpret_cast<curandState *>(States); | |||
| CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2], | |||
| input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask, | |||
| index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 2) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but RandomChoiceWithMask has 2 outputs."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input_shape_size_ = input_shape.size(); | |||
| if (input_shape_size_ < 1 || input_shape_size_ > MAX_DIMENSION) { | |||
| MS_LOG(ERROR) << "Input is " << input_shape_size_ | |||
| << "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs."; | |||
| return false; | |||
| } | |||
| // convert size_t to int | |||
| for (auto i = 0; i < input_shape_size_; i++) { | |||
| input_shape_5D_.push_back(input_shape[i]); | |||
| } | |||
| // convert shape to 5D | |||
| while (input_shape_5D_.size() != MAX_DIMENSION) { | |||
| input_shape_5D_.insert(input_shape_5D_.begin(), 1); | |||
| } | |||
| // init seedc_ | |||
| int seed = GetAttr<int>(kernel_node, "seed"); | |||
| int seed2 = GetAttr<int>(kernel_node, "seed2"); | |||
| if (seed2 != 0) | |||
| seedc_ = seed2; | |||
| else if (seed != 0) | |||
| seedc_ = seed; | |||
| else | |||
| seedc_ = time(NULL); | |||
| // init memory | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| count_ = GetAttr<int>(kernel_node, "count"); | |||
| // upper ceiling for input for ceil_power2 | |||
| ceil_power2_ = RcwmRoundUpPower2(input_size_); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S)); | |||
| output_size_list_.push_back(count_ * sizeof(T)); | |||
| workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| int blocknum = std::ceil(static_cast<float>(ceil_power2_) / BLOCKSIZE); | |||
| workspace_size_list_.push_back(blocknum * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); | |||
| workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); | |||
| } | |||
| private: | |||
| int input_shape_size_; | |||
| int seedc_; | |||
| int input_size_; | |||
| int count_; | |||
| int ceil_power2_; | |||
| std::vector<int> input_shape_5D_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_ | |||
| @@ -348,13 +348,13 @@ class RandomChoiceWithMask(PrimitiveWithInfer): | |||
| seed2 (int): Random seed2. Default: 0. | |||
| Inputs: | |||
| - **input_x** (Tensor[bool]) - The input tensor. | |||
| - **input_x** (Tensor[bool]) - The input tensor. The input tensor rank should be >= 1 and <= 5. | |||
| Outputs: | |||
| Two tensors, the first one is the index tensor and the other one is the mask tensor. | |||
| - **index** (Tensor) - The output has shape between 2-D and 5-D. | |||
| - **mask** (Tensor) - The output has shape 1-D. | |||
| - **index** (Tensor) - The output shape is 2-D. | |||
| - **mask** (Tensor) - The output shape is 1-D. | |||
| Examples: | |||
| >>> rnd_choice_mask = P.RandomChoiceWithMask() | |||
| @@ -372,6 +372,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check_integer("input_x rank", len(x_shape), 5, Rel.LE, self.name) | |||
| return ([self.count, len(x_shape)], [self.count]) | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -0,0 +1,86 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class RCWM_count_in(nn.Cell): | |||
| def __init__(self): | |||
| super(RCWM_count_in, self).__init__() | |||
| self.RCWM_count_in = P.RandomChoiceWithMask(count=4, seed=1) | |||
| def construct(self, x): | |||
| return self.RCWM_count_in(x) | |||
| class RCWM_count_out(nn.Cell): | |||
| def __init__(self): | |||
| super(RCWM_count_out, self).__init__() | |||
| self.RCWM_count_out = P.RandomChoiceWithMask(count=10, seed=1) | |||
| def construct(self, x): | |||
| return self.RCWM_count_out(x) | |||
| class RCWM_3D(nn.Cell): | |||
| def __init__(self): | |||
| super(RCWM_3D, self).__init__() | |||
| self.RCWM_3D = P.RandomChoiceWithMask(count=10, seed=1) | |||
| def construct(self, x): | |||
| return self.RCWM_3D(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_RCWM_3D(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| input_tensor = Tensor(np.ones([3, 4, 5]).astype(np.bool)) | |||
| expect1 = [[0, 1, 1], [0, 2, 1], [0, 2, 2], [1, 0, 1], [0, 1, 3], [0, 3, 0], [1, 3, 2], \ | |||
| [0, 0, 0], [1, 1, 2], [1, 3, 4]] | |||
| expect2 = [True, True, True, True, True, True, True, True, True, True] | |||
| rcwm = RCWM_3D() | |||
| output1, output2 = rcwm(input_tensor) | |||
| assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) | |||
| assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_RCWM_count_out(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) | |||
| expect1 = [[0, 2], [2, 2], [2, 1], [2, 0], [0, 0], [3, 3], [2, 3], [1, 3], [0, 0], [0, 0]] | |||
| expect2 = [True, True, True, True, True, True, True, True, False, False] | |||
| rcwm = RCWM_count_out() | |||
| output1, output2 = rcwm(input_tensor) | |||
| assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) | |||
| assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_RCWM_count_in(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) | |||
| expect1 = [[0, 2], [2, 2], [2, 1], [2, 0]] | |||
| expect2 = [True, True, True, True] | |||
| rcwm = RCWM_count_in() | |||
| output1, output2 = rcwm(input_tensor) | |||
| assert np.all(output1.asnumpy() == np.array(expect1)), "output: {}, expect: {}".format(output1, expect1) | |||
| assert np.all(output2.asnumpy() == np.array(expect2)), "output: {}, expect: {}".format(output2, expect2) | |||