Browse Source

!7106 GPU change kernel shape to size_t

Merge pull request !7106 from VectorSL/gpu-size_t
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
483f1aca9d
23 changed files with 325 additions and 287 deletions
  1. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h
  2. +10
    -8
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h
  3. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h
  4. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h
  5. +8
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h
  6. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h
  7. +8
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h
  8. +7
    -8
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h
  9. +18
    -14
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu
  10. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh
  11. +72
    -63
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
  12. +9
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
  13. +100
    -93
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu
  14. +8
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh
  15. +19
    -16
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu
  16. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh
  17. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  18. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
  19. +6
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h
  20. +9
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h
  21. +9
    -7
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h
  22. +11
    -11
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h
  23. +11
    -11
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h

+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h View File

@@ -86,9 +86,9 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int bound_;
int outerSize_;
int innerSize_;
size_t bound_;
size_t outerSize_;
size_t innerSize_;
};
} // namespace kernel
} // namespace mindspore


+ 10
- 8
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h View File

@@ -180,15 +180,16 @@ class ArrayReduceGpuKernel : public GpuKernel {
return;
}
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
std::vector<int> inputA;
std::vector<size_t> inputA;
std::vector<size_t> outputC_shape = output_shape;
const int split_dim = 4;

if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &inputA);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
inputA[0], inputA[1], inputA[2], inputA[3]),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, SizeToInt(inputA[0]),
SizeToInt(inputA[1]), SizeToInt(inputA[2]), SizeToInt(inputA[3])),
"cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_);
for (auto dim : input_shape) {
@@ -216,7 +217,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
return;
}

std::vector<int> outputC;
std::vector<size_t> outputC;
if (!keep_dims_) {
for (auto i : axis_) {
(void)(outputC_shape.insert(outputC_shape.begin() + i, 1));
@@ -225,9 +226,10 @@ class ArrayReduceGpuKernel : public GpuKernel {

if (outputC_shape.size() <= split_dim) {
ShapeNdTo4d(outputC_shape, &outputC);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
outputC[0], outputC[1], outputC[2], outputC[3]),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, SizeToInt(outputC[0]),
SizeToInt(outputC[1]), SizeToInt(outputC[2]), SizeToInt(outputC[3])),
"cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
for (auto dim : outputC_shape) {


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h View File

@@ -71,8 +71,8 @@ class BroadcastToGpuKernel : public GpuKernel {
}

private:
int input_shape_[4] = {1, 1, 1, 1};
int output_shape_[4] = {1, 1, 1, 1};
size_t input_shape_[4] = {1, 1, 1, 1};
size_t output_shape_[4] = {1, 1, 1, 1};

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h View File

@@ -61,7 +61,7 @@ class SliceGpuFwdKernel : public GpuKernel {
(void)size_.insert(size_.begin(), 1);
}

input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T);
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T);
auto out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);

output_size_ = sizeof(T);
@@ -118,7 +118,7 @@ class SliceGpuFwdKernel : public GpuKernel {
}
std::vector<int> begin_;
std::vector<int> size_;
std::vector<int> input_shape_;
std::vector<size_t> input_shape_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;


+ 8
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h View File

@@ -50,7 +50,10 @@ class SliceGradGpuKernel : public GpuKernel {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "StridedSliceGrad") {
is_strided_slice_ = true;
input_shape_ = GetAttr<std::vector<int>>(kernel_node, "shapex");
auto shapex = GetAttr<std::vector<int>>(kernel_node, "shapex");
for (auto x : shapex) {
input_shape_.push_back(IntToSize(x));
}
for (auto i = input_shape_.size(); i < 4; i++) {
(void)input_shape_.insert(input_shape_.begin(), 1);
}
@@ -69,11 +72,11 @@ class SliceGradGpuKernel : public GpuKernel {
ShapeNdTo4d(dy_shape, &dy_shape_);
begin_ = GetAttr<std::vector<int>>(kernel_node, "begin");
DealParam();
input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T);
input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T);

