Browse Source

!9855 Add int64 support to GPU transpose

From: @TFbunny
Reviewed-by: @tom__chen
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
568545f2d5
5 changed files with 39 additions and 19 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc
  2. +1
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h
  3. +15
    -12
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu
  4. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh
  5. +19
    -1
      tests/st/ops/gpu/test_transpose_op.py

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,5 +23,7 @@ MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).A
TransposeGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TransposeGpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TransposeGpuFwdKernel, int64_t)
} // namespace kernel
} // namespace mindspore

+ 1
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h View File

@@ -79,10 +79,7 @@ class TransposeGpuFwdKernel : public GpuKernel {
}
input_size_ *= sizeof(T);
output_size_ = input_size_;
std::vector<int> perm;
std::vector<int64_t> perm_me = GetAttr<std::vector<int64_t>>(kernel_node, "perm");
(void)std::transform(perm_me.begin(), perm_me.end(), std::back_inserter(perm),
[](const int64_t &value) { return static_cast<int>(value); });
std::vector<int64_t> perm = GetAttr<std::vector<int64_t>>(kernel_node, "perm");
for (size_t j = 0; j < perm.size(); j++) {
input_axis_.push_back(perm[j]);
}


+ 15
- 12
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu View File

@@ -20,8 +20,8 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void Transpose(const size_t size, const T* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, T* output) {
__global__ void Transpose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis,
const size_t shape_size, T *output) {
size_t pos_size;
size_t temp_pos;
size_t newpos;
@@ -36,7 +36,7 @@ __global__ void Transpose(const size_t size, const T* input, const size_t* input
temp_pos = pos;
pos_size = size / input_shape[0];
pos_array[0] = temp_pos / pos_size;
for (int i = 1; i < shape_size; i++) {
for (size_t i = 1; i < shape_size; i++) {
temp_pos -= pos_array[i - 1] * pos_size;
pos_size = pos_size / input_shape[i];
pos_array[i] = temp_pos / pos_size;
@@ -44,7 +44,7 @@ __global__ void Transpose(const size_t size, const T* input, const size_t* input
newpos = pos_array[input_axis[shape_size - 1]];
newpos_size = 1;
for (int j = shape_size - 2; j >= 0; j--) {
for (int64_t j = shape_size - 2; j >= 0; j--) {
newpos_size *= input_shape[input_axis[j + 1]];
newpos += pos_array[input_axis[j]] * newpos_size;
}
@@ -54,19 +54,22 @@ __global__ void Transpose(const size_t size, const T* input, const size_t* input
return;
}
template <typename T>
void CalTranspose(const size_t size, const T* input, const size_t* input_shape, const size_t* input_axis,
const size_t shape_size, T* output, cudaStream_t cuda_stream) {
void CalTranspose(const size_t size, const T *input, const size_t *input_shape, const size_t *input_axis,
const size_t shape_size, T *output, cudaStream_t cuda_stream) {
Transpose<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, input_shape, input_axis, shape_size,
output);
return;
}
template void CalTranspose<float>(const size_t size, const float* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, float* output,
template void CalTranspose<float>(const size_t size, const float *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, float *output,
cudaStream_t cuda_stream);
template void CalTranspose<half>(const size_t size, const half* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, half* output,
template void CalTranspose<half>(const size_t size, const half *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, half *output,
cudaStream_t cuda_stream);
template void CalTranspose<int>(const size_t size, const int* input, const size_t* input_shape,
const size_t* input_axis, const size_t shape_size, int* output,
template void CalTranspose<int>(const size_t size, const int *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, int *output,
cudaStream_t cuda_stream);
template void CalTranspose<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, int64_t *output,
cudaStream_t cuda_stream);

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 19
- 1
tests/st/ops/gpu/test_transpose_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -253,6 +253,24 @@ def test_transpose_float16():
def test_transpose_int32():
transpose1(np.int32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_int64():
transpose1(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_int64():
transpose_d(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dynamic_two_inputs_int64():
transpose_d2(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard


Loading…
Cancel
Save