| @@ -1,33 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ScatterNdGpuFwdKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ScatterNdGpuFwdKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ScatterNdGpuFwdKernel, int, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ScatterNdGpuFwdKernel, float, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ScatterNdGpuFwdKernel, half, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ScatterNdGpuFwdKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| ScatterNdGpuFwdKernel, short, int) // NOLINT | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||
| ScatterNdGpuFwdKernel, uchar, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -23,9 +23,9 @@ struct MinimumGradFunc { | |||
| __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, | |||
| const T &dy, T *dx1, T *dx2) { | |||
| if (grad_x1 && x1 < x2) { | |||
| ms_atomic_add(dx1, dy); | |||
| MsAtomicAdd(dx1, dy); | |||
| } else if (grad_x2 && x1 >= x2) { | |||
| ms_atomic_add(dx2, dy); | |||
| MsAtomicAdd(dx2, dy); | |||
| } | |||
| } | |||
| }; | |||
| @@ -35,9 +35,9 @@ struct MaximumGradFunc { | |||
| __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, | |||
| const T &dy, T *dx1, T *dx2) { | |||
| if (grad_x1 && x1 > x2) { | |||
| ms_atomic_add(dx1, dy); | |||
| MsAtomicAdd(dx1, dy); | |||
| } else if (grad_x2 && x1 <= x2) { | |||
| ms_atomic_add(dx2, dy); | |||
| MsAtomicAdd(dx2, dy); | |||
| } | |||
| } | |||
| }; | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| @@ -61,7 +61,7 @@ __global__ void ResizeNearestNeighborGrad(const int input_size, const T *input, | |||
| out_width - 1); | |||
| // pos_array[0] N, pos_array[1] C, out_y H, out_x W | |||
| output_pos = pos_array[0] * d2 * d3 * d4 + pos_array[1] * d3 * d4 + out_y * d4 + out_x; | |||
| ms_atomic_add(&output[output_pos], input[pos]); | |||
| MsAtomicAdd(&output[output_pos], input[pos]); | |||
| } | |||
| } | |||
| @@ -218,10 +218,10 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, | |||
| T *dx_3 = dx + offset + y_high * width + x_low; | |||
| T *dx_4 = dx + offset + y_high * width + x_high; | |||
| if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { | |||
| ms_atomic_add(dx_1, g1); | |||
| ms_atomic_add(dx_2, g2); | |||
| ms_atomic_add(dx_3, g3); | |||
| ms_atomic_add(dx_4, g4); | |||
| MsAtomicAdd(dx_1, g1); | |||
| MsAtomicAdd(dx_2, g2); | |||
| MsAtomicAdd(dx_3, g3); | |||
| MsAtomicAdd(dx_4, g4); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,70 +1,80 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, | |||
| const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, | |||
| S *indices_stride, S *work_shape) { | |||
| int i, j; | |||
| for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; | |||
| read_index += blockDim.x * gridDim.x) { | |||
| int write_index = 0; | |||
| bool out_bound = false; | |||
| i = read_index / block_size; | |||
| j = read_index % block_size; | |||
| for (size_t k = 0; k < indices_dim_1; k++) { | |||
| S indices_i = indices[i * indices_dim_1 + k]; | |||
| out_bound |= indices_i >= work_shape[k]; | |||
| write_index += indices_i * indices_stride[k]; | |||
| } | |||
| write_index += j; | |||
| out_bound |= write_index >= output_size; | |||
| if (!out_bound) { | |||
| ms_atomic_add(&output[write_index], update[read_index]); | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | |||
| const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | |||
| S *work_shape, cudaStream_t stream) { | |||
| ScatterNdKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size, | |||
| output_size, indices_dim_0, indices_dim_1, | |||
| indices_stride, work_shape); | |||
| return; | |||
| } | |||
| template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size, | |||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size, | |||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size, | |||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, | |||
| const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, | |||
| S *indices_stride, S *work_shape) { | |||
| int i, j; | |||
| for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; | |||
| read_index += blockDim.x * gridDim.x) { | |||
| int write_index = 0; | |||
| bool out_bound = false; | |||
| i = read_index / block_size; | |||
| j = read_index % block_size; | |||
| for (size_t k = 0; k < indices_dim_1; k++) { | |||
| S indices_i = indices[i * indices_dim_1 + k]; | |||
| out_bound |= indices_i >= work_shape[k]; | |||
| write_index += indices_i * indices_stride[k]; | |||
| } | |||
| write_index += j; | |||
| out_bound |= write_index >= output_size; | |||
| if (!out_bound) { | |||
| MsAtomicAdd(&output[write_index], update[read_index]); | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | |||
| const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | |||
| S *work_shape, cudaStream_t stream) { | |||
| ScatterNdKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size, | |||
| output_size, indices_dim_0, indices_dim_1, | |||
| indices_stride, work_shape); | |||
| return; | |||
| } | |||
| template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size, | |||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size, | |||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size, | |||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| // NOLINTNEXTLINE | |||
| template void ScatterNd<short, int>(int *indices, short *update, short *output, const size_t &block_size, | |||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| template void ScatterNd<unsigned char, int>(int *indices, unsigned char *update, unsigned char *output, | |||
| const size_t &block_size, const size_t &input_size, | |||
| const size_t &output_size, const size_t &indices_dim_0, | |||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||
| cudaStream_t stream); | |||
| @@ -19,11 +19,41 @@ | |||
| #include <cuda_fp16.h> | |||
| inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); } | |||
| __device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); } | |||
| inline __device__ int ms_atomic_add(int *address, int val) { return atomicAdd(address, val); } | |||
| __device__ static inline int MsAtomicAdd(int *address, int val) { return atomicAdd(address, val); } | |||
| inline __device__ half ms_atomic_add(half *address, half val) { | |||
| __device__ static inline unsigned int MsAtomicAdd(unsigned int *address, unsigned int val) { | |||
| return atomicAdd(address, val); | |||
| } | |||
| __device__ static inline unsigned char MsAtomicAdd(short *address, short val) { // NOLINT | |||
| bool is_4_byte_aligned = ((size_t) address & 2) == 0; | |||
| unsigned int *aligned = (unsigned int *) ((size_t) address & ~2); | |||
| unsigned int old = *aligned; | |||
| unsigned int assumed; | |||
| do { | |||
| assumed = old; | |||
| unsigned int replacement; | |||
| if (is_4_byte_aligned) { | |||
| replacement = (old & 0xffff0000) | (((old & 0xffff) + val) & 0xffff); | |||
| } else { | |||
| replacement = old + ((unsigned int) val << 16); | |||
| } | |||
| old = atomicCAS(aligned, assumed, replacement); | |||
| } while (assumed != old); | |||
| if (is_4_byte_aligned) { | |||
| return (short) (old & 0xffff); // NOLINT | |||
| } else { | |||
| return (short) (old >> 16); // NOLINT | |||
| } | |||
| } | |||
| __device__ static inline half MsAtomicAdd(half *address, half val) { | |||
| unsigned int *aligned = | |||
| reinterpret_cast<unsigned int *>(reinterpret_cast<size_t>(address) - (reinterpret_cast<size_t>(address) & 2)); | |||
| unsigned int old = *aligned; | |||
| @@ -42,4 +72,66 @@ inline __device__ half ms_atomic_add(half *address, half val) { | |||
| return half(raw); | |||
| } | |||
| __device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsigned char val) { | |||
| // We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to | |||
| // implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but | |||
| // unsigned int* must be 4 byte aligned. This variable contains the offset, | |||
| // in bytes, of the beginning of address, within the 4 byte aligned space that | |||
| // contains it. | |||
| size_t address_offset = (size_t) address & 3; | |||
| // Address of the 4 byte aligned space that contains address. | |||
| unsigned int* aligned = (unsigned int*) ((unsigned char*) address - address_offset); | |||
| // Constants which will be used later with __byte_perm. __byte_perm is a cuda | |||
| // function which takes 3 unsigned int's (x, y, selector) as parameters and | |||
| // returns an int. __byte_perm returns an integer by selecting bytes from x | |||
| // and y based on the given selector. The selector 0x3210 in will select all | |||
| // four bytes from x, preserving their original order. The position of the | |||
| // "4" in the selector indicates the position in the output where the first | |||
| // byte of y will end up. | |||
| unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; | |||
| // Gets the selector that will select the bytes at address from aligned | |||
| unsigned int selector = selectors[address_offset]; | |||
| unsigned int old = *aligned; | |||
| unsigned int assumed = 0; | |||
| do { | |||
| assumed = old; | |||
| // Selects the byte associated with address and put it as the first byte of | |||
| // this variable, so that we can add val to the value at address. | |||
| unsigned int sum = val + __byte_perm(old, 0, address_offset); | |||
| // Takes old and replaces the byte corresponding to address with the sum. | |||
| unsigned int replacement = __byte_perm(old, sum, selector); | |||
| // Try to replace the old value with the new value | |||
| old = atomicCAS(aligned, assumed, replacement); | |||
| } while (old != assumed); | |||
| // Select the single byte corredsponding to address and return it. | |||
| return __byte_perm(old, 0, address_offset); | |||
| } | |||
| __device__ static inline char MsAtomicAdd(char* address, char val) { | |||
| size_t address_offset = (size_t) address & 3; | |||
| unsigned int* aligned = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - address_offset); | |||
| unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; | |||
| unsigned int selector = selectors[address_offset]; | |||
| unsigned int old = *aligned; | |||
| unsigned int assumed = 0; | |||
| do { | |||
| assumed = old; | |||
| unsigned int sum = val + __byte_perm(old, 0, address_offset); | |||
| unsigned int replacement = __byte_perm(old, sum, selector); | |||
| old = atomicCAS(aligned, assumed, replacement); | |||
| } while (old != assumed); | |||
| return __byte_perm(old, 0, address_offset); | |||
| } | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_ | |||