output_size_ = sizeof(T);
for (auto x : dy_shape_) {
output_size_ = output_size_ * IntToSize(x);
output_size_ = output_size_ * x;
}
InitSizeLists();
return true;
@@ -125,8 +128,8 @@ class SliceGradGpuKernel : public GpuKernel {
std::vector<int> begin_;
std::vector<int> size_;
std::vector<int> strides_;
std::vector<int> input_shape_;
std::vector<int> dy_shape_;
std::vector<size_t> input_shape_;
std::vector<size_t> dy_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h View File

@@ -72,7 +72,7 @@ class StridedSliceGpuKernel : public GpuKernel {
}
input_size_list_.push_back(size);

int size1 = sizeof(T);
size_t size1 = sizeof(T);
for (size_t i = 0; i < MAX_DIMS; i++) {
size1 *= output_shape_[i];
}
@@ -188,7 +188,7 @@ class StridedSliceGpuKernel : public GpuKernel {
std::vector<int> end_;
std::vector<int> strides_;
std::vector<size_t> input_shape_;
std::vector<int> output_shape_;
std::vector<size_t> output_shape_;
int null_output_;

std::vector<size_t> input_size_list_;


+ 8
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h View File

@@ -50,7 +50,10 @@ class StridedSliceGradGpuKernel : public GpuKernel {
return true;
}
bool Init(const CNodePtr &kernel_node) override {
input_shape_ = GetAttr<std::vector<int>>(kernel_node, "shapex");
auto shapex = GetAttr<std::vector<int>>(kernel_node, "shapex");
for (auto x : shapex) {
input_shape_.push_back(IntToSize(x));
}
if (input_shape_.size() > MAX_DIMS) {
MS_LOG(ERROR) << "StridedSliceGrad support support dims less than " << input_shape_.size();
return false;
@@ -66,13 +69,13 @@ class StridedSliceGradGpuKernel : public GpuKernel {

protected:
void InitSizeLists() override {
int size = sizeof(T);
size_t size = sizeof(T);
for (size_t i = 0; i < MAX_DIMS; i++) {
size *= output_shape_[i];
}
input_size_list_.push_back(size);

int size1 = sizeof(T);
size_t size1 = sizeof(T);
for (size_t i = 0; i < MAX_DIMS; i++) {
size1 *= input_shape_[i];
}
@@ -187,8 +190,8 @@ class StridedSliceGradGpuKernel : public GpuKernel {
std::vector<int> begin_;
std::vector<int> end_;
std::vector<int> strides_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
int null_output_;

std::vector<size_t> input_size_list_;


+ 7
- 8
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h View File

@@ -37,17 +37,16 @@ class TransposeGpuFwdKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
int *input_shape = GetDeviceAddress<int>(workspace, 0);
int *input_axis = GetDeviceAddress<int>(workspace, 1);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis failed");
int size = SizeToInt(input_size_ / sizeof(T));
CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output,
reinterpret_cast<cudaStream_t>(stream_ptr));
size_t size = input_size_ / sizeof(T);
CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@@ -88,15 +87,15 @@ class TransposeGpuFwdKernel : public GpuKernel {
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
workspace_size_ = shape_size_ * sizeof(int);
workspace_size_ = shape_size_ * sizeof(size_t);
workspace_size_list_.push_back(workspace_size_);
workspace_size_list_.push_back(workspace_size_);
return;
}
private:
std::vector<int> input_shape_;
std::vector<int> input_axis_;
std::vector<size_t> input_shape_;
std::vector<size_t> input_axis_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;


+ 18
- 14
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu View File

@@ -18,14 +18,16 @@
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
template <typename T, typename S>
__global__ void ArgmaxWithValue(const T *input, const int bound, int outerSize, int innerSize, S *index, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize; pos += gridDim.x * blockDim.x) {
int x = pos / innerSize % outerSize;
int y = pos % innerSize;
__global__ void ArgmaxWithValue(const T *input, const size_t bound, size_t outerSize,
size_t innerSize, S *index, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize;
pos += gridDim.x * blockDim.x) {
size_t x = pos / innerSize % outerSize;
size_t y = pos % innerSize;
S idx = 0;
int InputOffset = x * bound * innerSize + 0 * innerSize + y;
size_t InputOffset = x * bound * innerSize + 0 * innerSize + y;
T maxData = input[InputOffset];
for (int i = 0; i < bound; i++) {
for (size_t i = 0; i < bound; i++) {
InputOffset = x * bound * innerSize + i * innerSize + y;
auto inputData = input[InputOffset];
idx = inputData > maxData ? i : idx;
@@ -38,14 +40,16 @@ __global__ void ArgmaxWithValue(const T *input, const int bound, int outerSize,
}

template <typename T, typename S>
void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index,
T *output, cudaStream_t cuda_stream) {
ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_, index,
output);
void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_,
S *index, T *output, cudaStream_t cuda_stream) {
ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_,
index, output);
return;
}

template void CalArgmaxWithValue<float, int>(const float *input, const int bound_, const int outerSize_,
const int innerSize_, int *index, float *output, cudaStream_t cuda_stream);
template void CalArgmaxWithValue<half, int>(const half *input, const int bound_, const int outerSize_,
const int innerSize_, int *index, half *output, cudaStream_t cuda_stream);
template void CalArgmaxWithValue<float, int>(const float *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, int *index, float *output,
cudaStream_t cuda_stream);
template void CalArgmaxWithValue<half, int>(const half *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, int *index, half *output,
cudaStream_t cuda_stream);

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh View File

@@ -17,6 +17,6 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
template <typename T, typename S>
void CalArgmaxWithValue(const T *input, const int bound_, const int outerSize_, const int innerSize_, S *index,
void CalArgmaxWithValue(const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, S *index,
T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_

+ 72
- 63
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu View File

@@ -219,31 +219,33 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *
cudaStream_t stream);

// Broadcast comparation
__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; }
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }

template <typename T, typename Func>
__global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
const int d4, const int d5, const int d6, const T *x0, const T *x1, bool *y) {
__global__ void BroadcastCmpKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *x0, const T *x1, bool *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
pos += blockDim.x * gridDim.x) {
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
int k = pos / (d3 * d4 * d5 * d6) % d2;
int l = pos / (d4 * d5 * d6) % d3;
int m = pos / (d5 * d6) % d4;
int n = pos / d6 % d5;
int o = pos % d6;
int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
size_t k = pos / (d3 * d4 * d5 * d6) % d2;
size_t l = pos / (d4 * d5 * d6) % d3;
size_t m = pos / (d5 * d6) % d4;
size_t n = pos / d6 % d5;
size_t o = pos % d6;
size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
@@ -255,9 +257,10 @@ __global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, con
}

template <typename T>
void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) {
int size = 1;
void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0,
const T *x1, bool *y, cudaStream_t stream) {
size_t size = 1;
for (auto d : y_dims) {
size *= d;
}
@@ -278,40 +281,42 @@ void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_di
}
}

template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
bool *y, cudaStream_t stream);

// Broadcast Arithmetic
template <typename T, typename Func>
__global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
const int d4, const int d5, const int d6, const T *x0, const T *x1, T *y) {
__global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *x0, const T *x1, T *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
pos += blockDim.x * gridDim.x) {
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
int k = pos / (d3 * d4 * d5 * d6) % d2;
int l = pos / (d4 * d5 * d6) % d3;
int m = pos / (d5 * d6) % d4;
int n = pos / d6 % d5;
int o = pos % d6;
int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
size_t k = pos / (d3 * d4 * d5 * d6) % d2;
size_t l = pos / (d4 * d5 * d6) % d3;
size_t m = pos / (d5 * d6) % d4;
size_t n = pos / d6 % d5;
size_t o = pos % d6;
size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
@@ -323,9 +328,10 @@ __global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, c
}

template <typename T>
void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
int size = 1;
void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0,
const T *x1, T *y, cudaStream_t stream) {
size_t size = 1;
for (auto d : y_dims) {
size *= d;
}
@@ -385,41 +391,44 @@ void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_
}
}

