From: @kanghui0204 Reviewed-by: @liangchenghui,@tom__chen Signed-off-by: @liangchenghuipull/15607/MERGE
| @@ -41,25 +41,11 @@ class DepthToSpaceFwdKernel : public GpuKernel { | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | T *input = GetDeviceAddress<T>(inputs, 0); | ||||
| T *output = GetDeviceAddress<T>(outputs, 0); | T *output = GetDeviceAddress<T>(outputs, 0); | ||||
| // get device buffer shape ptr | |||||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0); | |||||
| size_t *output_shape = GetDeviceAddress<size_t>(workspace, 1); | |||||
| // buffer shape memcpy from host to device | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size1_, cudaMemcpyHostToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync input_shape failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(output_shape, &output_shape_[0], workspace_size2_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync input_shape failed"); | |||||
| // get input size | // get input size | ||||
| size_t size = input_size_ / sizeof(T); | size_t size = input_size_ / sizeof(T); | ||||
| // call cuda kernel | // call cuda kernel | ||||
| CalDepthToSpace(size, input, input_shape, output_shape, block_size_, output, | |||||
| CalDepthToSpace(size, input, in_, ic_, ih_, iw_, on_, oc_, oh_, ow_, block_size_, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -89,14 +75,20 @@ class DepthToSpaceFwdKernel : public GpuKernel { | |||||
| input_size_ = 1; | input_size_ = 1; | ||||
| for (size_t i = 0; i < shape_size_; i++) { | for (size_t i = 0; i < shape_size_; i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| input_shape_.push_back(input_shape[i]); | |||||
| } | } | ||||
| input_size_ *= sizeof(T); | input_size_ *= sizeof(T); | ||||
| output_size_ = input_size_; | output_size_ = input_size_; | ||||
| output_shape_.push_back(input_shape[0]); | |||||
| output_shape_.push_back(input_shape[1] / block_size_ / block_size_); | |||||
| output_shape_.push_back(input_shape[2] * block_size_); | |||||
| output_shape_.push_back(input_shape[3] * block_size_); | |||||
| in_ = input_shape[0]; | |||||
| ic_ = input_shape[1]; | |||||
| ih_ = input_shape[2]; | |||||
| iw_ = input_shape[3]; | |||||
| on_ = in_; | |||||
| oc_ = ic_ / block_size_ / block_size_; | |||||
| oh_ = ih_ * block_size_; | |||||
| ow_ = iw_ * block_size_; | |||||
| // Private members Initialize | // Private members Initialize | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -107,11 +99,15 @@ class DepthToSpaceFwdKernel : public GpuKernel { | |||||
| input_size_ = 0; | input_size_ = 0; | ||||
| output_size_ = 0; | output_size_ = 0; | ||||
| block_size_ = 0; | block_size_ = 0; | ||||
| workspace_size1_ = 0; | |||||
| workspace_size2_ = 0; | |||||
| in_ = 0; | |||||
| ic_ = 0; | |||||
| ih_ = 0; | |||||
| iw_ = 0; | |||||
| on_ = 0; | |||||
| oc_ = 0; | |||||
| oh_ = 0; | |||||
| ow_ = 0; | |||||
| input_shape_.clear(); | |||||
| output_shape_.clear(); | |||||
| input_size_list_.clear(); | input_size_list_.clear(); | ||||
| output_size_list_.clear(); | output_size_list_.clear(); | ||||
| workspace_size_list_.clear(); | workspace_size_list_.clear(); | ||||
| @@ -121,16 +117,10 @@ class DepthToSpaceFwdKernel : public GpuKernel { | |||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| input_size_list_.push_back(input_size_); | input_size_list_.push_back(input_size_); | ||||
| output_size_list_.push_back(output_size_); | output_size_list_.push_back(output_size_); | ||||
| workspace_size1_ = shape_size_ * sizeof(size_t); | |||||
| workspace_size2_ = shape_size_ * sizeof(size_t); | |||||
| workspace_size_list_.push_back(workspace_size1_); | |||||
| workspace_size_list_.push_back(workspace_size2_); | |||||
| return; | return; | ||||
| } | } | ||||
| private: | private: | ||||
| std::vector<size_t> input_shape_; | |||||
| std::vector<size_t> output_shape_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -138,8 +128,14 @@ class DepthToSpaceFwdKernel : public GpuKernel { | |||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t output_size_; | size_t output_size_; | ||||
| size_t block_size_; | size_t block_size_; | ||||
| size_t workspace_size1_; | |||||
| size_t workspace_size2_; | |||||
| size_t in_; | |||||
| size_t ic_; | |||||
| size_t ih_; | |||||
| size_t iw_; | |||||
| size_t on_; | |||||
| size_t oc_; | |||||
| size_t oh_; | |||||
| size_t ow_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,25 +41,11 @@ class SpaceToDepthFwdKernel : public GpuKernel { | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | T *input = GetDeviceAddress<T>(inputs, 0); | ||||
| T *output = GetDeviceAddress<T>(outputs, 0); | T *output = GetDeviceAddress<T>(outputs, 0); | ||||
| // get device buffer shape ptr | |||||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0); | |||||
| size_t *output_shape = GetDeviceAddress<size_t>(workspace, 1); | |||||
| // buffer shape memcpy from host to device | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size1_, cudaMemcpyHostToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync input_shape failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudaMemcpyAsync(output_shape, &output_shape_[0], workspace_size2_, | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync input_shape failed"); | |||||
| // get input size | // get input size | ||||
| size_t size = input_size_ / sizeof(T); | size_t size = input_size_ / sizeof(T); | ||||
| // call cuda kernel | // call cuda kernel | ||||
| CalSpaceToDepth(size, input, input_shape, output_shape, block_size_, output, | |||||
| CalSpaceToDepth(size, input, in_, ic_, ih_, iw_, on_, oc_, oh_, ow_, block_size_, output, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -89,14 +75,19 @@ class SpaceToDepthFwdKernel : public GpuKernel { | |||||
| input_size_ = 1; | input_size_ = 1; | ||||
| for (size_t i = 0; i < shape_size_; i++) { | for (size_t i = 0; i < shape_size_; i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| input_shape_.push_back(input_shape[i]); | |||||
| } | } | ||||
| input_size_ *= sizeof(T); | input_size_ *= sizeof(T); | ||||
| output_size_ = input_size_; | output_size_ = input_size_; | ||||
| output_shape_.push_back(input_shape[0]); | |||||
| output_shape_.push_back(input_shape[1] * block_size_ * block_size_); | |||||
| output_shape_.push_back(input_shape[2] / block_size_); | |||||
| output_shape_.push_back(input_shape[3] / block_size_); | |||||
| in_ = input_shape[0]; | |||||
| ic_ = input_shape[1]; | |||||
| ih_ = input_shape[2]; | |||||
| iw_ = input_shape[3]; | |||||
| on_ = in_; | |||||
| oc_ = ic_ * block_size_ * block_size_; | |||||
| oh_ = ih_ / block_size_; | |||||
| ow_ = iw_ / block_size_; | |||||
| // Private members Initialize | // Private members Initialize | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -107,11 +98,15 @@ class SpaceToDepthFwdKernel : public GpuKernel { | |||||
| input_size_ = 0; | input_size_ = 0; | ||||
| output_size_ = 0; | output_size_ = 0; | ||||
| block_size_ = 0; | block_size_ = 0; | ||||
| workspace_size1_ = 0; | |||||
| workspace_size2_ = 0; | |||||
| in_ = 0; | |||||
| ic_ = 0; | |||||
| ih_ = 0; | |||||
| iw_ = 0; | |||||
| on_ = 0; | |||||
| oc_ = 0; | |||||
| oh_ = 0; | |||||
| ow_ = 0; | |||||
| input_shape_.clear(); | |||||
| output_shape_.clear(); | |||||
| input_size_list_.clear(); | input_size_list_.clear(); | ||||
| output_size_list_.clear(); | output_size_list_.clear(); | ||||
| workspace_size_list_.clear(); | workspace_size_list_.clear(); | ||||
| @@ -121,16 +116,10 @@ class SpaceToDepthFwdKernel : public GpuKernel { | |||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| input_size_list_.push_back(input_size_); | input_size_list_.push_back(input_size_); | ||||
| output_size_list_.push_back(output_size_); | output_size_list_.push_back(output_size_); | ||||
| workspace_size1_ = shape_size_ * sizeof(size_t); | |||||
| workspace_size2_ = shape_size_ * sizeof(size_t); | |||||
| workspace_size_list_.push_back(workspace_size1_); | |||||
| workspace_size_list_.push_back(workspace_size2_); | |||||
| return; | return; | ||||
| } | } | ||||
| private: | private: | ||||
| std::vector<size_t> input_shape_; | |||||
| std::vector<size_t> output_shape_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -138,8 +127,14 @@ class SpaceToDepthFwdKernel : public GpuKernel { | |||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t output_size_; | size_t output_size_; | ||||
| size_t block_size_; | size_t block_size_; | ||||
| size_t workspace_size1_; | |||||
| size_t workspace_size2_; | |||||
| size_t in_; | |||||
| size_t ic_; | |||||
| size_t ih_; | |||||
| size_t iw_; | |||||
| size_t on_; | |||||
| size_t oc_; | |||||
| size_t oh_; | |||||
| size_t ow_; | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,79 +13,126 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
| #include "depthtospace_impl.cuh" | #include "depthtospace_impl.cuh" | ||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void DepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, | |||||
| const size_t r, T *output) { | |||||
| __global__ void DepthToSpace(const size_t size, const T *input, const size_t in, | |||||
| const size_t ic, const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, const size_t oh, | |||||
| const size_t ow, const size_t r, T *output) { | |||||
| size_t temp_stride = 0; | size_t temp_stride = 0; | ||||
| size_t temp_pos = 0; | size_t temp_pos = 0; | ||||
| size_t input_pos = 0; | size_t input_pos = 0; | ||||
| size_t input_pos_array[DEPTHTOSPACE_BUFFER_DIMENSION]; | |||||
| size_t output_pos_array[DEPTHTOSPACE_BUFFER_DIMENSION]; | size_t output_pos_array[DEPTHTOSPACE_BUFFER_DIMENSION]; | ||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| temp_stride = output_shape[1] * output_shape[2] * output_shape[3]; | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; | |||||
| pos += blockDim.x * gridDim.x) { | |||||
| temp_stride = oc * oh * ow; | |||||
| output_pos_array[0] = pos / temp_stride; | output_pos_array[0] = pos / temp_stride; | ||||
| temp_pos = pos % temp_stride; | temp_pos = pos % temp_stride; | ||||
| for (size_t i = 1; i < DEPTHTOSPACE_BUFFER_DIMENSION; i++) { | |||||
| temp_stride /= output_shape[i]; | |||||
| output_pos_array[i] = temp_pos / temp_stride; | |||||
| temp_pos %= temp_stride; | |||||
| } | |||||
| temp_stride /= oc; | |||||
| output_pos_array[1] = temp_pos / temp_stride; | |||||
| temp_pos = pos % temp_stride; | |||||
| temp_stride /= oh; | |||||
| output_pos_array[2] = temp_pos / temp_stride; | |||||
| temp_pos = pos % temp_stride; | |||||
| input_pos_array[0] = output_pos_array[0]; | |||||
| input_pos_array[1] = output_pos_array[1] * r * r + r * (output_pos_array[2] % r) + output_pos_array[3] % r; | |||||
| input_pos_array[2] = output_pos_array[2] / r; | |||||
| input_pos_array[3] = output_pos_array[3] / r; | |||||
| temp_stride /= ow; | |||||
| output_pos_array[3] = temp_pos / temp_stride; | |||||
| for (size_t i = 0; i < 3; ++i) { | |||||
| input_pos += input_pos_array[i]; | |||||
| input_pos *= input_shape[i + 1]; | |||||
| } | |||||
| input_pos += input_pos_array[3]; | |||||
| input_pos += output_pos_array[0]; | |||||
| input_pos = | |||||
| (input_pos * ic) + | |||||
| (output_pos_array[1] + | |||||
| (r * (output_pos_array[2] % r) + output_pos_array[3] % r) * oc); | |||||
| input_pos = (input_pos * ih) + (output_pos_array[2] / r); | |||||
| input_pos = (input_pos * iw) + (output_pos_array[3] / r); | |||||
| output[pos] = input[input_pos]; | output[pos] = input[input_pos]; | ||||
| input_pos = 0; | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void CalDepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, | |||||
| const size_t r, T *output, cudaStream_t cuda_stream) { | |||||
| DepthToSpace<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, input_shape, output_shape, r, output); | |||||
| void CalDepthToSpace(const size_t size, const T *input, const size_t in, | |||||
| const size_t ic, const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, const size_t oh, | |||||
| const size_t ow, const size_t r, T *output, | |||||
| cudaStream_t cuda_stream) { | |||||
| DepthToSpace<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||||
| size, input, in, ic, ih, iw, on, oc, oh, ow, r, output); | |||||
| return; | return; | ||||
| } | } | ||||
| template void CalDepthToSpace<float>(const size_t size, const float *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, float *output, | |||||
| template void CalDepthToSpace<float>(const size_t size, const float *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, float *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalDepthToSpace<half>(const size_t size, const half *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, half *output, cudaStream_t cuda_stream); | |||||
| template void CalDepthToSpace<int>(const size_t size, const int *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int *output, cudaStream_t cuda_stream); | |||||
| template void CalDepthToSpace<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int64_t *output, | |||||
| template void CalDepthToSpace<half>(const size_t size, const half *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, half *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalDepthToSpace<int>(const size_t size, const int *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalDepthToSpace<int64_t>(const size_t size, const int64_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int64_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalDepthToSpace<int16_t>(const size_t size, const int16_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int16_t *output, | |||||
| template void CalDepthToSpace<int16_t>(const size_t size, const int16_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int16_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalDepthToSpace<int8_t>(const size_t size, const int8_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int8_t *output, | |||||
| template void CalDepthToSpace<int8_t>(const size_t size, const int8_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int8_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalDepthToSpace<uint8_t>(const size_t size, const uint8_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint8_t *output, | |||||
| template void CalDepthToSpace<uint8_t>(const size_t size, const uint8_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, uint8_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalDepthToSpace<uint16_t>(const size_t size, const uint16_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint16_t *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalDepthToSpace<uint32_t>(const size_t size, const uint32_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint32_t *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalDepthToSpace<uint64_t>(const size_t size, const uint64_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint64_t *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void | |||||
| CalDepthToSpace<uint16_t>(const size_t size, const uint16_t *input, | |||||
| const size_t in, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, const size_t r, | |||||
| uint16_t *output, cudaStream_t cuda_stream); | |||||
| template void | |||||
| CalDepthToSpace<uint32_t>(const size_t size, const uint32_t *input, | |||||
| const size_t in, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, const size_t r, | |||||
| uint32_t *output, cudaStream_t cuda_stream); | |||||
| template void | |||||
| CalDepthToSpace<uint64_t>(const size_t size, const uint64_t *input, | |||||
| const size_t in, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, const size_t r, | |||||
| uint64_t *output, cudaStream_t cuda_stream); | |||||
| @@ -19,7 +19,10 @@ | |||||
| #define DEPTHTOSPACE_BUFFER_DIMENSION 4 | #define DEPTHTOSPACE_BUFFER_DIMENSION 4 | ||||
| template <typename T> | template <typename T> | ||||
| void CalDepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, | |||||
| const size_t r, T *output, cudaStream_t cuda_stream); | |||||
| void CalDepthToSpace(const size_t size, const T *input, const size_t in, | |||||
| const size_t ic, const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, const size_t oh, | |||||
| const size_t ow, const size_t r, T *output, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_ | ||||
| @@ -19,73 +19,120 @@ | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void SpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, | |||||
| const size_t r, T *output) { | |||||
| __global__ void SpaceToDepth(const size_t size, const T *input, const size_t in, | |||||
| const size_t ic, const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, const size_t oh, | |||||
| const size_t ow, const size_t r, T *output) { | |||||
| size_t temp_stride = 0; | size_t temp_stride = 0; | ||||
| size_t temp_pos = 0; | size_t temp_pos = 0; | ||||
| size_t output_pos = 0; | size_t output_pos = 0; | ||||
| size_t input_pos_array[SPACETODEPTH_BUFFER_DIMENSION]; | size_t input_pos_array[SPACETODEPTH_BUFFER_DIMENSION]; | ||||
| size_t output_pos_array[SPACETODEPTH_BUFFER_DIMENSION]; | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||||
| temp_stride = input_shape[1] * input_shape[2] * input_shape[3]; | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; | |||||
| pos += blockDim.x * gridDim.x) { | |||||
| temp_stride = ic * ih * iw; | |||||
| input_pos_array[0] = pos / temp_stride; | input_pos_array[0] = pos / temp_stride; | ||||
| temp_pos = pos % temp_stride; | temp_pos = pos % temp_stride; | ||||
| for (size_t i = 1; i < SPACETODEPTH_BUFFER_DIMENSION; i++) { | |||||
| temp_stride /= input_shape[i]; | |||||
| input_pos_array[i] = temp_pos / temp_stride; | |||||
| temp_pos %= temp_stride; | |||||
| } | |||||
| temp_stride /= ic; | |||||
| input_pos_array[1] = temp_pos / temp_stride; | |||||
| temp_pos = pos % temp_stride; | |||||
| temp_stride /= ih; | |||||
| input_pos_array[2] = temp_pos / temp_stride; | |||||
| temp_pos = pos % temp_stride; | |||||
| output_pos_array[0] = input_pos_array[0]; | |||||
| output_pos_array[1] = input_pos_array[1] * r * r + r * (input_pos_array[2] % r) + input_pos_array[3] % r; | |||||
| output_pos_array[2] = input_pos_array[2] / r; | |||||
| output_pos_array[3] = input_pos_array[3] / r; | |||||
| temp_stride /= iw; | |||||
| input_pos_array[3] = temp_pos / temp_stride; | |||||
| for (size_t i = 0; i < 3; ++i) { | |||||
| output_pos += output_pos_array[i]; | |||||
| output_pos *= output_shape[i + 1]; | |||||
| } | |||||
| output_pos += output_pos_array[3]; | |||||
| output_pos += input_pos_array[0]; | |||||
| output_pos = (output_pos * oc) + | |||||
| (input_pos_array[1] + | |||||
| (r * (input_pos_array[2] % r) + input_pos_array[3] % r) * ic); | |||||
| output_pos = (output_pos * oh) + (input_pos_array[2] / r); | |||||
| output_pos = (output_pos * ow) + (input_pos_array[3] / r); | |||||
| output[output_pos] = input[pos]; | output[output_pos] = input[pos]; | ||||
| output_pos = 0; | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void CalSpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, | |||||
| const size_t r, T *output, cudaStream_t cuda_stream) { | |||||
| SpaceToDepth<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, input_shape, output_shape, r, output); | |||||
| void CalSpaceToDepth(const size_t size, const T *input, const size_t in, | |||||
| const size_t ic, const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, const size_t oh, | |||||
| const size_t ow, const size_t r, T *output, | |||||
| cudaStream_t cuda_stream) { | |||||
| SpaceToDepth<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||||
| size, input, in, ic, ih, iw, on, oc, oh, ow, r, output); | |||||
| return; | return; | ||||
| } | } | ||||
| template void CalSpaceToDepth<float>(const size_t size, const float *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, float *output, | |||||
| template void CalSpaceToDepth<float>(const size_t size, const float *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, float *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalSpaceToDepth<half>(const size_t size, const half *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, half *output, cudaStream_t cuda_stream); | |||||
| template void CalSpaceToDepth<int>(const size_t size, const int *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int *output, cudaStream_t cuda_stream); | |||||
| template void CalSpaceToDepth<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int64_t *output, | |||||
| template void CalSpaceToDepth<half>(const size_t size, const half *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, half *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalSpaceToDepth<int>(const size_t size, const int *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalSpaceToDepth<int64_t>(const size_t size, const int64_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int64_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalSpaceToDepth<int16_t>(const size_t size, const int16_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int16_t *output, | |||||
| template void CalSpaceToDepth<int16_t>(const size_t size, const int16_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int16_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalSpaceToDepth<int8_t>(const size_t size, const int8_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, int8_t *output, | |||||
| template void CalSpaceToDepth<int8_t>(const size_t size, const int8_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, int8_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalSpaceToDepth<uint8_t>(const size_t size, const uint8_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint8_t *output, | |||||
| template void CalSpaceToDepth<uint8_t>(const size_t size, const uint8_t *input, | |||||
| const size_t in, const size_t ic, | |||||
| const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, | |||||
| const size_t r, uint8_t *output, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| template void CalSpaceToDepth<uint16_t>(const size_t size, const uint16_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint16_t *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalSpaceToDepth<uint32_t>(const size_t size, const uint32_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint32_t *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalSpaceToDepth<uint64_t>(const size_t size, const uint64_t *input, const size_t *input_shape, | |||||
| const size_t *output_shape, const size_t r, uint64_t *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void | |||||
| CalSpaceToDepth<uint16_t>(const size_t size, const uint16_t *input, | |||||
| const size_t in, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, const size_t r, | |||||
| uint16_t *output, cudaStream_t cuda_stream); | |||||
| template void | |||||
| CalSpaceToDepth<uint32_t>(const size_t size, const uint32_t *input, | |||||
| const size_t in, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, const size_t r, | |||||
| uint32_t *output, cudaStream_t cuda_stream); | |||||
| template void | |||||
| CalSpaceToDepth<uint64_t>(const size_t size, const uint64_t *input, | |||||
| const size_t in, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t on, const size_t oc, | |||||
| const size_t oh, const size_t ow, const size_t r, | |||||
| uint64_t *output, cudaStream_t cuda_stream); | |||||
| @@ -19,7 +19,10 @@ | |||||
| #define SPACETODEPTH_BUFFER_DIMENSION 4 | #define SPACETODEPTH_BUFFER_DIMENSION 4 | ||||
| template <typename T> | template <typename T> | ||||
| void CalSpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, | |||||
| const size_t r, T *output, cudaStream_t cuda_stream); | |||||
| void CalSpaceToDepth(const size_t size, const T *input, const size_t in, | |||||
| const size_t ic, const size_t ih, const size_t iw, | |||||
| const size_t on, const size_t oc, const size_t oh, | |||||
| const size_t ow, const size_t r, T *output, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_ | ||||
| @@ -22,55 +22,51 @@ from mindspore.common.api import ms_function | |||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| def DepthToSpaceNumpy(arr, block_size): | |||||
| ''' | |||||
| DepthToSpace ops use numpy | |||||
| ''' | |||||
| tmpshape = arr.shape | |||||
| newshape = [] | |||||
| newshape.append(tmpshape[0]) | |||||
| newshape.append(tmpshape[1]//block_size//block_size) | |||||
| newshape.append(tmpshape[2]*block_size) | |||||
| newshape.append(tmpshape[3]*block_size) | |||||
| output = arr.reshape(newshape[0], newshape[1], block_size, block_size, tmpshape[2], tmpshape[3]) | |||||
| output = np.transpose(output, (0, 1, 4, 2, 5, 3)) | |||||
| output = output.reshape(newshape) | |||||
| return output | |||||
| class DepthToSpaceNet(nn.Cell): | class DepthToSpaceNet(nn.Cell): | ||||
| def __init__(self, nptype, block_size=2, input_shape=(1, 4, 3, 3)): | |||||
| def __init__(self, nptype, block_size=2, input_shape=(1, 12, 1, 1)): | |||||
| super(DepthToSpaceNet, self).__init__() | super(DepthToSpaceNet, self).__init__() | ||||
| self.DepthToSpace = P.DepthToSpace(2) | self.DepthToSpace = P.DepthToSpace(2) | ||||
| input_size = 1 | input_size = 1 | ||||
| for i in input_shape: | for i in input_shape: | ||||
| input_size = input_size*i | input_size = input_size*i | ||||
| self.data_np = np.arange(input_size).reshape(input_shape).astype(nptype) | |||||
| self.x = Parameter(initializer(Tensor(self.data_np), input_shape), name='x') | |||||
| data_np = np.arange(input_size).reshape(input_shape).astype(nptype) | |||||
| self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1') | |||||
| @ms_function | @ms_function | ||||
| def construct(self): | def construct(self): | ||||
| return self.DepthToSpace(self.x) | |||||
| y1 = self.DepthToSpace(self.x1) | |||||
| return y1 | |||||
| def DepthToSpace(nptype, block_size=2, input_shape=(1, 4, 3, 3)): | |||||
| def DepthToSpace(nptype, block_size=2, input_shape=(1, 12, 1, 1)): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| input_size = 1 | input_size = 1 | ||||
| for i in input_shape: | for i in input_shape: | ||||
| input_size = input_size*i | input_size = input_size*i | ||||
| expect = np.arange(input_size).reshape(input_shape).astype(nptype) | |||||
| expect = DepthToSpaceNumpy(expect, block_size) | |||||
| expect = np.array([[[[0, 3], | |||||
| [6, 9]], | |||||
| [[1, 4], | |||||
| [7, 10]], | |||||
| [[2, 5], | |||||
| [8, 11]]]]).astype(nptype) | |||||
| dts = DepthToSpaceNet(nptype, block_size, input_shape) | dts = DepthToSpaceNet(nptype, block_size, input_shape) | ||||
| output = dts() | output = dts() | ||||
| print(output) | |||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| def DepthToSpace_pynative(nptype, block_size=2, input_shape=(1, 4, 3, 3)): | |||||
| def DepthToSpace_pynative(nptype, block_size=2, input_shape=(1, 12, 1, 1)): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | ||||
| input_size = 1 | input_size = 1 | ||||
| for i in input_shape: | for i in input_shape: | ||||
| input_size = input_size*i | input_size = input_size*i | ||||
| expect = np.arange(input_size).reshape(input_shape).astype(nptype) | |||||
| expect = DepthToSpaceNumpy(expect, block_size) | |||||
| expect = np.array([[[[0, 3], | |||||
| [6, 9]], | |||||
| [[1, 4], | |||||
| [7, 10]], | |||||
| [[2, 5], | |||||
| [8, 11]]]]).astype(nptype) | |||||
| dts = P.DepthToSpace(2) | dts = P.DepthToSpace(2) | ||||
| arr_input = Tensor(np.arange(input_size).reshape(input_shape).astype(nptype)) | arr_input = Tensor(np.arange(input_size).reshape(input_shape).astype(nptype)) | ||||
| @@ -22,70 +22,45 @@ from mindspore.common.api import ms_function | |||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| def DepthToSpaceNumpy(arr, block_size): | |||||
| ''' | |||||
| DepthToSpace ops use numpy | |||||
| DepthToSpace ops is reverse ops to SpaceToDepth ops | |||||
| therefore DepthToSpace's output can be SpaceToDepth's input | |||||
| ''' | |||||
| tmpshape = arr.shape | |||||
| newshape = [] | |||||
| newshape.append(tmpshape[0]) | |||||
| newshape.append(tmpshape[1]//block_size//block_size) | |||||
| newshape.append(tmpshape[2]*block_size) | |||||
| newshape.append(tmpshape[3]*block_size) | |||||
| output = arr.reshape(newshape[0], newshape[1], block_size, block_size, tmpshape[2], tmpshape[3]) | |||||
| output = np.transpose(output, (0, 1, 4, 2, 5, 3)) | |||||
| output = output.reshape(newshape) | |||||
| return output | |||||
| class SpaceToDepthNet(nn.Cell): | class SpaceToDepthNet(nn.Cell): | ||||
| def __init__(self, nptype, block_size=2, input_shape=(1, 4, 3, 3)): | |||||
| def __init__(self, nptype): | |||||
| super(SpaceToDepthNet, self).__init__() | super(SpaceToDepthNet, self).__init__() | ||||
| self.SpaceToDepth = P.SpaceToDepth(block_size) | |||||
| input_size = 1 | |||||
| for i in input_shape: | |||||
| input_size = input_size*i | |||||
| data_np = np.arange(input_size).reshape(input_shape).astype(nptype)# data_np shape is (N,C,H,W) | |||||
| data_np = DepthToSpaceNumpy(data_np, block_size)#now data_np shape is (N,C/(block_size*block_size),H*block_size,W*block_size) | |||||
| self.SpaceToDepth = P.SpaceToDepth(2) | |||||
| data_np = np.array([[[[0, 3], | |||||
| [6, 9]], | |||||
| [[1, 4], | |||||
| [7, 10]], | |||||
| [[2, 5], | |||||
| [8, 11]]]]).astype(nptype) | |||||
| self.data_np = data_np | self.data_np = data_np | ||||
| new_shape = [] | |||||
| new_shape.append(input_shape[0]) | |||||
| new_shape.append(input_shape[1]//(block_size*block_size)) | |||||
| new_shape.append(input_shape[2]*block_size) | |||||
| new_shape.append(input_shape[3]*block_size) | |||||
| self.x = Parameter(initializer(Tensor(self.data_np), new_shape), name='x') | |||||
| self.x = Parameter(initializer(Tensor(self.data_np), (1, 3, 2, 2)), name='x') | |||||
| @ms_function | @ms_function | ||||
| def construct(self): | def construct(self): | ||||
| return self.SpaceToDepth(self.x) | return self.SpaceToDepth(self.x) | ||||
| def SpaceToDepth(nptype, block_size=2, input_shape=(1, 4, 3, 3)): | |||||
| def SpaceToDepth(nptype): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| input_size = 1 | |||||
| for i in input_shape: | |||||
| input_size = input_size*i | |||||
| expect = np.arange(input_size).reshape(input_shape).astype(nptype) | |||||
| std = SpaceToDepthNet(nptype, block_size, input_shape) | |||||
| expect = np.arange(12).reshape((1, 12, 1, 1)).astype(nptype) | |||||
| std = SpaceToDepthNet(nptype) | |||||
| output = std() | output = std() | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| def SpaceToDepth_pynative(nptype, block_size=2, input_shape=(1, 4, 3, 3)): | |||||
| def SpaceToDepth_pynative(nptype): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | ||||
| input_size = 1 | |||||
| for i in input_shape: | |||||
| input_size = input_size*i | |||||
| expect = np.arange(input_size).reshape(input_shape).astype(nptype) | |||||
| arrinput = DepthToSpaceNumpy(expect, block_size) | |||||
| std = P.SpaceToDepth(block_size) | |||||
| arrinput = Tensor(arrinput) | |||||
| output = std(arrinput) | |||||
| expect = np.arange(12).reshape((1, 12, 1, 1)).astype(nptype) | |||||
| std = P.SpaceToDepth(2) | |||||
| data_np = np.array([[[[0, 3], | |||||
| [6, 9]], | |||||
| [[1, 4], | |||||
| [7, 10]], | |||||
| [[2, 5], | |||||
| [8, 11]]]]).astype(nptype) | |||||
| tensor_input = Tensor(data_np) | |||||
| output = std(tensor_input) | |||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| @@ -208,3 +183,6 @@ def test_spacetodepth_pynative_uint32(): | |||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_spacetodepth_pynative_uint64(): | def test_spacetodepth_pynative_uint64(): | ||||
| SpaceToDepth_pynative(np.uint64) | SpaceToDepth_pynative(np.uint64) | ||||
| test_spacetodepth_graph_float32() | |||||
| test_spacetodepth_pynative_int32() | |||||