diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_impl.cuh new file mode 100644 index 0000000000..6eef4e0863 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream); + +template +void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cu new file mode 100644 index 0000000000..65ddef09ed --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cu @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" + +template +__global__ void AddReluV2Kernel(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) { + T sum = x1[i] + x2[i]; + bool p = sum > static_cast(0); + y[i] = p ? sum : static_cast(0); + + auto warp_predict = BallotSync(p, __activemask()); + if (LaneId() == 0) { + mask[WarpId(i)] = warp_predict; + } + } +} + +template +void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream) { + AddReluV2Kernel<<>>(num, x1, x2, y, mask); +} + +template +__global__ void AddReluGradV2Kernel(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) { + bool positive = mask[WarpId(i)] & (1 << LaneId()); + dx[i] = positive ? x1[i] + x2[i] : static_cast(0); + } +} + +template +void AddReluGradV2(const size_t num, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream) { + AddReluGradV2Kernel<<>>(num, x1, x2, mask, dx); +} + +template void AddReluV2(const size_t num, const float *x1, const float *x2, float *y, uint32_t *mask, + cudaStream_t cuda_stream); +template void AddReluV2(const size_t num, const half *x1, const half *x2, half *y, uint32_t *mask, + cudaStream_t cuda_stream); +template void AddReluV2(const size_t num, const int32_t *x1, const int32_t *x2, int32_t *y, uint32_t *mask, + cudaStream_t cuda_stream); +template void AddReluV2(const size_t num, const int64_t *x1, const int64_t *x2, int64_t *y, uint32_t *mask, + cudaStream_t cuda_stream); + +template void AddReluGradV2(const size_t num, const float *x1, const float *x2, const uint32_t *mask, float *dx, + cudaStream_t cuda_stream); +template void AddReluGradV2(const size_t num, const half *x1, const half *x2, const uint32_t *mask, half *dx, + cudaStream_t cuda_stream); +template void AddReluGradV2(const size_t num, const int32_t *x1, const int32_t *x2, const uint32_t *mask, int32_t *dx, + cudaStream_t cuda_stream); +template void AddReluGradV2(const size_t num, const int64_t *x1, const int64_t *x2, const uint32_t *mask, int64_t *dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh new file mode 100644 index 0000000000..6eef4e0863 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_V2_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void AddReluV2(const size_t num, const T *x1, const T *x2, T *y, uint32_t *mask, cudaStream_t cuda_stream); + +template +void AddReluGradV2(const size_t size, const T *x1, const T *x2, const uint32_t *mask, T *dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADD_RELU_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu index 6f3a3ad634..57c2bcdee9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu @@ -15,6 +15,7 @@ */ #include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" #include "runtime/device/gpu/cuda_common.h" template @@ -34,3 +35,47 @@ template void CalReLU(int size, float *input_addr, float *output_addr, cudaStrea template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); + +template +__global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { + T v = x[i]; + bool p = v > static_cast(0); + y[i] = p ? v : static_cast(0); + + auto warp_predict = BallotSync(p, __activemask()); + if (LaneId() == 0) { + mask[WarpId(i)] = warp_predict; + } + } +} + +template +void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream) { + ReluV2Kernel<<>>(num, x, y, mask); +} + +template +__global__ void ReluGradV2Kernel(const size_t num, const T *dy, const uint32_t *mask, T *dx) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { + bool p = mask[WarpId(i)] & (1 << LaneId()); + dx[i] = p ? dy[i] : static_cast(0); + } +} + +template +void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream) { + ReluGradV2Kernel<<>>(num, dy, mask, dx); +} + +template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream); +template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream); +template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream); +template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream); + +template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream); +template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream); +template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx, + cudaStream_t cuda_stream); +template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh index 19e1022479..7918395f6f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh @@ -20,4 +20,9 @@ #include "runtime/device/gpu/cuda_common.h" template void CalReLU(int input_size, T *input_addr, T *output_addr, cudaStream_t cuda_stream); + +template +void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream); +template +void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh index 27f1f9cfd9..9d45959f3c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh @@ -20,16 +20,18 @@ #include #include "runtime/device/gpu/cuda_common.h" +#define kThreadsPerBlock (256) +#define kBlocksPerGrid(n) ((n + kThreadsPerBlock - 1) / kThreadsPerBlock) + __device__ static inline double MsAtomicAdd(double *address, const double val) { - unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT - unsigned long long int old = *address_as_ull; // NOLINT - unsigned long long int assumed; // NOLINT - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); - } - while (assumed != old); // NOLINT - return __longlong_as_double(old); + unsigned long long int *address_as_ull = (unsigned long long int *)address; // NOLINT + unsigned long long int old = *address_as_ull; // NOLINT + unsigned long long int assumed; // NOLINT + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); // NOLINT + return __longlong_as_double(old); } __device__ static inline float MsAtomicAdd(float *address, const float val) { return atomicAdd(address, val); } @@ -42,7 +44,7 @@ __device__ static inline unsigned int MsAtomicAdd(unsigned int *address, unsigne __device__ static inline int8_t MsAtomicAdd(int8_t *address, int8_t val) { size_t offset = (size_t)address & 3; - uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); // NOLINT + uint32_t *address_as_ui = (uint32_t *)((char *)address - offset); // NOLINT uint32_t old = *address_as_ui; uint32_t shift = offset * 8; uint32_t old_byte; @@ -60,27 +62,27 @@ __device__ static inline int8_t MsAtomicAdd(int8_t *address, int8_t val) { } __device__ static inline int64_t MsAtomicAdd(int64_t *address, int64_t val) { - unsigned long long * address_as_ui = (unsigned long long *) (address); // NOLINT - unsigned long long old = *address_as_ui; // NOLINT - unsigned long long newval; // NOLINT - unsigned long long assumed; // NOLINT + unsigned long long *address_as_ui = (unsigned long long *)(address); // NOLINT + unsigned long long old = *address_as_ui; // NOLINT + unsigned long long newval; // NOLINT + unsigned long long assumed; // NOLINT do { assumed = old; - newval = val + (int64_t)old; + newval = val + (int64_t)old; old = atomicCAS(address_as_ui, assumed, newval); } while (assumed != old); return (int64_t)old; } __device__ static inline bool MsAtomicAdd(bool *address, bool val) { - *address = address && val; - return address[0]; + *address = address && val; + return address[0]; } __device__ static inline unsigned char MsAtomicAdd(short *address, short val) { // NOLINT - bool is_4_byte_aligned = ((size_t) address & 2) == 0; - unsigned int *aligned = (unsigned int *) ((size_t) address & ~2); + bool is_4_byte_aligned = ((size_t)address & 2) == 0; + unsigned int *aligned = (unsigned int *)((size_t)address & ~2); unsigned int old = *aligned; unsigned int assumed; @@ -91,16 +93,16 @@ __device__ static inline unsigned char MsAtomicAdd(short *address, short val) { if (is_4_byte_aligned) { replacement = (old & 0xffff0000) | (((old & 0xffff) + val) & 0xffff); } else { - replacement = old + ((unsigned int) val << 16); + replacement = old + ((unsigned int)val << 16); } old = atomicCAS(aligned, assumed, replacement); } while (assumed != old); if (is_4_byte_aligned) { - return (short) (old & 0xffff); // NOLINT + return (short)(old & 0xffff); // NOLINT } else { - return (short) (old >> 16); // NOLINT + return (short)(old >> 16); // NOLINT } } @@ -112,7 +114,8 @@ __device__ static inline half MsAtomicAdd(half *address, half val) { unsigned short old_as_us; // NOLINT do { assumed = old; - old_as_us = static_cast(reinterpret_cast(address) & 2 ? old >> 16 : old & 0xffff); // NOLINT + old_as_us = + static_cast(reinterpret_cast(address) & 2 ? old >> 16 : old & 0xffff); // NOLINT half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast(val)); unsigned short sum_as_us = __half_as_ushort(sum); // NOLINT unsigned int sum_as_ui = @@ -123,16 +126,16 @@ __device__ static inline half MsAtomicAdd(half *address, half val) { return half(raw); } -__device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsigned char val) { +__device__ static inline unsigned char MsAtomicAdd(unsigned char *address, unsigned char val) { // We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to // implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but // unsigned int* must be 4 byte aligned. This variable contains the offset, // in bytes, of the beginning of address, within the 4 byte aligned space that // contains it. - size_t address_offset = (size_t) address & 3; + size_t address_offset = (size_t)address & 3; // Address of the 4 byte aligned space that contains address. - unsigned int* aligned = (unsigned int*) ((unsigned char*) address - address_offset); + unsigned int *aligned = (unsigned int *)((unsigned char *)address - address_offset); // Constants which will be used later with __byte_perm. __byte_perm is a cuda // function which takes 3 unsigned int's (x, y, selector) as parameters and @@ -166,9 +169,9 @@ __device__ static inline unsigned char MsAtomicAdd(unsigned char* address, unsig return __byte_perm(old, 0, address_offset); } -__device__ static inline char MsAtomicAdd(char* address, char val) { - size_t address_offset = (size_t) address & 3; - unsigned int* aligned = reinterpret_cast(reinterpret_cast(address) - address_offset); +__device__ static inline char MsAtomicAdd(char *address, char val) { + size_t address_offset = (size_t)address & 3; + unsigned int *aligned = reinterpret_cast(reinterpret_cast(address) - address_offset); unsigned int selectors[] = {0x3214, 0x3240, 0x3410, 0x4210}; unsigned int selector = selectors[address_offset]; unsigned int old = *aligned; @@ -185,4 +188,12 @@ __device__ static inline char MsAtomicAdd(char* address, char val) { return __byte_perm(old, 0, address_offset); } +__device__ __forceinline__ unsigned BallotSync(int predicate, unsigned mask = 0xffffffff) { + return __ballot_sync(mask, predicate); +} + +enum : unsigned { warp_size = 32, log_wap_size = 5 }; +__device__ __forceinline__ unsigned LaneId() { return threadIdx.x & (warp_size - 1); } +__device__ __forceinline__ unsigned WarpId(const unsigned &tid) { return tid >> log_wap_size; } + #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.cc new file mode 100644 index 0000000000..c937365e54 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeFloat32), + FusedAddReluGradV2GpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeFloat16), + FusedAddReluGradV2GpuKernel, half) +MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeInt32), + FusedAddReluGradV2GpuKernel, int32_t) +MS_REG_GPU_KERNEL_ONE(FusedAddReluGradV2, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeInt64), + FusedAddReluGradV2GpuKernel, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.h new file mode 100644 index 0000000000..85424c99da --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_grad_v2_gpu_kernel.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class FusedAddReluGradV2GpuKernel : public GpuKernel { + public: + FusedAddReluGradV2GpuKernel() { ResetResource(); } + ~FusedAddReluGradV2GpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto x1 = GetDeviceAddress(inputs, 0); + auto x2 = GetDeviceAddress(inputs, 1); + auto mask = GetDeviceAddress(inputs, 2); + auto dx = GetDeviceAddress(outputs, 0); + + AddReluGradV2(element_num_, x1, x2, mask, dx, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + MS_EXCEPTION_IF_NULL(kernel_node); + auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + element_num_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + auto size = element_num_ * sizeof(T); + input_size_list_.push_back(size); + input_size_list_.push_back(size); + input_size_list_.push_back(size); + output_size_list_.push_back(size); + + size = (element_num_ + 31) / 32 * sizeof(uint32_t); + input_size_list_.push_back(size); + } + + private: + size_t element_num_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_GRAD_V2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.cc new file mode 100644 index 0000000000..ff486c0c46 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeUInt32), + FusedAddReluV2GpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeUInt32), + FusedAddReluV2GpuKernel, half) +MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt32), + FusedAddReluV2GpuKernel, int32_t) +MS_REG_GPU_KERNEL_ONE(FusedAddReluV2, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt32), + FusedAddReluV2GpuKernel, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.h new file mode 100644 index 0000000000..a9549bde9c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_add_relu_v2_gpu_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/add_relu_v2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class FusedAddReluV2GpuKernel : public GpuKernel { + public: + FusedAddReluV2GpuKernel() { ResetResource(); } + ~FusedAddReluV2GpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto x1 = GetDeviceAddress(inputs, 0); + auto x2 = GetDeviceAddress(inputs, 1); + auto y = GetDeviceAddress(outputs, 0); + auto mask = GetDeviceAddress(outputs, 1); + + AddReluV2(element_num_, x1, x2, y, mask, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + MS_EXCEPTION_IF_NULL(kernel_node); + auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + element_num_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + auto size = element_num_ * sizeof(T); + input_size_list_.push_back(size); + input_size_list_.push_back(size); + output_size_list_.push_back(size); + + size = (element_num_ + 31) / 32 * sizeof(uint32_t); + output_size_list_.push_back(size); + } + + private: + size_t element_num_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_ADD_RELU_V2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc new file mode 100644 index 0000000000..2739eac1e2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), + ReluGradV2GpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), + ReluGradV2GpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), + ReluGradV2GpuKernel, int32_t) +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), + ReluGradV2GpuKernel, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.h new file mode 100644 index 0000000000..8f64020854 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_V2_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ReluGradV2GpuKernel : public GpuKernel { + public: + ReluGradV2GpuKernel() { ResetResource(); } + ~ReluGradV2GpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto dy = GetDeviceAddress(inputs, 0); + auto mask = GetDeviceAddress(inputs, 1); + auto dx = GetDeviceAddress(outputs, 0); + + ReluGradV2(element_num_, dy, mask, dx, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + MS_EXCEPTION_IF_NULL(kernel_node); + auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + element_num_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + auto size = element_num_ * sizeof(T); + input_size_list_.push_back(size); + output_size_list_.push_back(size); + + auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t); + input_size_list_.push_back(mask_size); + } + + private: + size_t element_num_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc new file mode 100644 index 0000000000..566c27f252 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ReLUV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ReLUV2, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, int32_t) +MS_REG_GPU_KERNEL_ONE( + ReLUV2, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h new file mode 100644 index 0000000000..c6cf25a2f4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ReluV2GpuKernel : public GpuKernel { + public: + ReluV2GpuKernel() { ResetResource(); } + ~ReluV2GpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto x = GetDeviceAddress(inputs, 0); + auto y = GetDeviceAddress(outputs, 0); + auto mask = GetDeviceAddress(outputs, 1); + + ReluV2(element_num_, x, y, mask, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + MS_EXCEPTION_IF_NULL(kernel_node); + auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + element_num_ = std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + element_num_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + auto size = element_num_ * sizeof(T); + input_size_list_.push_back(size); + output_size_list_.push_back(size); + auto mask_size = (element_num_ + 31) / 32 * sizeof(uint32_t); + output_size_list_.push_back(mask_size); + } + + private: + size_t element_num_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc new file mode 100644 index 0000000000..c7e05c6b9c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AddReluGradV2Fusion::DefinePattern() const { + VectorRef relu_grad = VectorRef({prim::kPrimReluGradV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_}), mask_}); + return relu_grad; +} + +const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto x1 = utils::cast((*equiv)[x1_]); + auto x2 = utils::cast((*equiv)[x2_]); + auto mask = utils::cast((*equiv)[mask_]); + + auto tensor_add = AnfAlgo::GetInputNode(utils::cast(node), 0); + MS_EXCEPTION_IF_NULL(tensor_add); + auto users = GetRealNodeUsedList(graph, tensor_add); + if (users->size() > 1) { + return nullptr; + } + + auto prim = std::make_shared(kFusedAddReluGradV2Name); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), x1, x2, mask}; + auto add_relugrad = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(add_relugrad); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relugrad.get()); + add_relugrad->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(add_relugrad); + AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relugrad.get()); + return add_relugrad; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.h new file mode 100644 index 0000000000..b41f35dfee --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_GRAD_V2_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AddReluGradV2Fusion : public PatternProcessPass { + public: + explicit AddReluGradV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_grad", multigraph) { + x1_ = std::make_shared(); + x2_ = std::make_shared(); + mask_ = std::make_shared(); + } + ~AddReluGradV2Fusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x1_; + VarPtr x2_; + VarPtr mask_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELUGRAD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc new file mode 100644 index 0000000000..f78f0d87aa --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/add_relu_v2_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AddReluV2Fusion::DefinePattern() const { + VectorRef relu = VectorRef({prim::kPrimReluV2, VectorRef({prim::kPrimTensorAdd, x1_, x2_})}); + return relu; +} + +const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto x1 = utils::cast((*equiv)[x1_]); + auto x2 = utils::cast((*equiv)[x2_]); + + auto tensor_add = AnfAlgo::GetInputNode(utils::cast(node), 0); + MS_EXCEPTION_IF_NULL(tensor_add); + auto users = GetRealNodeUsedList(graph, tensor_add); + if (users->size() > 1) { + return nullptr; + } + + auto prim = std::make_shared(kFusedAddReluV2Name); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), x1, x2}; + auto add_relu = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(add_relu); + + std::vector types; + std::vector> shapes; + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) { + types.push_back(AnfAlgo::GetOutputInferDataType(node, i)); + shapes.push_back(AnfAlgo::GetOutputInferShape(node, i)); + } + + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, add_relu.get()); + add_relu->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(add_relu); + AnfAlgo::SetSelectKernelBuildInfo(build_info, add_relu.get()); + return add_relu; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.h new file mode 100644 index 0000000000..c0c3465447 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AddReluV2Fusion : public PatternProcessPass { + public: + explicit AddReluV2Fusion(bool multigraph = true) : PatternProcessPass("add_relu_v2_fusion", multigraph) { + x1_ = std::make_shared(); + x2_ = std::make_shared(); + } + ~AddReluV2Fusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x1_; + VarPtr x2_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_ADD_RELU_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc b/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc new file mode 100644 index 0000000000..c0094d827a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/relu_v2_pass.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kReluV2OutputNum = 2; + +CNodePtr GetRelu(const CNodePtr &relu_grad) { + MS_EXCEPTION_IF_NULL(relu_grad); + if (relu_grad->size() != kReluGradInputNum) { + MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); + } + auto relu_anf = relu_grad->input(2); + MS_EXCEPTION_IF_NULL(relu_anf); + return relu_anf->cast(); +} + +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} + +CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu); + if (relu->size() != kReluInputNum) { + MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); + } + + auto prim = std::make_shared(kReluV2OpName); + std::vector inputs = {NewValueNode(prim), relu->input(1)}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu->scope()); + + if (AnfAlgo::IsDynamicShape(relu)) { + return nullptr; + } + std::vector output_shape = AnfAlgo::GetOutputInferShape(relu, 0); + auto element_num = + std::accumulate(output_shape.begin(), output_shape.end(), static_cast(1), std::multiplies()); + + std::vector mask_shape = {(element_num + 31) / 32}; + auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; + auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), kNumberTypeUInt32}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); + + auto build_info = GenerateKernelBuildInfo(new_node); + AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get()); + return new_node; +} + +CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu_grad); + MS_EXCEPTION_IF_NULL(second_input); + + auto prim = std::make_shared(kReluGradV2OpName); + std::vector inputs = {NewValueNode(prim), relu_grad->input(1), second_input}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu_grad->scope()); + new_node->set_abstract(relu_grad->abstract()); + + std::vector types; + std::vector> shapes; + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) { + types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i)); + shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i)); + } + + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); + new_node->set_scope(relu_grad->scope()); + + auto build_info = GenerateKernelBuildInfo(new_node); + AnfAlgo::SetSelectKernelBuildInfo(build_info, new_node.get()); + + return new_node; +} +} // namespace + +const BaseRef ReluV2Pass::DefinePattern() const { + VectorRef relu_grad({prim::kPrimReluGrad, dy_, VectorRef({prim::kPrimRelu, x_})}); + return relu_grad; +} + +const AnfNodePtr ReluV2Pass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto relu_grad = node->cast(); + MS_EXCEPTION_IF_NULL(relu_grad); + auto relu = GetRelu(relu_grad); + MS_EXCEPTION_IF_NULL(relu); + + auto relu_v2 = CreateReluV2(graph, relu); + if (relu_v2 == nullptr) { + return nullptr; + } + std::vector relu_v2_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); + + auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); + auto manage = graph->manager(); + MS_EXCEPTION_IF_NULL(manage); + manage->Replace(relu, relu_v2_node_outputs[0]); + return relu_grad_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.h b/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.h new file mode 100644 index 0000000000..abaa51d90c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReluV2Pass : public PatternProcessPass { + public: + explicit ReluV2Pass(bool multigraph = true) : PatternProcessPass("relu_v2_fusion", multigraph) { + x_ = std::make_shared(); + dy_ = std::make_shared(); + } + ~ReluV2Pass() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x_; + VarPtr dy_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_RELU_V2_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 994c5c5f6d..95418386f9 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -35,6 +35,9 @@ #include "backend/optimizer/gpu/remove_format_transform_pair.h" #include "backend/optimizer/gpu/remove_redundant_format_transform.h" #include "backend/optimizer/gpu/reduce_precision_fusion.h" +#include "backend/optimizer/gpu/relu_v2_pass.h" +#include "backend/optimizer/gpu/add_relu_v2_fusion.h" +#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h" #include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h" #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h" @@ -142,6 +145,9 @@ void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_gra pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared("reduce_precision")); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index dd50c95c53..aa7d62277b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -245,6 +245,8 @@ constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; constexpr auto kMatMulV2OpName = "MatMulV2"; constexpr auto kBroadcastToOpName = "BroadcastTo"; +constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; +constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2"; // Hcom Op Type constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8f433a324c..3e0017c3f9 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -146,6 +146,7 @@ inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared("BatchNorm"); inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); inline const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); +inline const PrimitivePtr kPrimReluGradV2 = std::make_shared("ReluGradV2"); inline const PrimitivePtr kPrimRelu6Grad = std::make_shared("ReLU6Grad"); inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); diff --git a/tests/st/ops/gpu/test_relu_v2.py b/tests/st/ops/gpu/test_relu_v2.py new file mode 100644 index 0000000000..cefc3007b5 --- /dev/null +++ b/tests/st/ops/gpu/test_relu_v2.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.ops.operations._grad_ops as G + + +class ReluNet(nn.Cell): + def __init__(self): + super(ReluNet, self).__init__() + self.relu = P.ReLU() + self.relu_grad = G.ReluGrad() + + def construct(self, x, dy): + y = self.relu(x) + dx = self.relu_grad(dy, y) + return y, dx + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ReluV2(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) + + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.float32)) + dy = Tensor(np.array([[[[1, 0, 3], + [0, 1, 0], + [2, 1, 1]]]]).astype(np.float32)) + expect_y = np.array([[[[0, 1, 10,], + [1, 0, 1,], + [10, 1, 0.]]]]).astype(np.float32) + expect_dx = np.array([[[[0, 0, 3], + [0, 0, 0], + [2, 1, 0]]]]).astype(np.float32) + net = ReluNet() + y, dx = net(Tensor(x), Tensor(dy)) + + assert np.allclose(y.asnumpy(), expect_y) + assert np.allclose(dx.asnumpy(), expect_dx) + + +class AddReluNet(nn.Cell): + def __init__(self): + super(AddReluNet, self).__init__() + self.add = P.TensorAdd() + self.relu = P.ReLU() + self.relu_grad = G.ReluGrad() + + def construct(self, x1, x2, dy): + y = self.add(x1, x2) + y = self.relu(y) + dx = self.relu_grad(dy, y) + return y, dx + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_AddRelu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) + + x1 = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.float32)) + x2 = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.float32)) + dy = Tensor(np.array([[[[1, 0, 3], + [0, 1, 0], + [2, 1, 1]]]]).astype(np.float32)) + expect_y = np.array([[[[0, 2, 20], + [2, 0, 2], + [20, 2, 0]]]]).astype(np.float32) + expect_dx = np.array([[[[0, 0, 3], + [0, 0, 0], + [2, 1, 0]]]]).astype(np.float32) + net = AddReluNet() + y, dx1 = net(Tensor(x1), Tensor(x2), Tensor(dy)) + + assert np.allclose(y.asnumpy(), expect_y) + assert np.allclose(dx1.asnumpy(), expect_dx) + +class AddReluGradNet(nn.Cell): + def __init__(self): + super(AddReluGradNet, self).__init__() + self.add = P.TensorAdd() + self.relu = P.ReLU() + self.relu_grad = G.ReluGrad() + + def construct(self, x, dy1, dy2): + y = self.relu(x) + dy = self.add(dy1, dy2) + dx = self.relu_grad(dy, y) + return y, dx + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_AddReluGrad(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) + + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.float32)) + dy1 = Tensor(np.array([[[[1, 0, 3], + [0, 1, 0], + [2, 1, 1]]]]).astype(np.float32)) + dy2 = Tensor(np.array([[[[1, 0, 3], + [0, 1, 0], + [2, 1, 1]]]]).astype(np.float32)) + expect_y = np.array([[[[0, 1, 10,], + [1, 0, 1,], + [10, 1, 0.]]]]).astype(np.float32) + expect_dx = np.array([[[[0, 0, 6], + [0, 0, 0], + [4, 2, 0]]]]).astype(np.float32) + net = AddReluGradNet() + y, dx1 = net(Tensor(x), Tensor(dy1), Tensor(dy2)) + + assert np.allclose(y.asnumpy(), expect_y) + assert np.allclose(dx1.asnumpy(), expect_dx)