template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
float *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0,
const float *x1, float *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
half *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
int *y, cudaStream_t stream);

// BroadcastTo
template <typename T>
__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, const int o1,
const int o2, const int o3, const T *input_addr, T *output_addr) {
__global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t i2, const size_t i3, const size_t o0,
const size_t o1, const size_t o2, const size_t o3, const T *input_addr,
T *output_addr) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) {
int i = pos / (o1 * o2 * o3) % o0;
int j = pos / (o2 * o3) % o1;
int k = pos / o3 % o2;
int l = pos % o3;
size_t i = pos / (o1 * o2 * o3) % o0;
size_t j = pos / (o2 * o3) % o1;
size_t k = pos / o3 % o2;
size_t l = pos % o3;

int input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3);
size_t input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3);
output_addr[pos] = input_addr[input_idx];
}
}

template <typename T>
void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) {
int nums = o0 * o1 * o2 * o3;
void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr,
T *output_addr, cudaStream_t stream) {
size_t nums = o0 * o1 * o2 * o3;
BroadcastToKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr,
output_addr);
}

template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
const int &o2, const int &o3, const float *input_addr, float *output_addr,
cudaStream_t stream);
template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
const int &o2, const int &o3, const half *input_addr, half *output_addr, cudaStream_t stream);
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const float *input_addr,
float *output_addr, cudaStream_t stream);
template void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const half *input_addr,
half *output_addr, cudaStream_t stream);

