Merge pull request !3924 from mamba_ni/mastertags/v0.7.0-beta
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "identity_impl.cuh" | |||
| #include <iostream> | |||
| template <typename T> | |||
| __global__ void IdentityKernel(const size_t size, const size_t dim, T *output_addr) { | |||
| for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { | |||
| size_t batchIdx = pointIdx / (dim * dim); | |||
| size_t dst_x = (pointIdx - batchIdx * dim * dim) / dim; | |||
| size_t dst_y = (pointIdx - batchIdx * dim * dim) % dim; | |||
| if (dst_x == dst_y) { | |||
| output_addr[pointIdx] = 1; | |||
| } else { | |||
| output_addr[pointIdx] = 0; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) { | |||
| IdentityKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dim, output_addr); | |||
| return; | |||
| } | |||
| template void Identity<float>(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "matrix_combine_impl.cuh" | |||
| #include <iostream> | |||
| template <typename T> | |||
| __global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width, | |||
| const size_t dst_width, T *input_addr, T *output_addr) { | |||
| for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { | |||
| size_t batchIdx = pointIdx / (src_height * src_width); | |||
| size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width; | |||
| size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width; | |||
| size_t dst_h = src_height * batchIdx + src_h; | |||
| size_t dst_w = src_width * batchIdx + src_w; | |||
| output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width, | |||
| const size_t dst_width, const size_t res_width, const size_t batch, T *input_addr, | |||
| T *output_addr) { | |||
| for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { | |||
| size_t batchIdx = pointIdx / (src_height * src_width); | |||
| if (batchIdx != (batch - 1)) { | |||
| size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width; | |||
| size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width; | |||
| size_t dst_h = src_height * batchIdx + src_h; | |||
| size_t dst_w = src_width * batchIdx + src_w; | |||
| output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx]; | |||
| } else { | |||
| size_t src_h = (pointIdx - (batch - 1) * src_height * src_width) / res_width; | |||
| size_t src_w = (pointIdx - (batch - 1) * src_height * src_width) % res_width; | |||
| size_t src_coordinate = (batch - 1) * src_height * src_width + src_h * src_width + src_w; | |||
| size_t dst_h = src_height * (batch - 1) + src_h; | |||
| size_t dst_w = src_width * (batch - 1) + src_w; | |||
| output_addr[dst_h * dst_width + dst_w] = input_addr[src_coordinate]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width, | |||
| const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr, | |||
| cudaStream_t cuda_stream) { | |||
| if (residual == 0) { | |||
| MatrixCombineKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, src_height, src_width, dst_width, | |||
| input_addr, output_addr); | |||
| } else { | |||
| MatrixCombineKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, src_height, src_width, dst_width, | |||
| res_width, batch, input_addr, output_addr); | |||
| } | |||
| return; | |||
| } | |||
| template void MatrixCombine<float>(const size_t size, const size_t src_height, const size_t src_width, | |||
| const size_t dst_width, const size_t residual, const size_t res_width, | |||
| const size_t batch, float *input_addr, float *output_addr, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width, | |||
| const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "matrix_split_impl.cuh" | |||
| #include <iostream> | |||
| template <typename T> | |||
| __global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, | |||
| T *output_addr) { | |||
| for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { | |||
| size_t batchIdx = pointIdx / (split_dim * split_dim); | |||
| size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim; | |||
| size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim; | |||
| size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y; | |||
| output_addr[pointIdx] = input_addr[src_coordinate]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, const size_t res_dim, | |||
| T *input_addr, T *output_addr) { | |||
| for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) { | |||
| size_t batchIdx = pointIdx / (split_dim * split_dim); | |||
| size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim; | |||
| size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim; | |||
| size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y; | |||
| size_t batch_lower = dim / split_dim; | |||
| if (batchIdx < batch_lower) { | |||
| output_addr[pointIdx] = input_addr[src_coordinate]; | |||
| } else { | |||
| if (dst_x < res_dim && dst_y < res_dim) { | |||
| output_addr[pointIdx] = input_addr[src_coordinate]; | |||
| } else if (dst_x == dst_y) { | |||
| output_addr[pointIdx] = 1; | |||
| } else { | |||
| output_addr[pointIdx] = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr, | |||
| cudaStream_t cuda_stream) { | |||
| size_t batch = dim / split_dim; | |||
| size_t res_dim = dim - batch * split_dim; | |||
| if (res_dim == 0) { | |||
| MatrixSplitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, split_dim, dim, input_addr, output_addr); | |||
| } else { | |||
| MatrixSplitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, split_dim, dim, res_dim, input_addr, | |||
| output_addr); | |||
| } | |||
| return; | |||
| } | |||
| template void MatrixSplit<float>(const size_t size, const size_t split_dim, const size_t dim, float *input_addr, | |||
| float *output_addr, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Im2ColGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| Im2ColGpuFwdKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,269 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class Im2ColGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| Im2ColGpuFwdKernel() | |||
| : cudnn_handle_(nullptr), | |||
| input_desc_(nullptr), | |||
| output_desc_(nullptr), | |||
| filter_desc_(nullptr), | |||
| conv_desc_(nullptr), | |||
| padded_desc_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| old_height_(0), | |||
| old_width_(0), | |||
| pad_height_(0), | |||
| pad_width_(0), | |||
| pad_top_(0), | |||
| pad_left_(0), | |||
| n_(0), | |||
| c_(0), | |||
| is_null_input_(false), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| padded_size_(0), | |||
| workspace_size_(0), | |||
| use_pad_(true) {} | |||
| ~Im2ColGpuFwdKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { | |||
| T *padded_addr = GetDeviceAddress<T>(workspace, 0); | |||
| CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, | |||
| old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnIm2Col(cudnn_handle_, padded_desc_, padded_addr, filter_desc_, conv_desc_, output_addr), | |||
| "cudnnIm2ColForward failed"); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnIm2Col(cudnn_handle_, input_desc_, input_addr, filter_desc_, conv_desc_, output_addr), | |||
| "cudnnIm2ColForward failed"); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| if (!CheckParam(kernel_node)) { | |||
| return false; | |||
| } | |||
| auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto filter_shape = GetAttr<std::vector<int>>(kernel_node, "kernel_size"); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(in_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "cudnnIm2ColForward input is null."; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| Set4DDesc(in_shape, filter_shape, output_shape); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, 1), "cudnnSetConvGroupCount failed"); | |||
| pad_height_ = GetAttr<int>(kernel_node, "pad"); | |||
| pad_width_ = pad_height_; | |||
| pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode"); | |||
| SetStrideAndDilation(kernel_node); | |||
| if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { | |||
| SetPad(in_shape, kernel_node); | |||
| } else { | |||
| if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { | |||
| pad_height_ = 0; | |||
| pad_width_ = 0; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], | |||
| dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| } | |||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), | |||
| "cudnnSetConvolutionMathType failed.") | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() override { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), | |||
| "cudnnCreateConvolutionDescriptor failed"); | |||
| } | |||
| void InitSizeLists() override { | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast<size_t *>(&input_size_)), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast<size_t *>(&output_size_)), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast<size_t *>(&padded_size_)), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| } | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { | |||
| workspace_size_list_.push_back(padded_size_); | |||
| } | |||
| return; | |||
| } | |||
| private: | |||
| void DestroyResource() noexcept { | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), | |||
| "cudnnDestroyConvolutionDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); | |||
| } | |||
| bool CheckParam(const CNodePtr &kernel_node) { | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but Im2Col needs 1 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but Im2Col needs 1 output."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void SetPad(const std::vector<size_t> &in_shape, const CNodePtr &kernel_node) { | |||
| auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list"); | |||
| n_ = SizeToInt(in_shape[0]); | |||
| c_ = SizeToInt(in_shape[1]); | |||
| old_height_ = SizeToInt(in_shape[2]); | |||
| old_width_ = SizeToInt(in_shape[3]); | |||
| pad_height_ = pad_list[0] + pad_list[1]; | |||
| pad_width_ = pad_list[2] + pad_list[3]; | |||
| pad_top_ = pad_list[0]; | |||
| pad_left_ = pad_list[2]; | |||
| // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. | |||
| if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { | |||
| use_pad_ = false; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, | |||
| old_height_ + pad_height_, old_width_ + pad_width_), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( | |||
| conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], | |||
| dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||
| "cudnnSetConvolution2dDescriptor failed"); | |||
| } | |||
| void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<int> &filter_shape, | |||
| const std::vector<size_t> &output_shape) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), | |||
| SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 1, | |||
| SizeToInt(in_shape[1]), filter_shape[0], filter_shape[1]), | |||
| "cudnnSetFilter4dDescriptor failed"); | |||
| auto out_H = output_shape[0] * output_shape[1] * output_shape[2]; | |||
| auto out_W = output_shape[3] * output_shape[4] * output_shape[5]; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | |||
| SizeToInt(out_H), SizeToInt(out_W), 1, 1), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } | |||
| void SetStrideAndDilation(const CNodePtr &kernel_node) { | |||
| stride_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride"); | |||
| dilation_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation"); | |||
| if (stride_.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "Im2Col's stride must be 4d!"; | |||
| } | |||
| if (stride_[0] != 1 || stride_[1] != 1) { | |||
| MS_LOG(EXCEPTION) << "Im2Col's stride only support 1 in N axis and C axis!"; | |||
| } | |||
| if (dilation_.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "Im2Col's dilation must be 4d!"; | |||
| } | |||
| if (dilation_[0] != 1 || dilation_[1] != 1) { | |||
| MS_LOG(EXCEPTION) << "Im2Col's dilation only support 1 in N axis and C axis!"; | |||
| } | |||
| } | |||
| cudnnHandle_t cudnn_handle_; | |||
| cudnnTensorDescriptor_t input_desc_; | |||
| cudnnTensorDescriptor_t output_desc_; | |||
| cudnnFilterDescriptor_t filter_desc_; | |||
| cudnnConvolutionFwdAlgo_t conv_algorithm_; | |||
| cudnnConvolutionDescriptor_t conv_desc_; | |||
| cudnnTensorDescriptor_t padded_desc_; | |||
| std::string pad_mode_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| const float pad_value_ = 0.0; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| int old_height_; | |||
| int old_width_; | |||
| int pad_height_; | |||
| int pad_width_; | |||
| int pad_top_; | |||
| int pad_left_; | |||
| int n_; | |||
| int c_; | |||
| std::vector<int> stride_; | |||
| std::vector<int> dilation_; | |||
| bool is_null_input_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t padded_size_; | |||
| size_t workspace_size_; | |||
| bool use_pad_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_IM2COLGPUKERNEL_H_ | |||
| @@ -83,7 +83,10 @@ from . import _quant_ops | |||
| from ._quant_ops import * | |||
| from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, | |||
| CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull) | |||
| from .thor_ops import * | |||
| from .thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, | |||
| CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, | |||
| CusMatMulCubeDenseRight, | |||
| CusMatMulCubeFraczLeftCast, Im2Col) | |||
| from .sparse_ops import SparseToDense | |||
| __all__ = [ | |||
| @@ -13,9 +13,12 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """thor_ops""" | |||
| import math | |||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | |||
| from ...common import dtype as mstype | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| __all__ = ["CusBatchMatMul", | |||
| "CusCholeskyTrsm", | |||
| @@ -31,6 +34,37 @@ __all__ = ["CusBatchMatMul", | |||
| ] | |||
| def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): | |||
| """ | |||
| Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements. | |||
| """ | |||
| def _raise_message(): | |||
| raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two " | |||
| f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}") | |||
| def _get_return_value(): | |||
| if isinstance(arg_value, int): | |||
| ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value) | |||
| elif len(arg_value) == 2: | |||
| ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value | |||
| elif len(arg_value) == 4: | |||
| if not allow_four: | |||
| _raise_message() | |||
| ret = arg_value if ret_four else (arg_value[2], arg_value[3]) | |||
| else: | |||
| _raise_message() | |||
| return ret | |||
| validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) | |||
| ret_value = _get_return_value() | |||
| for item in ret_value: | |||
| if isinstance(item, int) and item > 0: | |||
| continue | |||
| _raise_message() | |||
| return ret_value | |||
| class CusBatchMatMul(PrimitiveWithInfer): | |||
| """ | |||
| Multiplies matrix `a` by matrix `b` in batch. | |||
| @@ -360,6 +394,7 @@ class CusTranspose02314(PrimitiveWithInfer): | |||
| """init CusTranspose02314""" | |||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||
| from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314 | |||
| def get_bprop(self): | |||
| def bprop(x, out, dout): | |||
| return (C.zeros_like(x),) | |||
| @@ -446,3 +481,84 @@ class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer): | |||
| def infer_dtype(self, data1_dtype, data2_dtype): | |||
| return mstype.float16 | |||
| class Im2Col(PrimitiveWithInfer): | |||
| """ | |||
| extract image pathes from image. | |||
| The rank of input_x1 must be `4`, data_format is "NCHW". | |||
| Inputs: | |||
| - **input_x1** (Tensor) - The feature map. | |||
| The shape of the tensor is :math:`(N, C, H, W)`. | |||
| Outputs: | |||
| Tensor. | |||
| Examples: | |||
| >>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16)) | |||
| >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) | |||
| >>> output = img2col(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| kernel_size, | |||
| pad_mode="valid", | |||
| pad=0, | |||
| stride=1, | |||
| dilation=1): | |||
| """init Im2Col""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) | |||
| self.add_prim_attr('kernel_size', self.kernel_size) | |||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) | |||
| self.add_prim_attr('stride', self.stride) | |||
| self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | |||
| self.add_prim_attr('dilation', self.dilation) | |||
| validator.check_value_type('pad', pad, (int,), self.name) | |||
| self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) | |||
| self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) | |||
| if self.pad_mode == 'pad': | |||
| validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) | |||
| kernel_size_h = self.kernel_size[0] | |||
| kernel_size_w = self.kernel_size[1] | |||
| stride_h = self.stride[2] | |||
| stride_w = self.stride[3] | |||
| dilation_h = self.dilation[2] | |||
| dilation_w = self.dilation[3] | |||
| if self.pad_mode == "valid": | |||
| h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h) | |||
| w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w) | |||
| pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 | |||
| elif self.pad_mode == "same": | |||
| h_out = math.ceil(x_shape[2] / stride_h) | |||
| w_out = math.ceil(x_shape[3] / stride_w) | |||
| pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]) | |||
| pad_top = math.floor(pad_needed_h / 2) | |||
| pad_bottom = pad_needed_h - pad_top | |||
| pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]) | |||
| pad_left = math.floor(pad_needed_w / 2) | |||
| pad_right = pad_needed_w - pad_left | |||
| elif self.pad_mode == 'pad': | |||
| pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad | |||
| h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h | |||
| w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w | |||
| h_out = math.floor(h_out) | |||
| w_out = math.floor(w_out) | |||
| self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] | |||
| self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) | |||
| batch_size = x_shape[0] | |||
| channel = x_shape[1] | |||
| k_h = kernel_size_h | |||
| k_w = kernel_size_w | |||
| out_shape = [channel, k_h, k_w, batch_size, h_out, w_out] | |||
| return out_shape | |||
| def infer_dtype(self, x_dtype): | |||
| args = {'x': x_dtype} | |||
| valid_types = [mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||
| return x_dtype | |||