From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @tom__chentags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -44,22 +44,58 @@ class DynamicRangeGpuKernel : public GpuKernel { | |||
| T *range_delta = GetDeviceAddress<T>(inputs, 2); | |||
| T *output_device_address = GetDeviceAddress<T>(outputs, 0); | |||
| int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0); | |||
| DynamicRangeErrorCode *error_code_device_address = GetDeviceAddress<DynamicRangeErrorCode>(workspace, 1); | |||
| stream_ptr_ = stream_ptr; | |||
| CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address, | |||
| max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CudaValidateInputAndInferShape(range_start, range_end, range_delta, output_shape_device_address, | |||
| error_code_device_address, max_output_length_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| DynamicRangeErrorCode error_code = DynamicRangeErrorCode::kOk; | |||
| CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_, | |||
| cudaMemcpyAsync(&error_code, error_code_device_address, sizeof(DynamicRangeErrorCode), | |||
| cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "Failed to copy error code to host."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); | |||
| // use workspace[0] for actual output shape, we know it must be 1d | |||
| CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_, | |||
| cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t), | |||
| cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "Failed to copy gpu memory."); | |||
| "Failed to copy output_shape to host."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed"); | |||
| LogExceptionIfNotOk(error_code); | |||
| CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address, | |||
| error_code_device_address, max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| void LogExceptionIfNotOk(DynamicRangeErrorCode error_code) { | |||
| switch (error_code) { | |||
| case DynamicRangeErrorCode::kOk: | |||
| return; | |||
| case DynamicRangeErrorCode::kDeltaIsZero: | |||
| MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be equal to zero"; | |||
| break; | |||
| case DynamicRangeErrorCode::kInvalidPositiveDelta: | |||
| MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be positive when limit < start"; | |||
| break; | |||
| case DynamicRangeErrorCode::kInvalidNegativeDelta: | |||
| MS_LOG(EXCEPTION) << "gpu RangeOp input error: delta cannot be negative when limit > start"; | |||
| break; | |||
| case DynamicRangeErrorCode::kMaxSizeExceeded: | |||
| MS_LOG(EXCEPTION) << "gpu RangeOp memory error: the number of elements in the output exceeds maxlen"; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "gpu RangeOp unknown error"; | |||
| } | |||
| } | |||
| void PostExecute() override { | |||
| // required synchronize for PostExecute | |||
| CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)), | |||
| @@ -103,6 +139,7 @@ class DynamicRangeGpuKernel : public GpuKernel { | |||
| // this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape. | |||
| workspace_size_list_.push_back(sizeof(int64_t)); | |||
| workspace_size_list_.push_back(sizeof(DynamicRangeErrorCode)); | |||
| return; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -20,57 +20,90 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __device__ void CheckInputs(const T &start, const T &end, const T &delta) { | |||
| __global__ void ValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta, | |||
| int64_t *output_shape, DynamicRangeErrorCode *error_code, | |||
| const int64_t max_output_size) { | |||
| T start = range_start[0]; | |||
| T end = range_end[0]; | |||
| T delta = range_delta[0]; | |||
| *error_code = DynamicRangeErrorCode::kOk; | |||
| if (delta == 0) { | |||
| asm("trap;"); | |||
| *error_code = DynamicRangeErrorCode::kDeltaIsZero; | |||
| return; | |||
| } | |||
| if (start < end && delta < 0) { | |||
| asm("trap;"); | |||
| *error_code = DynamicRangeErrorCode::kInvalidNegativeDelta; | |||
| return; | |||
| } | |||
| if (start > end && delta > 0) { | |||
| asm("trap;"); | |||
| *error_code = DynamicRangeErrorCode::kInvalidPositiveDelta; | |||
| return; | |||
| } | |||
| if (*error_code == DynamicRangeErrorCode::kOk) { | |||
| int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta)); | |||
| if (real_output_shape > max_output_size) { | |||
| *error_code = DynamicRangeErrorCode::kMaxSizeExceeded; | |||
| } | |||
| *output_shape = real_output_shape; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, | |||
| int64_t *output_shape, const int64_t max_output_size) { | |||
| __global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, | |||
| const int64_t max_output_size) { | |||
| T start = range_start[0]; | |||
| T end = range_end[0]; | |||
| T delta = range_delta[0]; | |||
| CheckInputs(start, end, delta); | |||
| int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta)); | |||
| if (real_output_shape > max_output_size) { | |||
| asm("trap;"); | |||
| } | |||
| *output_shape = real_output_shape; | |||
| size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; | |||
| for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) { | |||
| for (; gt_id < *output_shape; gt_id += blockDim.x * gridDim.x) { | |||
| output[gt_id] = gt_id * delta + start; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta, | |||
| int64_t *output_shape, DynamicRangeErrorCode *error_code, | |||
| const int64_t max_output_size, cudaStream_t cuda_stream) { | |||
| ValidateInputAndInferShape<<<1, 1, 0, cuda_stream>>>(range_start, range_end, range_delta, output_shape, error_code, | |||
| max_output_size); | |||
| } | |||
| template <typename T> | |||
| void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, | |||
| const int64_t max_output_size, cudaStream_t cuda_stream) { | |||
| DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream) { | |||
| Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta, | |||
| output, output_shape, max_output_size); | |||
| } | |||
| template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output, | |||
| int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| template void CudaValidateInputAndInferShape<int>(const int *range_start, const int *range_end, const int *range_delta, | |||
| int64_t *output_shape, DynamicRangeErrorCode *error_code, | |||
| const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| template void CudaValidateInputAndInferShape<int64_t>(const int64_t *range_start, const int64_t *range_end, | |||
| const int64_t *range_delta, int64_t *output_shape, | |||
| DynamicRangeErrorCode *error_code, const int64_t max_output_size, | |||
| cudaStream_t cuda_stream); | |||
| template void CudaValidateInputAndInferShape<float>(const float *range_start, const float *range_end, | |||
| const float *range_delta, int64_t *output_shape, | |||
| DynamicRangeErrorCode *error_code, const int64_t max_output_size, | |||
| cudaStream_t cuda_stream); | |||
| template void CudaValidateInputAndInferShape<double>(const double *range_start, const double *range_end, | |||
| const double *range_delta, int64_t *output_shape, | |||
| DynamicRangeErrorCode *error_code, const int64_t max_output_size, | |||
| cudaStream_t cuda_stream); | |||
| template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output, | |||
| int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size, | |||
| cudaStream_t cuda_stream); | |||
| template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta, | |||
| int64_t *output, int64_t *output_shape, const int64_t max_output_size, | |||
| cudaStream_t cuda_stream); | |||
| int64_t *output, int64_t *output_shape, DynamicRangeErrorCode *error_code, | |||
| const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output, | |||
| int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| int64_t *output_shape, DynamicRangeErrorCode *error_code, const int64_t max_output_size, | |||
| cudaStream_t cuda_stream); | |||
| template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta, | |||
| double *output, int64_t *output_shape, const int64_t max_output_size, | |||
| cudaStream_t cuda_stream); | |||
| double *output, int64_t *output_shape, DynamicRangeErrorCode *error_code, | |||
| const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -19,8 +19,21 @@ | |||
| #include <cuda_runtime.h> | |||
| enum class DynamicRangeErrorCode { | |||
| kOk = 0, | |||
| kDeltaIsZero, | |||
| kInvalidPositiveDelta, | |||
| kInvalidNegativeDelta, | |||
| kMaxSizeExceeded | |||
| }; | |||
| template <typename T> | |||
| void CudaValidateInputAndInferShape(const T *range_start, const T *range_end, const T *range_delta, | |||
| int64_t *output_shape, DynamicRangeErrorCode *error_code, | |||
| const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape, | |||
| const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| DynamicRangeErrorCode *error_code, const int64_t max_output_size, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_ | |||
| @@ -22,12 +22,12 @@ from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class RangeNet(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, maxlen=10000): | |||
| super(RangeNet, self).__init__() | |||
| self.range = P.Range() | |||
| self.range = P.Range(maxlen) | |||
| def construct(self, s, e, d): | |||
| return self.range(s, e, d) | |||
| def construct(self, start, limit, delta): | |||
| return self.range(start, limit, delta) | |||
| @pytest.mark.level0 | |||
| @@ -91,3 +91,27 @@ def test_range_invalid_max_output_length(): | |||
| _ = P.Range(-1) | |||
| _ = P.Range(None) | |||
| _ = P.Range('5') | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_range_invalid_input(): | |||
| with pytest.raises(RuntimeError) as info: | |||
| range_net = RangeNet(3500) | |||
| _ = range_net(Tensor(0, mstype.int32), Tensor(5, mstype.int32), Tensor(0, mstype.int32)).asnumpy() | |||
| assert "delta cannot be equal to zero" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| range_net = RangeNet(2) | |||
| _ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy() | |||
| assert "number of elements in the output exceeds maxlen" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| range_net = RangeNet(3500) | |||
| _ = range_net(Tensor(20, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy() | |||
| assert "delta cannot be positive when limit < start" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| range_net = RangeNet(3500) | |||
| _ = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(-4, mstype.int32)).asnumpy() | |||
| assert "delta cannot be negative when limit > start" in str(info.value) | |||