+ 9
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh View File

@@ -43,14 +43,17 @@ template <typename T>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);

template <typename T>
void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream);
void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, bool *y,
cudaStream_t stream);

template <typename T>
void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);
void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y,
cudaStream_t stream);

template <typename T>
void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream);
void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, T *output_addr,
cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_

+ 100
- 93
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu View File

@@ -21,16 +21,17 @@
#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh"

template <typename T>
__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
__global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const size_t s4,
const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const T *input, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) {
int i = pos / (l2 * l3 * l4) % l1;
int j = pos / (l3 * l4) % l2;
int k = pos / l4 % l3;
int o = pos % l4;
size_t i = pos / (l2 * l3 * l4) % l1;
size_t j = pos / (l3 * l4) % l2;
size_t k = pos / l4 % l3;
size_t o = pos % l4;

int offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4);
size_t offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4);
output[pos] = input[offset];
}
}
@@ -56,18 +57,19 @@ void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaSt
return;
}
template <typename T>
void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3,
const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output,
cudaStream_t stream) {
void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const T *input, T *output, cudaStream_t stream) {
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4,
input, output);
}
template <typename T>
void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<int> in_shape, const std::vector<int> begin,
const std::vector<int> size, T *output, cudaStream_t cuda_stream) {
int block = in_shape[1] * in_shape[2] * in_shape[3];
int map = in_shape[2] * in_shape[3];
int w = in_shape[3];
void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, T *output,
cudaStream_t cuda_stream) {
size_t block = in_shape[1] * in_shape[2] * in_shape[3];
size_t map = in_shape[2] * in_shape[3];
size_t w = in_shape[3];
int length = size[3];
int p = 0;
for (int i = begin[0]; i < size[0] + begin[0]; i++) {
@@ -82,23 +84,24 @@ void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<int> i
}

template <typename T>
__global__ void StridedSliceKernel(const int b0, const int b1, const int b2, const int b3, const int b4,
const int b5, const int b6, const int s0, const int s1, const int s2,
const int s3, const int s4, const int s5, const int s6, const int i0,
const int i1, const int i2, const int i3, const int i4, const int i5,
const int i6, const int o0, const int o1, const int o2, const int o3,
const int o4, const int o5, const int o6, const T *input_addr, T *output_addr) {
__global__ void StridedSliceKernel(const size_t b0, const size_t b1, const size_t b2, const size_t b3, const size_t b4,
const size_t b5, const size_t b6, const size_t s0, const size_t s1, const size_t s2,
const size_t s3, const size_t s4, const size_t s5, const size_t s6, const size_t i0,
const size_t i1, const size_t i2, const size_t i3, const size_t i4, const size_t i5,
const size_t i6, const size_t o0, const size_t o1, const size_t o2, const size_t o3,
const size_t o4, const size_t o5, const size_t o6,
const T *input_addr, T *output_addr) {
int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
int j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
int k = pos / (o3 * o4 * o5 * o6) % o2;
int l = pos / (o4 * o5 * o6) % o3;
int m = pos / (o5 * o6) % o4;
int n = pos / (o6) % o5;
int o = pos % o6;
int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
size_t k = pos / (o3 * o4 * o5 * o6) % o2;
size_t l = pos / (o4 * o5 * o6) % o3;
size_t m = pos / (o5 * o6) % o4;
size_t n = pos / (o6) % o5;
size_t o = pos % o6;
size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
+ (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \
+ (n * s5 + b5) * i6 + (o * s6 + b6);
output_addr[pos] = input_addr[input_idx];
@@ -107,10 +110,10 @@ __global__ void StridedSliceKernel(const int b0, const int b1, const int b2, con

template <typename T>
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const T *input, T *output,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const T *input, T *output,
cudaStream_t cuda_stream) {
int size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \
* output_shape[4] * output_shape[5] * output_shape[6];
size_t size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] \
* output_shape[4] * output_shape[5] * output_shape[6];
StridedSliceKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6],
strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6],
@@ -120,23 +123,25 @@ void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int>
}

template <typename T>
__global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, const int b3, const int b4,
const int b5, const int b6, const int s0, const int s1, const int s2,
const int s3, const int s4, const int s5, const int s6, const int i0,
const int i1, const int i2, const int i3, const int i4, const int i5,
const int i6, const int o0, const int o1, const int o2, const int o3,
const int o4, const int o5, const int o6, const T *dy, T *dx) {
__global__ void StridedSliceGradKernel(const size_t b0, const size_t b1, const size_t b2, const size_t b3,
const size_t b4, const size_t b5, const size_t b6, const size_t s0,
const size_t s1, const size_t s2, const size_t s3, const size_t s4,
const size_t s5, const size_t s6, const size_t i0, const size_t i1,
const size_t i2, const size_t i3, const size_t i4, const size_t i5,
const size_t i6, const size_t o0, const size_t o1, const size_t o2,
const size_t o3, const size_t o4, const size_t o5, const size_t o6,
const T *dy, T *dx) {
int output_num = o0 * o1 * o2 * o3 * o4 * o5 * o6;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
int i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
int j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
int k = pos / (o3 * o4 * o5 * o6) % o2;
int l = pos / (o4 * o5 * o6) % o3;
int m = pos / (o5 * o6) % o4;
int n = pos / (o6) % o5;
int o = pos % o6;
int input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6) % o0;
size_t j = pos / (o2 * o3 * o4 * o5 * o6) % o1;
size_t k = pos / (o3 * o4 * o5 * o6) % o2;
size_t l = pos / (o4 * o5 * o6) % o3;
size_t m = pos / (o5 * o6) % o4;
size_t n = pos / (o6) % o5;
size_t o = pos % o6;
size_t input_idx = (i * s0 + b0) * i1 * i2 * i3 * i4 * i5 * i6 + (j * s1 + b1) * i2 * i3 * i4 * i5 * i6 \
+ (k * s2 + b2) * i3 * i4 * i5 * i6 + (l * s3 + b3) * i4 * i5 * i6 + (m * s4 + b4) * i5 * i6 \
+ (n * s5 + b5) * i6 + (o * s6 + b6);
dx[input_idx] = dy[pos];
@@ -145,9 +150,10 @@ __global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2,
}

