Merge pull request !7106 from VectorSL/gpu-size_ttags/v1.1.0
| @@ -86,9 +86,9 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int bound_; | |||
| int outerSize_; | |||
| int innerSize_; | |||
| size_t bound_; | |||
| size_t outerSize_; | |||
| size_t innerSize_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -180,15 +180,16 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| return; | |||
| } | |||
| void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) { | |||
| std::vector<int> inputA; | |||
| std::vector<size_t> inputA; | |||
| std::vector<size_t> outputC_shape = output_shape; | |||
| const int split_dim = 4; | |||
| if (input_shape.size() <= split_dim) { | |||
| ShapeNdTo4d(input_shape, &inputA); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, | |||
| inputA[0], inputA[1], inputA[2], inputA[3]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, SizeToInt(inputA[0]), | |||
| SizeToInt(inputA[1]), SizeToInt(inputA[2]), SizeToInt(inputA[3])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } else { | |||
| CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_); | |||
| for (auto dim : input_shape) { | |||
| @@ -216,7 +217,7 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| return; | |||
| } | |||
| std::vector<int> outputC; | |||
| std::vector<size_t> outputC; | |||
| if (!keep_dims_) { | |||
| for (auto i : axis_) { | |||
| (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); | |||
| @@ -225,9 +226,10 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| if (outputC_shape.size() <= split_dim) { | |||
| ShapeNdTo4d(outputC_shape, &outputC); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, | |||
| outputC[0], outputC[1], outputC[2], outputC[3]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, SizeToInt(outputC[0]), | |||
| SizeToInt(outputC[1]), SizeToInt(outputC[2]), SizeToInt(outputC[3])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } else { | |||
| CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_); | |||
| for (auto dim : outputC_shape) { | |||
| @@ -71,8 +71,8 @@ class BroadcastToGpuKernel : public GpuKernel { | |||
| } | |||
| private: | |||
| int input_shape_[4] = {1, 1, 1, 1}; | |||
| int output_shape_[4] = {1, 1, 1, 1}; | |||
| size_t input_shape_[4] = {1, 1, 1, 1}; | |||
| size_t output_shape_[4] = {1, 1, 1, 1}; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -61,7 +61,7 @@ class SliceGpuFwdKernel : public GpuKernel { | |||
| (void)size_.insert(size_.begin(), 1); | |||
| } | |||
| input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); | |||
| input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T); | |||
| auto out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||
| output_size_ = sizeof(T); | |||
| @@ -118,7 +118,7 @@ class SliceGpuFwdKernel : public GpuKernel { | |||
| } | |||
| std::vector<int> begin_; | |||
| std::vector<int> size_; | |||
| std::vector<int> input_shape_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -50,7 +50,10 @@ class SliceGradGpuKernel : public GpuKernel { | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == "StridedSliceGrad") { | |||
| is_strided_slice_ = true; | |||
| input_shape_ = GetAttr<std::vector<int>>(kernel_node, "shapex"); | |||
| auto shapex = GetAttr<std::vector<int>>(kernel_node, "shapex"); | |||
| for (auto x : shapex) { | |||
| input_shape_.push_back(IntToSize(x)); | |||
| } | |||
| for (auto i = input_shape_.size(); i < 4; i++) { | |||
| (void)input_shape_.insert(input_shape_.begin(), 1); | |||
| } | |||
| @@ -69,11 +72,11 @@ class SliceGradGpuKernel : public GpuKernel { | |||
| ShapeNdTo4d(dy_shape, &dy_shape_); | |||
| begin_ = GetAttr<std::vector<int>>(kernel_node, "begin"); | |||
| DealParam(); | |||
| input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); | |||
| input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T); | |||
| output_size_ = sizeof(T); | |||
| for (auto x : dy_shape_) { | |||
| output_size_ = output_size_ * IntToSize(x); | |||
| output_size_ = output_size_ * x; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| @@ -125,8 +128,8 @@ class SliceGradGpuKernel : public GpuKernel { | |||
| std::vector<int> begin_; | |||
| std::vector<int> size_; | |||
| std::vector<int> strides_; | |||
| std::vector<int> input_shape_; | |||
| std::vector<int> dy_shape_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> dy_shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -72,7 +72,7 @@ class StridedSliceGpuKernel : public GpuKernel { | |||
| } | |||
| input_size_list_.push_back(size); | |||
| int size1 = sizeof(T); | |||
| size_t size1 = sizeof(T); | |||
| for (size_t i = 0; i < MAX_DIMS; i++) { | |||
| size1 *= output_shape_[i]; | |||
| } | |||
| @@ -188,7 +188,7 @@ class StridedSliceGpuKernel : public GpuKernel { | |||
| std::vector<int> end_; | |||
| std::vector<int> strides_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<int> output_shape_; | |||
| std::vector<size_t> output_shape_; | |||
| int null_output_; | |||
| std::vector<size_t> input_size_list_; | |||
| @@ -50,7 +50,10 @@ class StridedSliceGradGpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| input_shape_ = GetAttr<std::vector<int>>(kernel_node, "shapex"); | |||
| auto shapex = GetAttr<std::vector<int>>(kernel_node, "shapex"); | |||
| for (auto x : shapex) { | |||
| input_shape_.push_back(IntToSize(x)); | |||
| } | |||
| if (input_shape_.size() > MAX_DIMS) { | |||
| MS_LOG(ERROR) << "StridedSliceGrad support support dims less than " << input_shape_.size(); | |||
| return false; | |||
| @@ -66,13 +69,13 @@ class StridedSliceGradGpuKernel : public GpuKernel { | |||
| protected: | |||
| void InitSizeLists() override { | |||
| int size = sizeof(T); | |||
| size_t size = sizeof(T); | |||
| for (size_t i = 0; i < MAX_DIMS; i++) { | |||
| size *= output_shape_[i]; | |||
| } | |||
| input_size_list_.push_back(size); | |||
| int size1 = sizeof(T); | |||
| size_t size1 = sizeof(T); | |||
| for (size_t i = 0; i < MAX_DIMS; i++) { | |||
| size1 *= input_shape_[i]; | |||
| } | |||
| @@ -187,8 +190,8 @@ class StridedSliceGradGpuKernel : public GpuKernel { | |||
| std::vector<int> begin_; | |||
| std::vector<int> end_; | |||
| std::vector<int> strides_; | |||
| std::vector<int> input_shape_; | |||
| std::vector<int> output_shape_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> output_shape_; | |||
| int null_output_; | |||
| std::vector<size_t> input_size_list_; | |||
| @@ -37,17 +37,16 @@ class TransposeGpuFwdKernel : public GpuKernel { | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| int *input_shape = GetDeviceAddress<int>(workspace, 0); | |||
| int *input_axis = GetDeviceAddress<int>(workspace, 1); | |||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0); | |||
| size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_shape failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_axis failed"); | |||
| int size = SizeToInt(input_size_ / sizeof(T)); | |||
| CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| size_t size = input_size_ / sizeof(T); | |||
| CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -88,15 +87,15 @@ class TransposeGpuFwdKernel : public GpuKernel { | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_ = shape_size_ * sizeof(int); | |||
| workspace_size_ = shape_size_ * sizeof(size_t); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| return; | |||
| } | |||
| private: | |||
| std::vector<int> input_shape_; | |||
| std::vector<int> input_axis_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> input_axis_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -18,14 +18,16 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "include/cuda_fp16.h" | |||
| template <typename T, typename S> | |||
| __global__ void ArgmaxWithValue(const T *input, const int bound, int outerSize, int innerSize, S *index, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize; pos += gridDim.x * blockDim.x) { | |||
| int x = pos / innerSize % outerSize; | |||
| int y = pos % innerSize; | |||
| __global__ void ArgmaxWithValue(const T *input, const size_t bound, size_t outerSize, | |||
| size_t innerSize, S *index, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize; | |||
| pos += gridDim.x * blockDim.x) { | |||
| size_t x = pos / innerSize % outerSize; | |||
| size_t y = pos % innerSize; | |||
| S idx = 0; | |||
| int InputOffset = x * bound * innerSize + 0 * innerSize + y; | |||
| size_t InputOffset = x * bound * innerSize + 0 * innerSize + y; | |||
| T maxData = input[InputOffset]; | |||
| for (int i = 0; i < bound; i++) { | |||
| for (size_t i = 0; i < bound; i++) { | |||
| InputOffset = x * bound * innerSize + i * innerSize + y; | |||
| auto inputData = input[InputOffset]; | |||
| idx = inputData > maxData ? i : idx; | |||
| @@ -38,14 +40,16 @@ __global__ void ArgmaxWithValue(const T *input, const int bound, int outerSize, | |||
| } | |||
| template <typename T, typename S> | |||
| void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index, | |||
| T *output, cudaStream_t cuda_stream) { | |||
| ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_, index, | |||
| output); | |||
| void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, | |||
| S *index, T *output, cudaStream_t cuda_stream) { | |||
| ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_, | |||
| index, output); | |||
| return; | |||
| } | |||
| template void CalArgmaxWithValue<float, int>(const float *input, const int bound_, const int outerSize_, | |||
| const int innerSize_, int *index, float *output, cudaStream_t cuda_stream); | |||
| template void CalArgmaxWithValue<half, int>(const half *input, const int bound_, const int outerSize_, | |||
| const int innerSize_, int *index, half *output, cudaStream_t cuda_stream); | |||
| template void CalArgmaxWithValue<float, int>(const float *input, const size_t bound_, const size_t outerSize_, | |||
| const size_t innerSize_, int *index, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalArgmaxWithValue<half, int>(const half *input, const size_t bound_, const size_t outerSize_, | |||
| const size_t innerSize_, int *index, half *output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -17,6 +17,6 @@ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ | |||
| template <typename T, typename S> | |||
| void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index, | |||
| void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, S *index, | |||
| T *output, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ | |||
| @@ -219,31 +219,33 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int * | |||
| cudaStream_t stream); | |||
| // Broadcast comparation | |||
| __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } | |||
| __device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } | |||
| template <typename T, typename Func> | |||
| __global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, | |||
| const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, | |||
| const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, | |||
| const int d4, const int d5, const int d6, const T *x0, const T *x1, bool *y) { | |||
| __global__ void BroadcastCmpKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, | |||
| const size_t l4, const size_t l5, const size_t l6, const size_t r0, | |||
| const size_t r1, const size_t r2, const size_t r3, const size_t r4, | |||
| const size_t r5, const size_t r6, const size_t d0, const size_t d1, | |||
| const size_t d2, const size_t d3, const size_t d4, const size_t d5, | |||
| const size_t d6, const T *x0, const T *x1, bool *y) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; | |||
| pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; | |||
| int j = pos / (d2 * d3 * d4 * d5 * d6) % d1; | |||
| int k = pos / (d3 * d4 * d5 * d6) % d2; | |||
| int l = pos / (d4 * d5 * d6) % d3; | |||
| int m = pos / (d5 * d6) % d4; | |||
| int n = pos / d6 % d5; | |||
| int o = pos % d6; | |||
| int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; | |||
| size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; | |||
| size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1; | |||
| size_t k = pos / (d3 * d4 * d5 * d6) % d2; | |||
| size_t l = pos / (d4 * d5 * d6) % d3; | |||
| size_t m = pos / (d5 * d6) % d4; | |||
| size_t n = pos / d6 % d5; | |||
| size_t o = pos % d6; | |||
| size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; | |||
| l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6; | |||
| l_index += Index(k, l2) * l3 * l4 * l5 * l6; | |||
| l_index += Index(l, l3) * l4 * l5 * l6; | |||
| l_index += Index(m, l4) * l5 * l6; | |||
| l_index += Index(n, l5) * l6; | |||
| l_index += Index(o, l6); | |||
| int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; | |||
| size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; | |||
| r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6; | |||
| r_index += Index(k, r2) * r3 * r4 * r5 * r6; | |||
| r_index += Index(l, r3) * r4 * r5 * r6; | |||
| @@ -255,9 +257,10 @@ __global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, con | |||
| } | |||
| template <typename T> | |||
| void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims, | |||
| enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) { | |||
| int size = 1; | |||
| void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, | |||
| const T *x1, bool *y, cudaStream_t stream) { | |||
| size_t size = 1; | |||
| for (auto d : y_dims) { | |||
| size *= d; | |||
| } | |||
| @@ -278,40 +281,42 @@ void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_di | |||
| } | |||
| } | |||
| template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, | |||
| const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, | |||
| template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, | |||
| bool *y, cudaStream_t stream); | |||
| template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, | |||
| const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, | |||
| template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, | |||
| bool *y, cudaStream_t stream); | |||
| template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, | |||
| const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, | |||
| template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, | |||
| bool *y, cudaStream_t stream); | |||
| // Broadcast Arithmetic | |||
| template <typename T, typename Func> | |||
| __global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, | |||
| const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, | |||
| const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, | |||
| const int d4, const int d5, const int d6, const T *x0, const T *x1, T *y) { | |||
| __global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, | |||
| const size_t l4, const size_t l5, const size_t l6, const size_t r0, | |||
| const size_t r1, const size_t r2, const size_t r3, const size_t r4, | |||
| const size_t r5, const size_t r6, const size_t d0, const size_t d1, | |||
| const size_t d2, const size_t d3, const size_t d4, const size_t d5, | |||
| const size_t d6, const T *x0, const T *x1, T *y) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; | |||
| pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; | |||
| int j = pos / (d2 * d3 * d4 * d5 * d6) % d1; | |||
| int k = pos / (d3 * d4 * d5 * d6) % d2; | |||
| int l = pos / (d4 * d5 * d6) % d3; | |||
| int m = pos / (d5 * d6) % d4; | |||
| int n = pos / d6 % d5; | |||
| int o = pos % d6; | |||
| int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; | |||
| size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; | |||
| size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1; | |||
| size_t k = pos / (d3 * d4 * d5 * d6) % d2; | |||
| size_t l = pos / (d4 * d5 * d6) % d3; | |||
| size_t m = pos / (d5 * d6) % d4; | |||
| size_t n = pos / d6 % d5; | |||
| size_t o = pos % d6; | |||
| size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; | |||
| l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6; | |||
| l_index += Index(k, l2) * l3 * l4 * l5 * l6; | |||
| l_index += Index(l, l3) * l4 * l5 * l6; | |||
| l_index += Index(m, l4) * l5 * l6; | |||
| l_index += Index(n, l5) * l6; | |||
| l_index += Index(o, l6); | |||
| int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; | |||
| size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; | |||
| r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6; | |||
| r_index += Index(k, r2) * r3 * r4 * r5 * r6; | |||
| r_index += Index(l, r3) * r4 * r5 * r6; | |||
| @@ -323,9 +328,10 @@ __global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, c | |||
| } | |||
| template <typename T> | |||
| void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims, | |||
| enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) { | |||
| int size = 1; | |||
| void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, | |||
| const T *x1, T *y, cudaStream_t stream) { | |||
| size_t size = 1; | |||
| for (auto d : y_dims) { | |||
| size *= d; | |||
| } | |||
| @@ -385,41 +391,44 @@ void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_ | |||
| } | |||
| } | |||
| template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, | |||
| const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, | |||
| float *y, cudaStream_t stream); | |||
| template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, | |||
| const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, | |||
| template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0, | |||
| const float *x1, float *y, cudaStream_t stream); | |||
| template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, | |||
| half *y, cudaStream_t stream); | |||
| template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, | |||
| const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, | |||
| template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, | |||
| int *y, cudaStream_t stream); | |||
| // BroadcastTo | |||
| template <typename T> | |||
| __global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, const int o1, | |||
| const int o2, const int o3, const T *input_addr, T *output_addr) { | |||
| __global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t i2, const size_t i3, const size_t o0, | |||
| const size_t o1, const size_t o2, const size_t o3, const T *input_addr, | |||
| T *output_addr) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (o1 * o2 * o3) % o0; | |||
| int j = pos / (o2 * o3) % o1; | |||
| int k = pos / o3 % o2; | |||
| int l = pos % o3; | |||
| size_t i = pos / (o1 * o2 * o3) % o0; | |||
| size_t j = pos / (o2 * o3) % o1; | |||
| size_t k = pos / o3 % o2; | |||
| size_t l = pos % o3; | |||
| int input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3); | |||
| size_t input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3); | |||
| output_addr[pos] = input_addr[input_idx]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { | |||
| int nums = o0 * o1 * o2 * o3; | |||
| void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, | |||
| const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, | |||
| T *output_addr, cudaStream_t stream) { | |||
| size_t nums = o0 * o1 * o2 * o3; | |||
| BroadcastToKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, | |||
| output_addr); | |||
| } | |||
| template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const float *input_addr, float *output_addr, | |||
| cudaStream_t stream); | |||
| template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const half *input_addr, half *output_addr, cudaStream_t stream); | |||
| template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, | |||
| const size_t &o1, const size_t &o2, const size_t &o3, const float *input_addr, | |||
| float *output_addr, cudaStream_t stream); | |||
| template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, | |||
| const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr, | |||
| half *output_addr, cudaStream_t stream); | |||
| @@ -43,14 +43,17 @@ template <typename T> | |||
| void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream); | |||
| template <typename T> | |||
| void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims, | |||
| enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream); | |||
| void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims, | |||
| enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream); | |||
| void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, | |||
| const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream); | |||
| void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0, | |||
| const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, T *output_addr, | |||
| cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ | |||
| @@ -21,16 +21,17 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" | |||
| template <typename T> | |||
| __global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| __global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const size_t s4, | |||
| const size_t l1, const size_t l2, const size_t l3, const size_t l4, | |||
| const size_t d1, const size_t d2, const size_t d3, const size_t d4, | |||
| const T *input, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (l2 * l3 * l4) % l1; | |||
| int j = pos / (l3 * l4) % l2; | |||
| int k = pos / l4 % l3; | |||
| int o = pos % l4; | |||
| size_t i = pos / (l2 * l3 * l4) % l1; | |||
| size_t j = pos / (l3 * l4) % l2; | |||
| size_t k = pos / l4 % l3; | |||
| size_t o = pos % l4; | |||
| int offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4); | |||
| size_t offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4); | |||
| output[pos] = input[offset]; | |||
| } | |||
| } | |||
| @@ -56,18 +57,19 @@ void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaSt | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, | |||
| const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, | |||
| cudaStream_t stream) { | |||
| void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const T *input, T *output, cudaStream_t stream) { | |||
| Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4, | |||
| input, output); | |||
| } | |||
| template <typename T> | |||
| void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<int> in_shape, const std::vector<int> begin, | |||
| const std::vector<int> size, T *output, cudaStream_t cuda_stream) { | |||
| int block = in_shape[1] * in_shape[2] * in_shape[3]; | |||
| int map = in_shape[2] * in_shape[3]; | |||
| int w = in_shape[3]; | |||
| void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, T *output, | |||
| cudaStream_t cuda_stream) { | |||
| size_t block = in_shape[1] * in_shape[2] * in_shape[3]; | |||
| size_t map = in_shape[2] * in_shape[3]; | |||
| size_t w = in_shape[3]; | |||
| int length = size[3]; | |||
| int p = 0; | |||
| for (int i = begin[0]; i < size[0] + begin[0]; i++) { | |||
| @@ -82,23 +84,24 @@ void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<int> i | |||
| } | |||
| template <typename T> | |||
| __global__ void StridedSliceKernel(const int b0, const int b1, const int b2, const int b3, const int b4, | |||
| const int b5, const int b6, const int s0, const int s1, const int s2, | |||
| const int s3, const int s4, const int s5, const int s6, const int i0, | |||
| const int i1, const int i2, const int i3, const int i4, const int i5, | |||
| const int i6, const int o0, const int o1, const int o2, const int o3, | |||
| const int o4, const int o5, const int o6, const T *input_addr, T *output_addr) { | |||
| __global__ void StridedSliceKernel(const size_t b0, const size_t b1, const size_t b2, const size_t b3, const size_t b4, | |||
| const size_t b5, const size_t b6, const size_t s0, const size_t s1, const size_t s2, | |||
| const size_t s3, const size_t s4, const size_t s5, const size_t s6, const size_t i0, | |||
| const size_t i1, const size_t i2, const size_t i3, const size_t i4, const size_t i5, | |||
| const size_t i6, const size_t o0, const size_t o1, const size_t o2, const size_t o3, | |||
| const size_t o4, const size_t o5, const size_t o6, | |||
| const T *input_addr, T *output_addr) { | |||
| int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; | |||
| int j = pos / (o2 * o3 * o4 * o5 * o6) % o1; | |||
| int k = pos / (o3 * o4 * o5 * o6) % o2; | |||
| int l = pos / (o4 * o5 * o6) % o3; | |||
| int m = pos / (o5 * o6) % o4; | |||
| int n = pos / (o6) % o5; | |||
| int o = pos % o6; | |||
| int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ | |||
| size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; | |||
| size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1; | |||
| size_t k = pos / (o3 * o4 * o5 * o6) % o2; | |||
| size_t l = pos / (o4 * o5 * o6) % o3; | |||
| size_t m = pos / (o5 * o6) % o4; | |||
| size_t n = pos / (o6) % o5; | |||
| size_t o = pos % o6; | |||
| size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ | |||
| + (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \ | |||
| + (n * s5 + b5) * i6 + (o * s6 + b6); | |||
| output_addr[pos] = input_addr[input_idx]; | |||
| @@ -107,10 +110,10 @@ __global__ void StridedSliceKernel(const int b0, const int b1, const int b2, con | |||
| template <typename T> | |||
| void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const T *input, T *output, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, const T *input, T *output, | |||
| cudaStream_t cuda_stream) { | |||
| int size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \ | |||
| * output_shape[4] * output_shape[5] * output_shape[6]; | |||
| size_t size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \ | |||
| * output_shape[4] * output_shape[5] * output_shape[6]; | |||
| StridedSliceKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||
| begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], | |||
| strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6], | |||
| @@ -120,23 +123,25 @@ void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> | |||
| } | |||
| template <typename T> | |||
| __global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, const int b3, const int b4, | |||
| const int b5, const int b6, const int s0, const int s1, const int s2, | |||
| const int s3, const int s4, const int s5, const int s6, const int i0, | |||
| const int i1, const int i2, const int i3, const int i4, const int i5, | |||
| const int i6, const int o0, const int o1, const int o2, const int o3, | |||
| const int o4, const int o5, const int o6, const T *dy, T *dx) { | |||
| __global__ void StridedSliceGradKernel(const size_t b0, const size_t b1, const size_t b2, const size_t b3, | |||
| const size_t b4, const size_t b5, const size_t b6, const size_t s0, | |||
| const size_t s1, const size_t s2, const size_t s3, const size_t s4, | |||
| const size_t s5, const size_t s6, const size_t i0, const size_t i1, | |||
| const size_t i2, const size_t i3, const size_t i4, const size_t i5, | |||
| const size_t i6, const size_t o0, const size_t o1, const size_t o2, | |||
| const size_t o3, const size_t o4, const size_t o5, const size_t o6, | |||
| const T *dy, T *dx) { | |||
| int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; | |||
| int j = pos / (o2 * o3 * o4 * o5 * o6) % o1; | |||
| int k = pos / (o3 * o4 * o5 * o6) % o2; | |||
| int l = pos / (o4 * o5 * o6) % o3; | |||
| int m = pos / (o5 * o6) % o4; | |||
| int n = pos / (o6) % o5; | |||
| int o = pos % o6; | |||
| int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ | |||
| size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0; | |||
| size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1; | |||
| size_t k = pos / (o3 * o4 * o5 * o6) % o2; | |||
| size_t l = pos / (o4 * o5 * o6) % o3; | |||
| size_t m = pos / (o5 * o6) % o4; | |||
| size_t n = pos / (o6) % o5; | |||
| size_t o = pos % o6; | |||
| size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \ | |||
| + (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \ | |||
| + (n * s5 + b5) * i6 + (o * s6 + b6); | |||
| dx[input_idx] = dy[pos]; | |||
| @@ -145,9 +150,10 @@ __global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, | |||
| } | |||
| template <typename T> | |||
| void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, const std::vector<int> &strides, | |||
| const std::vector<int> &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream) { | |||
| int size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6]; | |||
| void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, | |||
| const T *dy, T *dx, cudaStream_t cuda_stream) { | |||
| size_t size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6]; | |||
| StridedSliceGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||
| begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], | |||
| strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6], | |||
| @@ -157,88 +163,89 @@ void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> & | |||
| } | |||
| template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const float *input, float *output, cudaStream_t stream); | |||
| template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<int> in_shape, | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream); | |||
| template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const half *input, half *output, cudaStream_t stream); | |||
| template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<int> in_shape, | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream); | |||
| template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, half *output, | |||
| cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<int>(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const int *input, int *output, cudaStream_t stream); | |||
| template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<int> in_shape, | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream); | |||
| template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, int *output, | |||
| cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<short>(const size_t input_size, short *addr, const float value, cudaStream_t cuda_stream); // NOLINT | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const short *input, short *output, cudaStream_t stream); // NOLINT | |||
| template void CalSliceGrad<short>(const size_t input_size, const short *dy, const std::vector<int> in_shape, // NOLINT | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const short *input, short *output, cudaStream_t stream); // NOLINT | |||
| template void CalSliceGrad<short>(const size_t input_size, const short *dy, const std::vector<size_t> in_shape, // NOLINT | |||
| const std::vector<int> begin, const std::vector<int> size, short *output, // NOLINT | |||
| cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<unsigned char>(const size_t input_size, unsigned char *addr, const float value, | |||
| cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const unsigned char *input, unsigned char *output, cudaStream_t stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output, | |||
| cudaStream_t stream); | |||
| template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy, | |||
| const std::vector<int> in_shape, const std::vector<int> begin, | |||
| const std::vector<size_t> in_shape, const std::vector<int> begin, | |||
| const std::vector<int> size, unsigned char *output, cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const bool *input, bool *output, cudaStream_t stream); | |||
| template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<int> in_shape, | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); | |||
| template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, bool *output, | |||
| cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const float *input, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, const float *input, | |||
| float *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const half *input, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, const half *input, | |||
| half *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const int *input, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, const int *input, | |||
| int *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, | |||
| const short *input, short *output, cudaStream_t cuda_stream); // NOLINT | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, | |||
| const unsigned char *input, unsigned char *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const bool *input, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, const bool *input, | |||
| bool *output, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const float *dy, | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const float *dy, | |||
| float *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const half *dy, | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const half *dy, | |||
| half *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const int *dy, | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const int *dy, | |||
| int *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const short *dy, // NOLINT | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const short *dy, // NOLINT | |||
| short *dx, cudaStream_t cuda_stream); // NOLINT | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, | |||
| const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const bool *dy, | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const bool *dy, | |||
| bool *dx, cudaStream_t cuda_stream); | |||
| @@ -22,19 +22,20 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, | |||
| const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, | |||
| cudaStream_t stream); | |||
| void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, | |||
| const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, | |||
| const T *input, T *output, cudaStream_t stream); | |||
| template <typename T> | |||
| void CalSliceGrad(const size_t input_size, const T *input, const std::vector<int> in_shape, | |||
| void CalSliceGrad(const size_t input_size, const T *input, const std::vector<size_t> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, T *output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const T *input, T *output, | |||
| const std::vector<int> &strides, const std::vector<size_t> &output_shape, const T *input, T *output, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, const std::vector<int> &strides, | |||
| const std::vector<int> &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream); | |||
| void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const T *dy, T *dx, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ | |||
| @@ -20,19 +20,19 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, | |||
| const int shape_size, T* output) { | |||
| int pos_size; | |||
| int temp_pos; | |||
| int newpos; | |||
| int newpos_size; | |||
| int pos_array[TRANSPOSE_MAX_DIMENSION]; | |||
| __global__ void Transpose(const size_t size, const T* input, const size_t* input_shape, | |||
| const size_t* input_axis, const size_t shape_size, T* output) { | |||
| size_t pos_size; | |||
| size_t temp_pos; | |||
| size_t newpos; | |||
| size_t newpos_size; | |||
| size_t pos_array[TRANSPOSE_MAX_DIMENSION]; | |||
| // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + | |||
| // posArray[1] * input_shape[2] * input_shape[3] + | |||
| // posArray[2] * input_shape[3] + | |||
| // posArray[3] | |||
| for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| temp_pos = pos; | |||
| pos_size = size / input_shape[0]; | |||
| pos_array[0] = temp_pos / pos_size; | |||
| @@ -54,16 +54,19 @@ __global__ void Transpose(const int size, const T* input, const int* input_shape | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, | |||
| T* output, cudaStream_t cuda_stream) { | |||
| void CalTranspose(const size_t size, const T* input, const size_t* input_shape, const size_t* input_axis, | |||
| const size_t shape_size, T* output, cudaStream_t cuda_stream) { | |||
| Transpose<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, input_shape, input_axis, shape_size, | |||
| output); | |||
| return; | |||
| } | |||
| template void CalTranspose<float>(const int size, const float* input, const int* input_shape, const int* input_axis, | |||
| const int shape_size, float* output, cudaStream_t cuda_stream); | |||
| template void CalTranspose<half>(const int size, const half* input, const int* input_shape, const int* input_axis, | |||
| const int shape_size, half* output, cudaStream_t cuda_stream); | |||
| template void CalTranspose<int>(const int size, const int* input, const int* input_shape, const int* input_axis, | |||
| const int shape_size, int* output, cudaStream_t cuda_stream); | |||
| template void CalTranspose<float>(const size_t size, const float* input, const size_t* input_shape, | |||
| const size_t* input_axis, const size_t shape_size, float* output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalTranspose<half>(const size_t size, const half* input, const size_t* input_shape, | |||
| const size_t* input_axis, const size_t shape_size, half* output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalTranspose<int>(const size_t size, const int* input, const size_t* input_shape, | |||
| const size_t* input_axis, const size_t shape_size, int* output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -19,7 +19,7 @@ | |||
| #define TRANSPOSE_MAX_DIMENSION 100 | |||
| template <typename T> | |||
| void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, | |||
| T* output, cudaStream_t cuda_stream); | |||
| void CalTranspose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis, | |||
| const size_t shape_size, T *output, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_H_ | |||
| @@ -87,14 +87,14 @@ class GpuKernel : public KernelMod { | |||
| return GetValue<T>(attr); | |||
| } | |||
| // expand Nd Shape to 4d (N in [0,4]) | |||
| void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) { | |||
| void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<size_t> *dst) { | |||
| if (src.size() > 4) { | |||
| MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!"; | |||
| } | |||
| dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4])); | |||
| dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3])); | |||
| dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2])); | |||
| dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); | |||
| dst->push_back(src.size() < 4 ? 1 : src[src.size() - 4]); | |||
| dst->push_back(src.size() < 3 ? 1 : src[src.size() - 3]); | |||
| dst->push_back(src.size() < 2 ? 1 : src[src.size() - 2]); | |||
| dst->push_back(src.size() == 0 ? 1 : src[src.size() - 1]); | |||
| } | |||
| int AxisTransform(const std::string &origin_data_format, const std::string &cal_format, int axis) { | |||
| @@ -162,9 +162,9 @@ class BroadcastOpGpuKernel : public GpuKernel { | |||
| int input1_num_; | |||
| int input2_num_; | |||
| int output_num_; | |||
| std::vector<int> lhs_shape_; | |||
| std::vector<int> rhs_shape_; | |||
| std::vector<int> output_shape_; | |||
| std::vector<size_t> lhs_shape_; | |||
| std::vector<size_t> rhs_shape_; | |||
| std::vector<size_t> output_shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -140,12 +140,12 @@ class BroadcastOpGradGpuKernel : public GpuKernel { | |||
| BroadcastGradOpType op_type_; | |||
| bool need_broadcast_; | |||
| int input1_num_; | |||
| int input2_num_; | |||
| int output_num_; | |||
| int x1_shape_[4] = {1, 1, 1, 1}; | |||
| int x2_shape_[4] = {1, 1, 1, 1}; | |||
| int dy_shape_[4] = {1, 1, 1, 1}; | |||
| size_t input1_num_; | |||
| size_t input2_num_; | |||
| size_t output_num_; | |||
| size_t x1_shape_[4] = {1, 1, 1, 1}; | |||
| size_t x2_shape_[4] = {1, 1, 1, 1}; | |||
| size_t dy_shape_[4] = {1, 1, 1, 1}; | |||
| bool grad_x_; | |||
| bool grad_y_; | |||
| @@ -82,7 +82,7 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| std::vector<int> shape; | |||
| std::vector<size_t> shape; | |||
| double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef), | |||
| "cudnnSetActivationDescriptor failed"); | |||
| @@ -91,13 +91,15 @@ class ActivationGpuFwdKernel : public GpuKernel { | |||
| if (input_shape.size() <= split_dim) { | |||
| ShapeNdTo4d(input_shape, &shape); | |||
| if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, | |||
| shape[0], shape[3], shape[1], shape[2]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]), | |||
| SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | |||
| shape[0], shape[1], shape[2], shape[3]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]), | |||
| SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } | |||
| } else { | |||
| CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); | |||
| @@ -89,7 +89,7 @@ class ActivationGradGpuKernel : public GpuKernel { | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| std::vector<int> shape; | |||
| std::vector<size_t> shape; | |||
| double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef), | |||
| "SetActivationDescriptor failed"); | |||
| @@ -98,13 +98,15 @@ class ActivationGradGpuKernel : public GpuKernel { | |||
| if (input_shape.size() <= split_dim) { | |||
| ShapeNdTo4d(input_shape, &shape); | |||
| if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, | |||
| shape[0], shape[3], shape[1], shape[2]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]), | |||
| SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | |||
| shape[0], shape[1], shape[2], shape[3]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]), | |||
| SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| } | |||
| } else { | |||
| CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_); | |||
| @@ -68,9 +68,9 @@ class SoftmaxGpuKernel : public GpuKernel { | |||
| } else { | |||
| T *transpose_input_addr = GetDeviceAddress<T>(workspace, 0); | |||
| T *transpose_output_addr = GetDeviceAddress<T>(workspace, 1); | |||
| int *input_shape = GetDeviceAddress<int>(workspace, 2); | |||
| int *transpose_shape = GetDeviceAddress<int>(workspace, 3); | |||
| int *transpose_axis = GetDeviceAddress<int>(workspace, 4); | |||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 2); | |||
| size_t *transpose_shape = GetDeviceAddress<size_t>(workspace, 3); | |||
| size_t *transpose_axis = GetDeviceAddress<size_t>(workspace, 4); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_shape failed"); | |||
| @@ -80,7 +80,7 @@ class SoftmaxGpuKernel : public GpuKernel { | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_axis failed"); | |||
| int size = SizeToInt(input_size_ / sizeof(T)); | |||
| size_t size = input_size_ / sizeof(T); | |||
| CalTranspose(size, input_addr, input_shape, transpose_axis, shape_size_, transpose_input_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| @@ -113,7 +113,7 @@ class SoftmaxGpuKernel : public GpuKernel { | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| shape_size_ = SizeToInt(input_shape.size()); | |||
| shape_size_ = input_shape.size(); | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == "LogSoftmax") { | |||
| algo_ = CUDNN_SOFTMAX_LOG; | |||
| @@ -171,7 +171,7 @@ class SoftmaxGpuKernel : public GpuKernel { | |||
| void InitSizeByAxis2D(const std::vector<size_t> &input_shape, const int &axis) { | |||
| axis_ = axis; | |||
| if (axis_ < 0) { | |||
| axis_ += shape_size_; | |||
| axis_ += SizeToInt(shape_size_); | |||
| } | |||
| if (axis_ == 1) { | |||
| batch_size_ = input_shape[0]; | |||
| @@ -193,7 +193,7 @@ class SoftmaxGpuKernel : public GpuKernel { | |||
| width_ = 1; | |||
| input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; | |||
| output_size_ = input_size_; | |||
| workspace_size_ = IntToSize(shape_size_) * sizeof(int); | |||
| workspace_size_ = shape_size_ * sizeof(size_t); | |||
| } | |||
| void InitSizeByAxisLastDim(const std::vector<size_t> &input_shape, const int &axis) { | |||
| @@ -235,11 +235,11 @@ class SoftmaxGpuKernel : public GpuKernel { | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| std::vector<int> input_shape_; | |||
| std::vector<int> transpose_shape_; | |||
| std::vector<int> transpose_axis_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> transpose_shape_; | |||
| std::vector<size_t> transpose_axis_; | |||
| int axis_; | |||
| int shape_size_; | |||
| size_t shape_size_; | |||
| size_t batch_size_; | |||
| size_t channel_size_; | |||
| @@ -62,9 +62,9 @@ class SoftmaxGradGpuKernel : public GpuKernel { | |||
| T *transpose_y_addr = GetDeviceAddress<T>(workspace, 0); | |||
| T *transpose_dy_addr = GetDeviceAddress<T>(workspace, 1); | |||
| T *transpose_dx_addr = GetDeviceAddress<T>(workspace, 2); | |||
| int *input_shape = GetDeviceAddress<int>(workspace, 3); | |||
| int *transpose_shape = GetDeviceAddress<int>(workspace, 4); | |||
| int *transpose_axis = GetDeviceAddress<int>(workspace, 5); | |||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 3); | |||
| size_t *transpose_shape = GetDeviceAddress<size_t>(workspace, 4); | |||
| size_t *transpose_axis = GetDeviceAddress<size_t>(workspace, 5); | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| @@ -82,7 +82,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_axis failed"); | |||
| int size = SizeToInt(input_size_ / sizeof(T)); | |||
| size_t size = input_size_ / sizeof(T); | |||
| CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr, | |||
| @@ -116,7 +116,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| shape_size_ = SizeToInt(input_shape.size()); | |||
| shape_size_ = input_shape.size(); | |||
| if (shape_size_ != 2) { | |||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs."; | |||
| } | |||
| @@ -164,7 +164,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { | |||
| void InitSizeByAxis(const std::vector<size_t> input_shape, const int axis) { | |||
| axis_ = axis; | |||
| if (axis_ < 0) { | |||
| axis_ += shape_size_; | |||
| axis_ += SizeToInt(shape_size_); | |||
| } | |||
| if (axis_ == 1) { | |||
| batch_size_ = input_shape[0]; | |||
| @@ -186,7 +186,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { | |||
| width_ = 1; | |||
| input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; | |||
| output_size_ = input_size_; | |||
| workspace_size_ = IntToSize(shape_size_) * sizeof(int); | |||
| workspace_size_ = shape_size_ * sizeof(size_t); | |||
| } | |||
| cudnnHandle_t cudnn_handle_; | |||
| @@ -202,11 +202,11 @@ class SoftmaxGradGpuKernel : public GpuKernel { | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| std::vector<int> input_shape_; | |||
| std::vector<int> transpose_shape_; | |||
| std::vector<int> transpose_axis_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> transpose_shape_; | |||
| std::vector<size_t> transpose_axis_; | |||
| int axis_; | |||
| int shape_size_; | |||
| size_t shape_size_; | |||
| size_t batch_size_; | |||
| size_t channel_size_; | |||