| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_ | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||||
| template <typename T> | |||||
| __global__ void AddReluV2Kernel(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) { | |||||
| T sum = x1[i] + x2[i]; | |||||
| bool p = sum > static_cast<T>(0); | |||||
| y[i] = p ? sum : static_cast<T>(0); | |||||
| auto warp_predict = BallotSync(p, __activemask()); | |||||
| if (LaneId() == 0) { | |||||
| mask[WarpId(i)] = warp_predict; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream) { | |||||
| AddReluV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x1, x2, y, mask); | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void AddReluGradV2Kernel(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) { | |||||
| bool positive = mask[WarpId(i)] & (1 << LaneId()); | |||||
| dx[i] = positive ? x1[i] + x2[i] : static_cast<T>(0); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void AddReluGradV2(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream) { | |||||
| AddReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x1, x2, mask, dx); | |||||
| } | |||||
| template void AddReluV2(const size_t num, const float *x1, const float *x2, float *y, uint32_t *mask, | |||||
| cudaStream_t cuda_stream); | |||||
| template void AddReluV2(const size_t num, const half *x1, const half *x2, half *y, uint32_t *mask, | |||||
| cudaStream_t cuda_stream); | |||||
| template void AddReluV2(const size_t num, const int32_t *x1, const int32_t *x2, int32_t *y, uint32_t *mask, | |||||
| cudaStream_t cuda_stream); | |||||
| template void AddReluV2(const size_t num, const int64_t *x1, const int64_t *x2, int64_t *y, uint32_t *mask, | |||||
| cudaStream_t cuda_stream); | |||||
| template void AddReluGradV2(const size_t num, const float *x1, const float *x2, const uint32_t *mask, float *dx, | |||||
| cudaStream_t cuda_stream); | |||||
| template void AddReluGradV2(const size_t num, const half *x1, const half *x2, const uint32_t *mask, half *dx, | |||||
| cudaStream_t cuda_stream); | |||||
| template void AddReluGradV2(const size_t num, const int32_t *x1, const int32_t *x2, const uint32_t *mask, int32_t *dx, | |||||
| cudaStream_t cuda_stream); | |||||
| template void AddReluGradV2(const size_t num, const int64_t *x1, const int64_t *x2, const uint32_t *mask, int64_t *dx, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_ | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| @@ -34,3 +35,47 @@ template void CalReLU(int size, float *input_addr, float *output_addr, cudaStrea | |||||
| template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); | ||||
| template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); | ||||
| template <typename T> | |||||
| __global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { | |||||
| T v = x[i]; | |||||
| bool p = v > static_cast<T>(0); | |||||
| y[i] = p ? v : static_cast<T>(0); | |||||
| auto warp_predict = BallotSync(p, __activemask()); | |||||
| if (LaneId() == 0) { | |||||
| mask[WarpId(i)] = warp_predict; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream) { | |||||
| ReluV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, x, y, mask); | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void ReluGradV2Kernel(const size_t num, const T *dy, const uint32_t *mask, T *dx) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { | |||||
| bool p = mask[WarpId(i)] & (1 << LaneId()); | |||||
| dx[i] = p ? dy[i] : static_cast<T>(0); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream) { | |||||
| ReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, dy, mask, dx); | |||||
| } | |||||
| template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream); | |||||
| template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream); | |||||
| template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx, | |||||
| cudaStream_t cuda_stream); | |||||
| template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -20,4 +20,9 @@ | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| void CalReLU(int input_size, T *input_addr, T *output_addr, cudaStream_t cuda_stream); | void CalReLU(int input_size, T *input_addr, T *output_addr, cudaStream_t cuda_stream); | ||||
| template <typename T> | |||||
| void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream); | |||||
| template <typename T> | |||||
| void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ | ||||
| @@ -20,16 +20,18 @@ | |||||
| #include <cuda_fp16.h> | #include <cuda_fp16.h> | ||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| #define kThreadsPerBlock (256) | |||||
| #define kBlocksPerGrid(n) ((n + kThreadsPerBlock - 1) / kThreadsPerBlock) | |||||
| __device__ static inline double MsAtomicAdd(double *address, const double val) { | __device__ static inline double MsAtomicAdd(double *address, const double val) { | ||||
| unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT | |||||
| unsigned long long int old = *address_as_ull; // NOLINT | |||||
| unsigned long long int assumed; // NOLINT | |||||
| do { | |||||
| assumed = old; | |||||
| old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); | |||||
| } | |||||
| while (assumed != old); // NOLINT | |||||
| return __longlong_as_double(old); | |||||
| unsigned long long int *address_as_ull = (unsigned long long int *)address; // NOLINT | |||||
| unsigned long long int old = *address_as_ull; // NOLINT | |||||
| unsigned long long int assumed; // NOLINT | |||||
| do { | |||||
| assumed = old; | |||||
| old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); | |||||
| } while (assumed != old); // NOLINT | |||||
| return __longlong_as_double(old); | |||||
| } | } | ||||
| __device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); } | __device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); } | ||||
| @@ -42,7 +44,7 @@ __device__ static inline unsigned int MsAtomicAdd(unsigned int *address, unsigne | |||||
| __device__ static inline int8_t MsAtomicAdd(int8_t *address, int8_t val) { | __device__ static inline int8_t MsAtomicAdd(int8_t *address, int8_t val) { | ||||
| size_t offset = (size_t)address & 3; | size_t offset = (size_t)address & 3; | ||||
| uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); // NOLINT | |||||
| uint32_t *address_as_ui = (uint32_t *)((char *)address - offset); // NOLINT | |||||
| uint32_t old = *address_as_ui; | uint32_t old = *address_as_ui; | ||||
| uint32_t shift = offset * 8; | uint32_t shift = offset * 8; | ||||
| uint32_t old_byte; | uint32_t old_byte; | ||||
| @@ -60,27 +62,27 @@ __device__ static inline int8_t MsAtomicAdd(int8_t *address, int8_t val) { | |||||
| } | } | ||||
| __device__ static inline int64_t MsAtomicAdd(int64_t *address, int64_t val) { | __device__ static inline int64_t MsAtomicAdd(int64_t *address, int64_t val) { | ||||
| unsigned long long * address_as_ui = (unsigned long long *) (address); // NOLINT | |||||
| unsigned long long old = *address_as_ui; // NOLINT | |||||
| unsigned long long newval; // NOLINT | |||||
| unsigned long long assumed; // NOLINT | |||||
| unsigned long long *address_as_ui = (unsigned long long *)(address); // NOLINT | |||||
| unsigned long long old = *address_as_ui; // NOLINT | |||||
| unsigned long long newval; // NOLINT | |||||
| unsigned long long assumed; // NOLINT | |||||
| do { | do { | ||||
| assumed = old; | assumed = old; | ||||
| newval = val + (int64_t)old; | |||||
| newval = val + (int64_t)old; | |||||
| old = atomicCAS(address_as_ui, assumed, newval); | old = atomicCAS(address_as_ui, assumed, newval); | ||||
| } while (assumed != old); | } while (assumed != old); | ||||
| return (int64_t)old; | return (int64_t)old; | ||||
| } | } | ||||
| __device__ static inline bool MsAtomicAdd(bool *address, bool val) { | __device__ static inline bool MsAtomicAdd(bool *address, bool val) { | ||||
| *address = address && val; | |||||
| return address[0]; | |||||
| *address = address && val; | |||||
| return address[0]; | |||||
| } | } | ||||
| __device__ static inline unsigned char MsAtomicAdd(short *address, short val) { // NOLINT | __device__ static inline unsigned char MsAtomicAdd(short *address, short val) { // NOLINT | ||||
| bool is_4_byte_aligned = ((size_t) address & 2) == 0; | |||||
| unsigned int *aligned = (unsigned int *) ((size_t) address & ~2); | |||||
| bool is_4_byte_aligned = ((size_t)address & 2) == 0; | |||||
| unsigned int *aligned = (unsigned int *)((size_t)address & ~2); | |||||
| unsigned int old = *aligned; | unsigned int old = *aligned; | ||||
| unsigned int assumed; | unsigned int assumed; | ||||
| @@ -91,16 +93,16 @@ __device__ static inline unsigned char MsAtomicAdd(short *address, short val) { | |||||
| if (is_4_byte_aligned) { | if (is_4_byte_aligned) { | ||||
| replacement = (old & 0xffff0000) | (((old & 0xffff) + val) & 0xffff); | replacement = (old & 0xffff0000) | (((old & 0xffff) + val) & 0xffff); | ||||
| } else { | } else { | ||||
| replacement = old + ((unsigned int) val << 16); | |||||
| replacement = old + ((unsigned int)val << 16); | |||||
| } | } | ||||
| old = atomicCAS(aligned, assumed, replacement); | old = atomicCAS(aligned, assumed, replacement); | ||||
| } while (assumed != old); | } while (assumed != old); | ||||
| if (is_4_byte_aligned) { | if (is_4_byte_aligned) { | ||||
| return (short) (old & 0xffff); // NOLINT | |||||
| return (short)(old & 0xffff); // NOLINT | |||||
| } else { | } else { | ||||
| return (short) (old >> 16); // NOLINT | |||||
| return (short)(old >> 16); // NOLINT | |||||
| } | } | ||||
| } | } | ||||
| @@ -112,7 +114,8 @@ __device__ static inline half MsAtomicAdd(half *address, half val) { | |||||
| unsigned short old_as_us; // NOLINT | unsigned short old_as_us; // NOLINT | ||||
| do { | do { | ||||
| assumed = old; | assumed = old; | ||||
| old_as_us = static_cast<unsigned short>(reinterpret_cast<size_t>(address) & 2 ? old >> 16 : old & 0xffff); // NOLINT | |||||
| old_as_us = | |||||
| static_cast<unsigned short>(reinterpret_cast<size_t>(address) & 2 ? old >> 16 : old & 0xffff); // NOLINT | |||||
| half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast<float>(val)); | half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast<float>(val)); | ||||
| unsigned short sum_as_us = __half_as_ushort(sum); // NOLINT | unsigned short sum_as_us = __half_as_ushort(sum); // NOLINT | ||||
| unsigned int sum_as_ui = | unsigned int sum_as_ui = | ||||
| @@ -123,16 +126,16 @@ __device__ static inline half MsAtomicAdd(half *address, half val) { | |||||
| return half(raw); | return half(raw); | ||||
| } | } | ||||
| __device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsigned char val) { | |||||
| __device__ static inline unsigned char MsAtomicAdd(unsigned char *address, unsigned char val) { | |||||
| // We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to | // We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to | ||||
| // implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but | // implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but | ||||
| // unsigned int* must be 4 byte aligned. This variable contains the offset, | // unsigned int* must be 4 byte aligned. This variable contains the offset, | ||||
| // in bytes, of the beginning of address, within the 4 byte aligned space that | // in bytes, of the beginning of address, within the 4 byte aligned space that | ||||
| // contains it. | // contains it. | ||||
| size_t address_offset = (size_t) address & 3; | |||||
| size_t address_offset = (size_t)address & 3; | |||||
| // Address of the 4 byte aligned space that contains address. | // Address of the 4 byte aligned space that contains address. | ||||
| unsigned int* aligned = (unsigned int*) ((unsigned char*) address - address_offset); | |||||
| unsigned int *aligned = (unsigned int *)((unsigned char *)address - address_offset); | |||||
| // Constants which will be used later with __byte_perm. __byte_perm is a cuda | // Constants which will be used later with __byte_perm. __byte_perm is a cuda | ||||
| // function which takes 3 unsigned int's (x, y, selector) as parameters and | // function which takes 3 unsigned int's (x, y, selector) as parameters and | ||||
| @@ -166,9 +169,9 @@ __device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsig | |||||
| return __byte_perm(old, 0, address_offset); | return __byte_perm(old, 0, address_offset); | ||||
| } | } | ||||
| __device__ static inline char MsAtomicAdd(char* address, char val) { | |||||
| size_t address_offset = (size_t) address & 3; | |||||
| unsigned int* aligned = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - address_offset); | |||||
| __device__ static inline char MsAtomicAdd(char *address, char val) { | |||||
| size_t address_offset = (size_t)address & 3; | |||||
| unsigned int *aligned = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - address_offset); | |||||
| unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; | unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; | ||||
| unsigned int selector = selectors[address_offset]; | unsigned int selector = selectors[address_offset]; | ||||
| unsigned int old = *aligned; | unsigned int old = *aligned; | ||||
| @@ -185,4 +188,12 @@ __device__ static inline char MsAtomicAdd(char* address, char val) { | |||||
| return __byte_perm(old, 0, address_offset); | return __byte_perm(old, 0, address_offset); | ||||
| } | } | ||||
| __device__ __forceinline__ unsigned BallotSync(int predicate, unsigned mask = 0xffffffff) { | |||||
| return __ballot_sync(mask, predicate); | |||||
| } | |||||
| enum : unsigned { warp_size = 32, log_wap_size = 5 }; | |||||
| __device__ __forceinline__ unsigned LaneId() { return threadIdx.x & (warp_size - 1); } | |||||
| __device__ __forceinline__ unsigned WarpId(const unsigned &tid) { return tid >> log_wap_size; } | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_ | ||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeUInt32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| FusedAddReluGradV2GpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeUInt32) | |||||
| .AddOutputAttr(kNumberTypeFloat16), | |||||
| FusedAddReluGradV2GpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeUInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32), | |||||
| FusedAddReluGradV2GpuKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeUInt32) | |||||
| .AddOutputAttr(kNumberTypeInt64), | |||||
| FusedAddReluGradV2GpuKernel, int64_t) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class FusedAddReluGradV2GpuKernel : public GpuKernel { | |||||
| public: | |||||
| FusedAddReluGradV2GpuKernel() { ResetResource(); } | |||||
| ~FusedAddReluGradV2GpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| auto x1 = GetDeviceAddress<T>(inputs, 0); | |||||
| auto x2 = GetDeviceAddress<T>(inputs, 1); | |||||
| auto mask = GetDeviceAddress<uint32_t>(inputs, 2); | |||||
| auto dx = GetDeviceAddress<T>(outputs, 0); | |||||
| AddReluGradV2(element_num_, x1, x2, mask, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>()); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| void ResetResource() noexcept override { | |||||
| element_num_ = 0; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| auto size = element_num_ * sizeof(T); | |||||
| input_size_list_.push_back(size); | |||||
| input_size_list_.push_back(size); | |||||
| input_size_list_.push_back(size); | |||||
| output_size_list_.push_back(size); | |||||
| size = (element_num_ + 31) / 32 * sizeof(uint32_t); | |||||
| input_size_list_.push_back(size); | |||||
| } | |||||
| private: | |||||
| size_t element_num_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeUInt32), | |||||
| FusedAddReluV2GpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddInputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeFloat16) | |||||
| .AddOutputAttr(kNumberTypeUInt32), | |||||
| FusedAddReluV2GpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddInputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeInt32) | |||||
| .AddOutputAttr(kNumberTypeUInt32), | |||||
| FusedAddReluV2GpuKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddInputAttr(kNumberTypeInt64) | |||||
| .AddOutputAttr(kNumberTypeInt64) | |||||
| .AddOutputAttr(kNumberTypeUInt32), | |||||
| FusedAddReluV2GpuKernel, int64_t) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class FusedAddReluV2GpuKernel : public GpuKernel { | |||||
| public: | |||||
| FusedAddReluV2GpuKernel() { ResetResource(); } | |||||
| ~FusedAddReluV2GpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| auto x1 = GetDeviceAddress<T>(inputs, 0); | |||||
| auto x2 = GetDeviceAddress<T>(inputs, 1); | |||||
| auto y = GetDeviceAddress<T>(outputs, 0); | |||||
| auto mask = GetDeviceAddress<uint32_t>(outputs, 1); | |||||
| AddReluV2(element_num_, x1, x2, y, mask, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>()); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| void ResetResource() noexcept override { | |||||
| element_num_ = 0; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| auto size = element_num_ * sizeof(T); | |||||
| input_size_list_.push_back(size); | |||||
| input_size_list_.push_back(size); | |||||
| output_size_list_.push_back(size); | |||||
| size = (element_num_ + 31) / 32 * sizeof(uint32_t); | |||||
| output_size_list_.push_back(size); | |||||
| } | |||||
| private: | |||||
| size_t element_num_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGradV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ReluGradV2GpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGradV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), | |||||
| ReluGradV2GpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGradV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ReluGradV2GpuKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReluGradV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), | |||||
| ReluGradV2GpuKernel, int64_t) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class ReluGradV2GpuKernel : public GpuKernel { | |||||
| public: | |||||
| ReluGradV2GpuKernel() { ResetResource(); } | |||||
| ~ReluGradV2GpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| auto dy = GetDeviceAddress<T>(inputs, 0); | |||||
| auto mask = GetDeviceAddress<uint32_t>(inputs, 1); | |||||
| auto dx = GetDeviceAddress<T>(outputs, 0); | |||||
| ReluGradV2(element_num_, dy, mask, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>()); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| void ResetResource() noexcept override { | |||||
| element_num_ = 0; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| auto size = element_num_ * sizeof(T); | |||||
| input_size_list_.push_back(size); | |||||
| output_size_list_.push_back(size); | |||||
| auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t); | |||||
| input_size_list_.push_back(mask_size); | |||||
| } | |||||
| private: | |||||
| size_t element_num_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GRAD_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReLUV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReluV2GpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReLUV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReluV2GpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReluV2GpuKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| ReLUV2, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), | |||||
| ReluV2GpuKernel, int64_t) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,82 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class ReluV2GpuKernel : public GpuKernel { | |||||
| public: | |||||
| ReluV2GpuKernel() { ResetResource(); } | |||||
| ~ReluV2GpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| auto x = GetDeviceAddress<T>(inputs, 0); | |||||
| auto y = GetDeviceAddress<T>(outputs, 0); | |||||
| auto mask = GetDeviceAddress<uint32_t>(outputs, 1); | |||||
| ReluV2(element_num_, x, y, mask, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||||
| element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>()); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| void ResetResource() noexcept override { | |||||
| element_num_ = 0; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| auto size = element_num_ * sizeof(T); | |||||
| input_size_list_.push_back(size); | |||||
| output_size_list_.push_back(size); | |||||
| auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t); | |||||
| output_size_list_.push_back(mask_size); | |||||
| } | |||||
| private: | |||||
| size_t element_num_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> inputs_type; | |||||
| std::vector<TypeId> outputs_type; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| builder.SetInputsDeviceType(inputs_type); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| return builder.Build(); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef AddReluGradV2Fusion::DefinePattern() const { | |||||
| VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_}), mask_}); | |||||
| return relu_grad; | |||||
| } | |||||
| const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]); | |||||
| auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]); | |||||
| auto mask = utils::cast<AnfNodePtr>((*equiv)[mask_]); | |||||
| auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||||
| MS_EXCEPTION_IF_NULL(tensor_add); | |||||
| auto users = GetRealNodeUsedList(graph, tensor_add); | |||||
| if (users->size() > 1) { | |||||
| return nullptr; | |||||
| } | |||||
| auto prim = std::make_shared<Primitive>(kFusedAddReluGradV2Name); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2, mask}; | |||||
| auto add_relugrad = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(add_relugrad); | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relugrad.get()); | |||||
| add_relugrad->set_scope(node->scope()); | |||||
| auto build_info = GenerateKernelBuildInfo(add_relugrad); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relugrad.get()); | |||||
| return add_relugrad; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class AddReluGradV2Fusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit AddReluGradV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_grad", multigraph) { | |||||
| x1_ = std::make_shared<Var>(); | |||||
| x2_ = std::make_shared<Var>(); | |||||
| mask_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~AddReluGradV2Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| VarPtr x1_; | |||||
| VarPtr x2_; | |||||
| VarPtr mask_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELUGRAD_FUSION_H_ | |||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/add_relu_v2_fusion.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> inputs_type; | |||||
| std::vector<TypeId> outputs_type; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| builder.SetInputsDeviceType(inputs_type); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| return builder.Build(); | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef AddReluV2Fusion::DefinePattern() const { | |||||
| VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_})}); | |||||
| return relu; | |||||
| } | |||||
| const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||||
| const EquivPtr &equiv) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(equiv); | |||||
| auto x1 = utils::cast<AnfNodePtr>((*equiv)[x1_]); | |||||
| auto x2 = utils::cast<AnfNodePtr>((*equiv)[x2_]); | |||||
| auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||||
| MS_EXCEPTION_IF_NULL(tensor_add); | |||||
| auto users = GetRealNodeUsedList(graph, tensor_add); | |||||
| if (users->size() > 1) { | |||||
| return nullptr; | |||||
| } | |||||
| auto prim = std::make_shared<Primitive>(kFusedAddReluV2Name); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2}; | |||||
| auto add_relu = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(add_relu); | |||||
| std::vector<TypeId> types; | |||||
| std::vector<std::vector<size_t>> shapes; | |||||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) { | |||||
| types.push_back(AnfAlgo::GetOutputInferDataType(node, i)); | |||||
| shapes.push_back(AnfAlgo::GetOutputInferShape(node, i)); | |||||
| } | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relu.get()); | |||||
| add_relu->set_scope(node->scope()); | |||||
| auto build_info = GenerateKernelBuildInfo(add_relu); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relu.get()); | |||||
| return add_relu; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class AddReluV2Fusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit AddReluV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_v2_fusion", multigraph) { | |||||
| x1_ = std::make_shared<Var>(); | |||||
| x2_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~AddReluV2Fusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| VarPtr x1_; | |||||
| VarPtr x2_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_ | |||||
| @@ -0,0 +1,151 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/gpu/relu_v2_pass.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/utils.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "backend/optimizer/common/helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| const size_t kReluV2OutputNum = 2; | |||||
| CNodePtr GetRelu(const CNodePtr &relu_grad) { | |||||
| MS_EXCEPTION_IF_NULL(relu_grad); | |||||
| if (relu_grad->size() != kReluGradInputNum) { | |||||
| MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); | |||||
| } | |||||
| auto relu_anf = relu_grad->input(2); | |||||
| MS_EXCEPTION_IF_NULL(relu_anf); | |||||
| return relu_anf->cast<CNodePtr>(); | |||||
| } | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||||
| std::vector<std::string> inputs_format; | |||||
| std::vector<std::string> outputs_format; | |||||
| std::vector<TypeId> inputs_type; | |||||
| std::vector<TypeId> outputs_type; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { | |||||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { | |||||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||||
| } | |||||
| builder.SetInputsDeviceType(inputs_type); | |||||
| builder.SetInputsFormat(inputs_format); | |||||
| builder.SetOutputsDeviceType(outputs_type); | |||||
| builder.SetOutputsFormat(outputs_format); | |||||
| return builder.Build(); | |||||
| } | |||||
| CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(relu); | |||||
| if (relu->size() != kReluInputNum) { | |||||
| MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); | |||||
| } | |||||
| auto prim = std::make_shared<Primitive>(kReluV2OpName); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)}; | |||||
| auto new_node = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(new_node); | |||||
| new_node->set_scope(relu->scope()); | |||||
| if (AnfAlgo::IsDynamicShape(relu)) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<size_t> output_shape = AnfAlgo::GetOutputInferShape(relu, 0); | |||||
| auto element_num = | |||||
| std::accumulate(output_shape.begin(), output_shape.end(), static_cast<size_t>(1), std::multiplies<size_t>()); | |||||
| std::vector<size_t> mask_shape = {(element_num + 31) / 32}; | |||||
| auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; | |||||
| auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32}; | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); | |||||
| auto build_info = GenerateKernelBuildInfo(new_node); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get()); | |||||
| return new_node; | |||||
| } | |||||
| CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(relu_grad); | |||||
| MS_EXCEPTION_IF_NULL(second_input); | |||||
| auto prim = std::make_shared<Primitive>(kReluGradV2OpName); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu_grad->input(1), second_input}; | |||||
| auto new_node = graph->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(new_node); | |||||
| new_node->set_scope(relu_grad->scope()); | |||||
| new_node->set_abstract(relu_grad->abstract()); | |||||
| std::vector<TypeId> types; | |||||
| std::vector<std::vector<size_t>> shapes; | |||||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) { | |||||
| types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i)); | |||||
| shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i)); | |||||
| } | |||||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); | |||||
| new_node->set_scope(relu_grad->scope()); | |||||
| auto build_info = GenerateKernelBuildInfo(new_node); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get()); | |||||
| return new_node; | |||||
| } | |||||
| } // namespace | |||||
| const BaseRef ReluV2Pass::DefinePattern() const { | |||||
| VectorRef relu_grad({prim::kPrimReluGrad, dy_, VectorRef({prim::kPrimRelu, x_})}); | |||||
| return relu_grad; | |||||
| } | |||||
| const AnfNodePtr ReluV2Pass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto relu_grad = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(relu_grad); | |||||
| auto relu = GetRelu(relu_grad); | |||||
| MS_EXCEPTION_IF_NULL(relu); | |||||
| auto relu_v2 = CreateReluV2(graph, relu); | |||||
| if (relu_v2 == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<AnfNodePtr> relu_v2_node_outputs; | |||||
| CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); | |||||
| auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); | |||||
| auto manage = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manage); | |||||
| manage->Replace(relu, relu_v2_node_outputs[0]); | |||||
| return relu_grad_v2; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ReluV2Pass : public PatternProcessPass { | |||||
| public: | |||||
| explicit ReluV2Pass(bool multigraph = true) : PatternProcessPass("relu_v2_fusion", multigraph) { | |||||
| x_ = std::make_shared<Var>(); | |||||
| dy_ = std::make_shared<Var>(); | |||||
| } | |||||
| ~ReluV2Pass() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| VarPtr x_; | |||||
| VarPtr dy_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_ | |||||
| @@ -35,6 +35,9 @@ | |||||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | #include "backend/optimizer/gpu/remove_format_transform_pair.h" | ||||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | ||||
| #include "backend/optimizer/gpu/reduce_precision_fusion.h" | #include "backend/optimizer/gpu/reduce_precision_fusion.h" | ||||
| #include "backend/optimizer/gpu/relu_v2_pass.h" | |||||
| #include "backend/optimizer/gpu/add_relu_v2_fusion.h" | |||||
| #include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" | |||||
| #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" | ||||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | ||||
| #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" | ||||
| @@ -142,6 +145,9 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra | |||||
| pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>()); | ||||
| pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>()); | pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>()); | ||||
| pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>()); | pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>()); | ||||
| pm->AddPass(std::make_shared<opt::ReluV2Pass>()); | |||||
| pm->AddPass(std::make_shared<opt::AddReluV2Fusion>()); | |||||
| pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>()); | |||||
| pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | pm->AddPass(std::make_shared<opt::AllReduceFusion>()); | ||||
| pm->AddPass(std::make_shared<opt::GetitemTuple>()); | pm->AddPass(std::make_shared<opt::GetitemTuple>()); | ||||
| pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision")); | pm->AddPass(std::make_shared<opt::ReducePrecisionFusion>("reduce_precision")); | ||||
| @@ -245,6 +245,8 @@ constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; | |||||
| constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | ||||
| constexpr auto kMatMulV2OpName = "MatMulV2"; | constexpr auto kMatMulV2OpName = "MatMulV2"; | ||||
| constexpr auto kBroadcastToOpName = "BroadcastTo"; | constexpr auto kBroadcastToOpName = "BroadcastTo"; | ||||
| constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; | |||||
| constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2"; | |||||
| // Hcom Op Type | // Hcom Op Type | ||||
| constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | ||||
| @@ -146,6 +146,7 @@ inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive | |||||
| inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | ||||
| inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | ||||
| inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad"); | ||||
| inline const PrimitivePtr kPrimReluGradV2 = std::make_shared<Primitive>("ReluGradV2"); | |||||
| inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad"); | inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad"); | ||||
| inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared<Primitive>("Conv2DBackpropInput"); | ||||
| inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive>("Conv2DBackpropFilter"); | ||||
| @@ -0,0 +1,142 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.ops.operations._grad_ops as G | |||||
| class ReluNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ReluNet, self).__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.relu_grad = G.ReluGrad() | |||||
| def construct(self, x, dy): | |||||
| y = self.relu(x) | |||||
| dx = self.relu_grad(dy, y) | |||||
| return y, dx | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ReluV2(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||||
| x = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.float32)) | |||||
| dy = Tensor(np.array([[[[1, 0, 3], | |||||
| [0, 1, 0], | |||||
| [2, 1, 1]]]]).astype(np.float32)) | |||||
| expect_y = np.array([[[[0, 1, 10,], | |||||
| [1, 0, 1,], | |||||
| [10, 1, 0.]]]]).astype(np.float32) | |||||
| expect_dx = np.array([[[[0, 0, 3], | |||||
| [0, 0, 0], | |||||
| [2, 1, 0]]]]).astype(np.float32) | |||||
| net = ReluNet() | |||||
| y, dx = net(Tensor(x), Tensor(dy)) | |||||
| assert np.allclose(y.asnumpy(), expect_y) | |||||
| assert np.allclose(dx.asnumpy(), expect_dx) | |||||
| class AddReluNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(AddReluNet, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| self.relu = P.ReLU() | |||||
| self.relu_grad = G.ReluGrad() | |||||
| def construct(self, x1, x2, dy): | |||||
| y = self.add(x1, x2) | |||||
| y = self.relu(y) | |||||
| dx = self.relu_grad(dy, y) | |||||
| return y, dx | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_AddRelu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||||
| x1 = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.float32)) | |||||
| x2 = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.float32)) | |||||
| dy = Tensor(np.array([[[[1, 0, 3], | |||||
| [0, 1, 0], | |||||
| [2, 1, 1]]]]).astype(np.float32)) | |||||
| expect_y = np.array([[[[0, 2, 20], | |||||
| [2, 0, 2], | |||||
| [20, 2, 0]]]]).astype(np.float32) | |||||
| expect_dx = np.array([[[[0, 0, 3], | |||||
| [0, 0, 0], | |||||
| [2, 1, 0]]]]).astype(np.float32) | |||||
| net = AddReluNet() | |||||
| y, dx1 = net(Tensor(x1), Tensor(x2), Tensor(dy)) | |||||
| assert np.allclose(y.asnumpy(), expect_y) | |||||
| assert np.allclose(dx1.asnumpy(), expect_dx) | |||||
| class AddReluGradNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(AddReluGradNet, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| self.relu = P.ReLU() | |||||
| self.relu_grad = G.ReluGrad() | |||||
| def construct(self, x, dy1, dy2): | |||||
| y = self.relu(x) | |||||
| dy = self.add(dy1, dy2) | |||||
| dx = self.relu_grad(dy, y) | |||||
| return y, dx | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_AddReluGrad(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||||
| x = Tensor(np.array([[[[-1, 1, 10], | |||||
| [1, -1, 1], | |||||
| [10, 1, -1]]]]).astype(np.float32)) | |||||
| dy1 = Tensor(np.array([[[[1, 0, 3], | |||||
| [0, 1, 0], | |||||
| [2, 1, 1]]]]).astype(np.float32)) | |||||
| dy2 = Tensor(np.array([[[[1, 0, 3], | |||||
| [0, 1, 0], | |||||
| [2, 1, 1]]]]).astype(np.float32)) | |||||
| expect_y = np.array([[[[0, 1, 10,], | |||||
| [1, 0, 1,], | |||||
| [10, 1, 0.]]]]).astype(np.float32) | |||||
| expect_dx = np.array([[[[0, 0, 6], | |||||
| [0, 0, 0], | |||||
| [4, 2, 0]]]]).astype(np.float32) | |||||
| net = AddReluGradNet() | |||||
| y, dx1 = net(Tensor(x), Tensor(dy1), Tensor(dy2)) | |||||
| assert np.allclose(y.asnumpy(), expect_y) | |||||
| assert np.allclose(dx1.asnumpy(), expect_dx) | |||||