template <typename T>
void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, const std::vector<int> &strides,
const std::vector<int> &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream) {
int size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6];
void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape,
const T *dy, T *dx, cudaStream_t cuda_stream) {
size_t size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3] * dy_shape[4] * dy_shape[5] * dy_shape[6];
StridedSliceGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6],
strides[0], strides[1], strides[2], strides[3], strides[4], strides[5], strides[6],
@@ -157,88 +163,89 @@ void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &
}

template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const float *input, float *output, cudaStream_t stream);
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<int> in_shape,
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream);
template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, float *output,
cudaStream_t cuda_stream);

template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const half *input, half *output, cudaStream_t stream);
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<int> in_shape,
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream);
template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, half *output,
cudaStream_t cuda_stream);

template void FillDeviceArray<int>(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const int *input, int *output, cudaStream_t stream);
template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<int> in_shape,
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream);
template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, int *output,
cudaStream_t cuda_stream);

template void FillDeviceArray<short>(const size_t input_size, short *addr, const float value, cudaStream_t cuda_stream); // NOLINT
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const short *input, short *output, cudaStream_t stream); // NOLINT
template void CalSliceGrad<short>(const size_t input_size, const short *dy, const std::vector<int> in_shape, // NOLINT
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const short *input, short *output, cudaStream_t stream); // NOLINT
template void CalSliceGrad<short>(const size_t input_size, const short *dy, const std::vector<size_t> in_shape, // NOLINT
const std::vector<int> begin, const std::vector<int> size, short *output, // NOLINT
cudaStream_t cuda_stream);

template void FillDeviceArray<unsigned char>(const size_t input_size, unsigned char *addr, const float value,
cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const unsigned char *input, unsigned char *output, cudaStream_t stream);
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output,
cudaStream_t stream);
template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy,
const std::vector<int> in_shape, const std::vector<int> begin,
const std::vector<size_t> in_shape, const std::vector<int> begin,
const std::vector<int> size, unsigned char *output, cudaStream_t cuda_stream);

template void FillDeviceArray<bool>(const size_t input_size, bool *addr, const float value, cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2,
const int l3, const int l4, const int d1, const int d2, const int d3, const int d4,
const bool *input, bool *output, cudaStream_t stream);
template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<int> in_shape,
template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1,
const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2,
const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream);
template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, bool *output,
cudaStream_t cuda_stream);

template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const float *input,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const float *input,
float *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const half *input,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const half *input,
half *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const int *input,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const int *input,
int *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape,
const std::vector<int> &strides, const std::vector<size_t> &output_shape,
const short *input, short *output, cudaStream_t cuda_stream); // NOLINT
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape,
const std::vector<int> &strides, const std::vector<size_t> &output_shape,
const unsigned char *input, unsigned char *output, cudaStream_t cuda_stream);
template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const bool *input,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const bool *input,
bool *output, cudaStream_t cuda_stream);

template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const float *dy,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const float *dy,
float *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const half *dy,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const half *dy,
half *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const int *dy,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const int *dy,
int *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const short *dy, // NOLINT
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const short *dy, // NOLINT
short *dx, cudaStream_t cuda_stream); // NOLINT
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape,
const unsigned char *dy, unsigned char *dx, cudaStream_t cuda_stream);
template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &dx_shape, const bool *dy,
template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const bool *dy,
bool *dx, cudaStream_t cuda_stream);

+ 8
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh View File

@@ -22,19 +22,20 @@
#include "runtime/device/gpu/cuda_common.h"

template <typename T>
void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3,
const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output,
cudaStream_t stream);
void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2,
const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4,
const T *input, T *output, cudaStream_t stream);
template <typename T>
void CalSliceGrad(const size_t input_size, const T *input, const std::vector<int> in_shape,
void CalSliceGrad(const size_t input_size, const T *input, const std::vector<size_t> in_shape,
const std::vector<int> begin, const std::vector<int> size, T *output, cudaStream_t cuda_stream);
template <typename T>
void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<int> &output_shape, const T *input, T *output,
const std::vector<int> &strides, const std::vector<size_t> &output_shape, const T *input, T *output,
cudaStream_t cuda_stream);
template <typename T>
void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, const std::vector<int> &strides,
const std::vector<int> &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream);
void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int> &begin,
const std::vector<int> &strides, const std::vector<size_t> &dx_shape, const T *dy, T *dx,
cudaStream_t cuda_stream);
template <typename T>
void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_

+ 19
- 16
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu View File

@@ -20,19 +20,19 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis,
const int shape_size, T* output) {
int pos_size;
int temp_pos;
int newpos;
int newpos_size;
int pos_array[TRANSPOSE_MAX_DIMENSION];
__global__ void Transpose(const size_t size, const T* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, T* output) {
size_t pos_size;
size_t temp_pos;
size_t newpos;
size_t newpos_size;
size_t pos_array[TRANSPOSE_MAX_DIMENSION];
// for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] +
// posArray[1] * input_shape[2] * input_shape[3] +
// posArray[2] * input_shape[3] +
// posArray[3]
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
temp_pos = pos;
pos_size = size / input_shape[0];
pos_array[0] = temp_pos / pos_size;
@@ -54,16 +54,19 @@ __global__ void Transpose(const int size, const T* input, const int* input_shape
return;
}
template <typename T>
void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size,
T* output, cudaStream_t cuda_stream) {
void CalTranspose(const size_t size, const T* input, const size_t* input_shape, const size_t* input_axis,
const size_t shape_size, T* output, cudaStream_t cuda_stream) {
Transpose<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, input_shape, input_axis, shape_size,
output);
return;
}
template void CalTranspose<float>(const int size, const float* input, const int* input_shape, const int* input_axis,
const int shape_size, float* output, cudaStream_t cuda_stream);
template void CalTranspose<half>(const int size, const half* input, const int* input_shape, const int* input_axis,
const int shape_size, half* output, cudaStream_t cuda_stream);
template void CalTranspose<int>(const int size, const int* input, const int* input_shape, const int* input_axis,
const int shape_size, int* output, cudaStream_t cuda_stream);
template void CalTranspose<float>(const size_t size, const float* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, float* output,
cudaStream_t cuda_stream);
template void CalTranspose<half>(const size_t size, const half* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, half* output,
cudaStream_t cuda_stream);
template void CalTranspose<int>(const size_t size, const int* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, int* output,
cudaStream_t cuda_stream);

+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh View File

@@ -19,7 +19,7 @@
#define TRANSPOSE_MAX_DIMENSION 100
template <typename T>
void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size,
T* output, cudaStream_t cuda_stream);
void CalTranspose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis,
const size_t shape_size, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_H_

+ 5
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -87,14 +87,14 @@ class GpuKernel : public KernelMod {
return GetValue<T>(attr);
}
// expand Nd Shape to 4d (N in [0,4])
void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<int> *dst) {
void ShapeNdTo4d(const std::vector<size_t> &src, std::vector<size_t> *dst) {
if (src.size() > 4) {
MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!";
}
dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4]));
dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3]));
dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2]));
dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1]));
dst->push_back(src.size() < 4 ? 1 : src[src.size() - 4]);
dst->push_back(src.size() < 3 ? 1 : src[src.size() - 3]);
dst->push_back(src.size() < 2 ? 1 : src[src.size() - 2]);
dst->push_back(src.size() == 0 ? 1 : src[src.size() - 1]);
}
int AxisTransform(const std::string &origin_data_format, const std::string &cal_format, int axis) {


+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h View File

@@ -162,9 +162,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
int input1_num_;
int input2_num_;
int output_num_;
std::vector<int> lhs_shape_;
std::vector<int> rhs_shape_;
std::vector<int> output_shape_;
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;


+ 6
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h View File

@@ -140,12 +140,12 @@ class BroadcastOpGradGpuKernel : public GpuKernel {

BroadcastGradOpType op_type_;
bool need_broadcast_;
int input1_num_;
int input2_num_;
int output_num_;
int x1_shape_[4] = {1, 1, 1, 1};
int x2_shape_[4] = {1, 1, 1, 1};
int dy_shape_[4] = {1, 1, 1, 1};
size_t input1_num_;
size_t input2_num_;
size_t output_num_;
size_t x1_shape_[4] = {1, 1, 1, 1};
size_t x2_shape_[4] = {1, 1, 1, 1};
size_t dy_shape_[4] = {1, 1, 1, 1};
bool grad_x_;
bool grad_y_;



+ 9
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h View File

@@ -82,7 +82,7 @@ class ActivationGpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
std::vector<int> shape;
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef),
"cudnnSetActivationDescriptor failed");
@@ -91,13 +91,15 @@ class ActivationGpuFwdKernel : public GpuKernel {
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
shape[0], shape[3], shape[1], shape[2]),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])),
"cudnnSetTensor4dDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])),
"cudnnSetTensor4dDescriptor failed");
}
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);


+ 9
- 7
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h View File

@@ -89,7 +89,7 @@ class ActivationGradGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
std::vector<int> shape;
std::vector<size_t> shape;
double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, coef),
"SetActivationDescriptor failed");
@@ -98,13 +98,15 @@ class ActivationGradGpuKernel : public GpuKernel {
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
shape[0], shape[3], shape[1], shape[2]),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[3]), SizeToInt(shape[1]), SizeToInt(shape[2])),
"cudnnSetTensor4dDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(shape[0]),
SizeToInt(shape[1]), SizeToInt(shape[2]), SizeToInt(shape[3])),
"cudnnSetTensor4dDescriptor failed");
}
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);


+ 11
- 11
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h View File

@@ -68,9 +68,9 @@ class SoftmaxGpuKernel : public GpuKernel {
} else {
T *transpose_input_addr = GetDeviceAddress<T>(workspace, 0);
T *transpose_output_addr = GetDeviceAddress<T>(workspace, 1);
int *input_shape = GetDeviceAddress<int>(workspace, 2);
int *transpose_shape = GetDeviceAddress<int>(workspace, 3);
int *transpose_axis = GetDeviceAddress<int>(workspace, 4);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 2);
size_t *transpose_shape = GetDeviceAddress<size_t>(workspace, 3);
size_t *transpose_axis = GetDeviceAddress<size_t>(workspace, 4);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape failed");
@@ -80,7 +80,7 @@ class SoftmaxGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis failed");
int size = SizeToInt(input_size_ / sizeof(T));
size_t size = input_size_ / sizeof(T);
CalTranspose(size, input_addr, input_shape, transpose_axis, shape_size_, transpose_input_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDNN_RET_WITH_EXCEPT(
@@ -113,7 +113,7 @@ class SoftmaxGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
shape_size_ = SizeToInt(input_shape.size());
shape_size_ = input_shape.size();
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "LogSoftmax") {
algo_ = CUDNN_SOFTMAX_LOG;
@@ -171,7 +171,7 @@ class SoftmaxGpuKernel : public GpuKernel {
void InitSizeByAxis2D(const std::vector<size_t> &input_shape, const int &axis) {
axis_ = axis;
if (axis_ < 0) {
axis_ += shape_size_;
axis_ += SizeToInt(shape_size_);
}
if (axis_ == 1) {
batch_size_ = input_shape[0];
@@ -193,7 +193,7 @@ class SoftmaxGpuKernel : public GpuKernel {
width_ = 1;
input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_;
output_size_ = input_size_;
workspace_size_ = IntToSize(shape_size_) * sizeof(int);
workspace_size_ = shape_size_ * sizeof(size_t);
}
void InitSizeByAxisLastDim(const std::vector<size_t> &input_shape, const int &axis) {
@@ -235,11 +235,11 @@ class SoftmaxGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
std::vector<int> input_shape_;
std::vector<int> transpose_shape_;
std::vector<int> transpose_axis_;
std::vector<size_t> input_shape_;
std::vector<size_t> transpose_shape_;
std::vector<size_t> transpose_axis_;
int axis_;
int shape_size_;
size_t shape_size_;
size_t batch_size_;
size_t channel_size_;


+ 11
- 11
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h View File

@@ -62,9 +62,9 @@ class SoftmaxGradGpuKernel : public GpuKernel {
T *transpose_y_addr = GetDeviceAddress<T>(workspace, 0);
T *transpose_dy_addr = GetDeviceAddress<T>(workspace, 1);
T *transpose_dx_addr = GetDeviceAddress<T>(workspace, 2);
int *input_shape = GetDeviceAddress<int>(workspace, 3);
int *transpose_shape = GetDeviceAddress<int>(workspace, 4);
int *transpose_axis = GetDeviceAddress<int>(workspace, 5);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 3);
size_t *transpose_shape = GetDeviceAddress<size_t>(workspace, 4);
size_t *transpose_axis = GetDeviceAddress<size_t>(workspace, 5);
const float alpha = 1;
const float beta = 0;

@@ -82,7 +82,7 @@ class SoftmaxGradGpuKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis failed");
int size = SizeToInt(input_size_ / sizeof(T));
size_t size = input_size_ / sizeof(T);
CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr,
@@ -116,7 +116,7 @@ class SoftmaxGradGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
shape_size_ = SizeToInt(input_shape.size());
shape_size_ = input_shape.size();
if (shape_size_ != 2) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs.";
}
@@ -164,7 +164,7 @@ class SoftmaxGradGpuKernel : public GpuKernel {
void InitSizeByAxis(const std::vector<size_t> input_shape, const int axis) {
axis_ = axis;
if (axis_ < 0) {
axis_ += shape_size_;
axis_ += SizeToInt(shape_size_);
}
if (axis_ == 1) {
batch_size_ = input_shape[0];
@@ -186,7 +186,7 @@ class SoftmaxGradGpuKernel : public GpuKernel {
width_ = 1;
input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_;
output_size_ = input_size_;
workspace_size_ = IntToSize(shape_size_) * sizeof(int);
workspace_size_ = shape_size_ * sizeof(size_t);
}

cudnnHandle_t cudnn_handle_;
@@ -202,11 +202,11 @@ class SoftmaxGradGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

std::vector<int> input_shape_;
std::vector<int> transpose_shape_;
std::vector<int> transpose_axis_;
std::vector<size_t> input_shape_;
std::vector<size_t> transpose_shape_;
std::vector<size_t> transpose_axis_;
int axis_;
int shape_size_;
size_t shape_size_;

size_t batch_size_;
size_t channel_size_;


Loading…
Cancel
Save