| @@ -26,3 +26,7 @@ from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad | |||
| from .mean import SimpleMean, gpu_schedule_SimpleMean | |||
| from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad | |||
| from .mul import Mul, gpu_schedule_Mul | |||
| from .hsigmoid import Hsigmoid, gpu_schedule_Hsigmoid | |||
| from .hsigmoid_grad import HsigmoidGrad, gpu_schedule_HsigmoidGrad | |||
| from .hswish import Hswish, gpu_schedule_Hswish | |||
| from .hswish_grad import HswishGrad, gpu_schedule_HswishGrad | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """hsigmoid""" | |||
| import _akg.topi as topi | |||
| import _akg.tvm as tvm | |||
| from _akg.topi import tag | |||
| @tvm.tag_scope(tag=tag.ELEMWISE) | |||
| def topi_nn_hsigmoid(x): | |||
| """ | |||
| topi hsigmoid | |||
| Args: | |||
| x: | |||
| Returns: | |||
| """ | |||
| return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, | |||
| tvm.if_then_else(x(*i) >= 3, 1, | |||
| (x(*i) + 3) / 6))) | |||
| def Hsigmoid(x): | |||
| """ | |||
| Hsigmoid | |||
| Args: | |||
| x: | |||
| Returns: | |||
| """ | |||
| return topi_nn_hsigmoid(x) | |||
| def gpu_schedule_Hsigmoid(outs): | |||
| """ | |||
| gpu schedule Hsigmoid | |||
| Args: | |||
| outs: | |||
| Returns: | |||
| """ | |||
| device = 'cuda' | |||
| ctx = tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with tvm.target.create(device): | |||
| sch = topi.cuda.schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Hsigmoid grad""" | |||
| import _akg.topi as topi | |||
| import _akg.tvm as tvm | |||
| def HsigmoidGrad(y_grad, x): | |||
| """ | |||
| HsigmoidGrad | |||
| Args: | |||
| y_grad: | |||
| x: | |||
| Returns: | |||
| """ | |||
| return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, | |||
| tvm.if_then_else(x(*i) >= 3, 0, | |||
| y_grad(*i) / 6))) | |||
| def gpu_schedule_HsigmoidGrad(outs): | |||
| """ | |||
| gpu schedule ReLU6Grad | |||
| Args: | |||
| outs: | |||
| Returns: | |||
| """ | |||
| device = 'cuda' | |||
| ctx = tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with tvm.target.create(device): | |||
| sch = topi.cuda.schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """hswish""" | |||
| import _akg.topi as topi | |||
| import _akg.tvm as tvm | |||
| from _akg.topi import tag | |||
| @tvm.tag_scope(tag=tag.ELEMWISE) | |||
| def topi_nn_hswish(x): | |||
| """ | |||
| topi hswish | |||
| Args: | |||
| x: | |||
| Returns: | |||
| """ | |||
| return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, | |||
| tvm.if_then_else(x(*i) >= 3, x(*i), | |||
| x(*i) * (x(*i) + 3) / 6))) | |||
| def Hswish(x): | |||
| """ | |||
| Hswish | |||
| Args: | |||
| x: | |||
| Returns: | |||
| """ | |||
| return topi_nn_hswish(x) | |||
| def gpu_schedule_Hswish(outs): | |||
| """ | |||
| gpu schedule Hswish | |||
| Args: | |||
| outs: | |||
| Returns: | |||
| """ | |||
| device = 'cuda' | |||
| ctx = tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with tvm.target.create(device): | |||
| sch = topi.cuda.schedule_elemwise(outs) | |||
| return sch | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """HswishGrad""" | |||
| import _akg.topi as topi | |||
| import _akg.tvm as tvm | |||
| def HswishGrad(y_grad, x): | |||
| """ | |||
| HswishGrad | |||
| Args: | |||
| y_grad: | |||
| x: | |||
| Returns: | |||
| """ | |||
| shape = x.shape | |||
| res0 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, y_grad(*i) * (2 * x(*i) + 3) / 6)) | |||
| res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= 3, y_grad(*i), res0(*i))) | |||
| return res6 | |||
| def gpu_schedule_HswishGrad(outs): | |||
| """ | |||
| gpu schedule HswishGrad | |||
| Args: | |||
| outs: | |||
| Returns: | |||
| """ | |||
| device = 'cuda' | |||
| ctx = tvm.context(device, 0) | |||
| if not ctx.exist: | |||
| raise SystemError("Skip because %s is not enabled" % device) | |||
| with tvm.target.create(device): | |||
| sch = topi.cuda.schedule_elemwise(outs) | |||
| return sch | |||
| @@ -300,6 +300,13 @@ class ParamValidator: | |||
| for arg, value in args.items(): | |||
| ParamValidator.check_subclass(arg, value, mstype.tensor) | |||
| @staticmethod | |||
| def check_bool(arg_name, arg_value): | |||
| """Check arg isintance of bool""" | |||
| if not isinstance(arg_value, bool): | |||
| raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_type(arg_name, arg_value, valid_types): | |||
| """Type checking.""" | |||
| @@ -473,6 +473,7 @@ if(ENABLE_GPU) | |||
| gpu_cuda_lib | |||
| gpu_queue | |||
| cublas | |||
| ${CUDA_PATH}/lib64/libcurand.so | |||
| ${CUDNN_PATH}/lib64/libcudnn.so | |||
| ${CUDA_PATH}/lib64/libcudart.so | |||
| ${CUDA_PATH}/lib64/stubs/libcuda.so) | |||
| @@ -0,0 +1,169 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdint.h> | |||
| #include <thrust/device_ptr.h> | |||
| #include <thrust/fill.h> | |||
| #include <thrust/reduce.h> | |||
| #include <thrust/system/cuda/execution_policy.h> | |||
| #include "batchnorm_fold2_impl.cuh" | |||
| #include "batchnorm_fold_impl.cuh" | |||
| #include "include/cuda_runtime.h" | |||
| template <typename T> | |||
| __global__ void BatchNormFold2Kernel(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean, | |||
| const T *running_std, const T *running_mean, const int *global_step, T *y, | |||
| int freeze_bn, size_t N, size_t C, size_t H, size_t W) { | |||
| int c = 0; | |||
| size_t num_count = N * C * H * W; | |||
| if (*global_step < freeze_bn) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | |||
| c = i / (H * W) % C; | |||
| y[i] = x[i] * running_std[c] / batch_std[c] + beta[c] - gamma[c] * batch_mean[c] / batch_std[c]; | |||
| } | |||
| } else { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | |||
| c = i / (H * W) % C; | |||
| y[i] = x[i] + beta[c] - gamma[c] * running_mean[c] / running_std[c]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void BatchNormFold2GradReduce1(const T *dout, T *tmp, const T *x, T *tmp2, size_t N, size_t C, size_t HW) { | |||
| int n = 0; | |||
| int c = 0; | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N * C; i += blockDim.x * gridDim.x) { | |||
| n = i / C; | |||
| c = i % C; | |||
| tmp[c * N + n] = thrust::reduce(thrust::seq, dout + i * HW, dout + (i + 1) * HW, 0.f, thrust::plus<T>()); | |||
| tmp2[c * N + n] = thrust::reduce(thrust::seq, x + i * HW, x + (i + 1) * HW, 0.f, thrust::plus<T>()); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void BatchNormFold2GradReduce2(const T *tmp, T *d_beta, const T *tmp2, T *reduce_x, size_t N, size_t C) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) { | |||
| d_beta[i] = thrust::reduce(thrust::seq, tmp + i * N, tmp + (i + 1) * N, 0.f, thrust::plus<T>()); | |||
| reduce_x[i] = thrust::reduce(thrust::seq, tmp2 + i * N, tmp2 + (i + 1) * N, 0.f, thrust::plus<T>()); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void BatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, | |||
| const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, | |||
| T *d_batch_mean, T *d_batch_std, size_t C) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) { | |||
| d_gamma[i] = -d_beta[i] * batch_mean[i] / batch_std[i]; | |||
| d_batch_mean[i] = -d_beta[i] * gamma[i] / batch_std[i]; | |||
| d_batch_std[i] = | |||
| (d_beta[i] * gamma[i] * batch_mean[i] - reduce_x[i] * running_std[i]) / batch_std[i] / batch_std[i]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void BatchNormFold2GradFreeze(const T *d_beta, const T *running_mean, const T *running_std, T *d_gamma, | |||
| size_t C) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) { | |||
| d_gamma[i] = -d_beta[i] * running_mean[i] / running_std[i]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void BatchNormFold2GradMul(const T *dout, const T *x, T *tmp_x, size_t NCHW) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < NCHW; i += blockDim.x * gridDim.x) { | |||
| tmp_x[i] = dout[i] * x[i]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void DxMul(size_t N, size_t C, size_t HW, const T *batch_std, const T *running_std, T *d_x) { | |||
| int c = 0; | |||
| size_t num_count = N * C * HW; | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | |||
| c = (i / HW) % C; | |||
| d_x[i] = d_x[i] * running_std[c] / batch_std[c]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean, | |||
| const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn, | |||
| size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream) { | |||
| auto num_count = N * C * H * W; | |||
| BatchNormFold2Kernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>( | |||
| x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, y, freeze_bn, N, C, H, W); | |||
| } | |||
| template void BatchNormFold2Forward<float>(const float *x, const float *beta, const float *gamma, | |||
| const float *batch_std, const float *batch_mean, const float *running_std, | |||
| const float *running_mean, const int *global_step, float *y, int freeze_bn, | |||
| size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N, | |||
| size_t C, size_t H, size_t W, cudaStream_t cuda_stream) { | |||
| auto hw = H * W; | |||
| auto num_count = N * C * H * W; | |||
| BatchNormFold2GradMul<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dout, x, tmp_x, num_count); | |||
| BatchNormFold2GradReduce1<<<GET_BLOCKS(N * C), GET_THREADS, 0, cuda_stream>>>(dout, tmp, tmp_x, tmp2, N, C, hw); | |||
| BatchNormFold2GradReduce2<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(tmp, d_beta, tmp2, reduce_x, N, C); | |||
| } | |||
| template void BatchNormFold2GradReduce<float>(const float *dout, const float *x, float *d_beta, float *tmp, | |||
| float *reduce_x, float *tmp2, float *tmp_x, size_t N, size_t C, size_t H, | |||
| size_t W, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, | |||
| const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, | |||
| T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream) { | |||
| BatchNormFold2GradNotFreeze<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>( | |||
| d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, d_batch_mean, d_batch_std, C); | |||
| } | |||
| template void CalBatchNormFold2GradNotFreeze<float>(const float *d_beta, const float *reduce_x, const float *batch_mean, | |||
| const float *batch_std, const float *running_mean, | |||
| const float *running_std, const float *gamma, float *d_gamma, | |||
| float *d_batch_mean, float *d_batch_std, size_t C, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, | |||
| const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, | |||
| T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream) { | |||
| BatchNormFold2GradFreeze<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(d_beta, running_mean, running_std, d_gamma, | |||
| C); | |||
| ThrustFillWith(d_batch_mean, C, (T)0.f, cuda_stream); | |||
| ThrustFillWith(d_batch_std, C, (T)0.f, cuda_stream); | |||
| } | |||
| template void CalBatchNormFold2GradFreeze<float>(const float *d_beta, const float *reduce_x, const float *batch_mean, | |||
| const float *batch_std, const float *running_mean, | |||
| const float *running_std, const float *gamma, float *d_gamma, | |||
| float *d_batch_mean, float *d_batch_std, size_t C, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H, | |||
| size_t W, cudaStream_t cuda_stream) { | |||
| DxMul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(N, C, H * W, batch_std, running_std, d_x); | |||
| } | |||
| template void CalBatchNormFold2GradNotFreezeDxMul<float>(const float *batch_std, const float *running_std, float *d_x, | |||
| size_t N, size_t C, size_t H, size_t W, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean, | |||
| const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn, | |||
| size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, | |||
| const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, | |||
| T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, | |||
| const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, | |||
| T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N, | |||
| size_t C, size_t H, size_t W, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H, | |||
| size_t W, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <thrust/device_ptr.h> | |||
| #include <thrust/fill.h> | |||
| #include <thrust/system/cuda/execution_policy.h> | |||
| #include "batchnorm_fold_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void UpdateRunningStd(int channel_size, const double epsilon, T* running_std) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { | |||
| running_std[i] = sqrtf(running_std[i] + epsilon); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void UpdateBatchStd(int channel_size, T* batch_std) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { | |||
| batch_std[i] = 1 / batch_std[i]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void CalDx(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, const T* batch_std, | |||
| int batch_size, int channel_size, int height, int width, T* dx) { | |||
| int n = batch_size * channel_size * height * width; | |||
| int normal_size = batch_size * height * width; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { | |||
| int channel_index = i / (height * width) % channel_size; | |||
| dx[i] = d_batch_mean[channel_index] / normal_size + | |||
| d_batch_std[channel_index] * (x[i] - batch_mean[channel_index]) / batch_std[channel_index] / normal_size; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream) { | |||
| UpdateRunningStd<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(channel_size, epsilon, running_std); | |||
| return; | |||
| } | |||
| template void CalUpdateRunningStd<float>(int channel_size, double epsilon, float* running_std, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream) { | |||
| UpdateBatchStd<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(channel_size, batch_std); | |||
| return; | |||
| } | |||
| template void CalUpdateBatchStd<float>(int channel_size, float* batch_std, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, | |||
| const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx, | |||
| cudaStream_t cuda_stream) { | |||
| CalDx<<<GET_BLOCKS(batch_size * channel_size * height * width), GET_THREADS, 0, cuda_stream>>>( | |||
| d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_size, channel_size, height, width, dx); | |||
| } | |||
| template void CalBatchNormFoldGrad<float>(const float* d_batch_mean, const float* d_batch_std, const float* x, | |||
| const float* batch_mean, const float* batch_std, int batch_size, | |||
| int channel_size, int height, int width, float* dx, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream) { | |||
| thrust::device_ptr<T> dev_ptr(array); | |||
| thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + size, tofill); | |||
| } | |||
| template void ThrustFillWith<float>(float* array, int size, float tofill, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORM_FOLD_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORM_FOLD_H_ | |||
| template <typename T> | |||
| void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, | |||
| const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BATCHNORM_FOLD_H_ | |||
| @@ -41,3 +41,4 @@ template void CalConcatV2(const size_t size, const int w1, const int w2, const i | |||
| int* output, cudaStream_t cuda_stream); | |||
| template void CalConcatV2(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, | |||
| half* output, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <thrust/reduce.h> | |||
| #include "correction_mul_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void CorrectionMul(const T* weight, const T* gamma, const T* running_std, const int batchsize, const int chw, | |||
| T* output) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batchsize * chw; i += blockDim.x * gridDim.x) { | |||
| int n = i / chw; | |||
| output[i] = weight[i] * gamma[n] / running_std[n]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void Mul(int N, const T* a, const T* b, T* c) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { | |||
| c[i] = a[i] * b[i]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void Reduce(int N, int CHW, const T* tmp, const T* running_std, T* d_gamma) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { | |||
| d_gamma[i] = thrust::reduce(thrust::seq, tmp + i * CHW, tmp + (i + 1) * CHW, 0.f, thrust::plus<T>()); | |||
| d_gamma[i] = d_gamma[i] / running_std[i]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int N, int C, int H, int W, T* output, | |||
| cudaStream_t cuda_stream) { | |||
| CorrectionMul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(weight, gamma, running_std, N, C * H * W, | |||
| output); | |||
| } | |||
| template void CalCorrectionMul<float>(const float* weight, const float* gamma, const float* running_std, int N, int C, | |||
| int H, int W, float* output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int N, int C, int H, int W, T* d_gamma, | |||
| T* tmp, cudaStream_t cuda_stream) { | |||
| Mul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(N * C * H * W, d_out, weight, tmp); | |||
| Reduce<<<GET_BLOCKS(N), GET_THREADS, 0, cuda_stream>>>(N, C * H * W, tmp, running_std, d_gamma); | |||
| } | |||
| template void CalCorrectionMulGrad<float>(const float* d_out, const float* weight, const float* running_std, int N, | |||
| int C, int H, int W, float* d_gamma, float* tmp, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CORRECTIONMUL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CORRECTIONMUL_H_ | |||
| template <typename T> | |||
| void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int batch_size, int channel_size, | |||
| int height, int width, T* output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int batch_size, int channel_size, | |||
| int height, int width, T* d_gamma, T* tmp, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CORRECTIONMUL_H_ | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdint.h> | |||
| #include "cross_entropy_cuda_impl.cuh" | |||
| #include "include/cuda_runtime.h" | |||
| __global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits, | |||
| const float *labels, const int batch_size, const int num_classes, | |||
| float *loss, float *dx) { | |||
| extern __shared__ float loss_shared[]; | |||
| const float mean_scale = 1.0f / static_cast<float>(batch_size); | |||
| loss_shared[threadIdx.x] = 0; | |||
| for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) { | |||
| loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i]; | |||
| dx[i] = (softmax_logits[i] - labels[i]) * mean_scale; | |||
| } | |||
| __syncthreads(); | |||
| if (threadIdx.x == 0) { | |||
| *loss = 0; | |||
| for (int i = 0; i < batch_size; i++) { | |||
| *loss += loss_shared[i]; | |||
| } | |||
| *loss *= mean_scale; | |||
| } | |||
| } | |||
| void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, | |||
| const int batch_size, const int num_classes, float *loss, float *dx, | |||
| cudaStream_t cuda_stream) { | |||
| CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>( | |||
| softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx); | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, | |||
| const int batch_size, const int num_classes, float *loss, float *dx, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdint.h> | |||
| #include "dropout_impl.cuh" | |||
| #include "include/cuda_runtime.h" | |||
| __global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count, | |||
| float drop_prob) { | |||
| float scale = 1.f / (1.f - drop_prob); | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | |||
| mask[i] = mask[i] > drop_prob; | |||
| output[i] = scale * input[i] * mask[i]; | |||
| } | |||
| } | |||
| void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob, | |||
| cudaStream_t cuda_stream) { | |||
| DropoutForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, num_count, | |||
| drop_prob); | |||
| } | |||
| __global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count, | |||
| float drop_prob) { | |||
| float scale = 1.f / (1.f - drop_prob); | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) { | |||
| dx[i] = scale * dy[i] * mask[i]; | |||
| } | |||
| } | |||
| void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob, | |||
| cudaStream_t cuda_stream) { | |||
| DropoutBackwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dy, mask, dx, num_count, drop_prob); | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob, | |||
| cudaStream_t cuda_stream); | |||
| void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <thrust/extrema.h> | |||
| #include <thrust/device_vector.h> | |||
| #include <thrust/pair.h> | |||
| #include "device/gpu/cuda_common.h" | |||
| #include "fake_quant_impl.cuh" | |||
| __global__ void FakeQuantize(const float* input, float* output, const int size, const float* nudge_min, | |||
| const float* nudge_max, const float* scale, bool symmetric) { | |||
| float input_x = 0.f; | |||
| int nudge_input = 0; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { | |||
| input_x = input[i]; | |||
| // clamp input x | |||
| if (input_x < nudge_min[0]) { | |||
| input_x = nudge_min[0]; | |||
| } | |||
| if (input_x > nudge_max[0]) { | |||
| input_x = nudge_max[0]; | |||
| } | |||
| // clamp shift | |||
| nudge_input = floor((input_x - nudge_min[0]) / scale[0] + 0.5f); | |||
| // quantize | |||
| output[i] = nudge_input * scale[0] + nudge_min[0]; | |||
| } | |||
| return; | |||
| } | |||
| __global__ void FakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, | |||
| const float* nudge_min, const float* nudge_max) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { | |||
| if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { | |||
| output[i] = 0; | |||
| } else { | |||
| output[i] = gradient[i]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| __global__ void NudgeMinMax(const float* input_min, const float* input_max, const float quant_min, | |||
| const float quant_max, float* nudge_min, float* nudge_max, float* scale) { | |||
| float zp_from_min = 0.f; | |||
| if ((quant_max - quant_min) == 0 || (*input_max - *input_min) == 0) { | |||
| *scale = 0.f; | |||
| zp_from_min = 0.f; | |||
| } else { | |||
| *scale = (*input_max - *input_min) / (quant_max - quant_min); | |||
| zp_from_min = quant_min - *input_min / *scale; | |||
| } | |||
| float nudge_zp = 0.f; | |||
| if (zp_from_min <= quant_min) { | |||
| nudge_zp = quant_min; | |||
| } else if (zp_from_min >= quant_max) { | |||
| nudge_zp = quant_max; | |||
| } else { | |||
| nudge_zp = round(zp_from_min); | |||
| } | |||
| *nudge_min = (quant_min - nudge_zp) * (*scale); | |||
| *nudge_max = (quant_max - nudge_zp) * (*scale); | |||
| return; | |||
| } | |||
| __global__ void UpdateInputMinMaxWithEMA(float* input_min, float* input_max, const float min, const float max, | |||
| const float decay) { | |||
| *input_min = decay * (min) + (1 - decay) * (*input_min); | |||
| *input_min = *input_min > 0 ? 0 : *input_min; | |||
| *input_max = decay * (max) + (1 - decay) * (*input_max); | |||
| *input_max = *input_max < 0 ? 0 : *input_max; | |||
| return; | |||
| } | |||
| __global__ void UpdateInputMinMax(float* input_min, float* input_max, const float min, const float max) { | |||
| *input_min = min; | |||
| *input_max = max; | |||
| } | |||
| void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max, | |||
| const float* scale, bool symmetric, cudaStream_t cuda_stream) { | |||
| FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale, | |||
| symmetric); | |||
| return; | |||
| } | |||
| void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, | |||
| const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream) { | |||
| FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min, | |||
| nudge_max); | |||
| return; | |||
| } | |||
| void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||
| float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream) { | |||
| NudgeMinMax<<<1, 1>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale); | |||
| return; | |||
| } | |||
| void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema, | |||
| cudaStream_t cuda_stream) { | |||
| float minel = 0.f; | |||
| float maxel = 0.f; | |||
| thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple; | |||
| tuple = thrust::minmax_element(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size); | |||
| minel = tuple.first[0]; | |||
| maxel = tuple.second[0]; | |||
| if (ema) { | |||
| UpdateInputMinMaxWithEMA<<<1, 1>>>(input_min, input_max, minel, maxel, ema_decay); | |||
| } else { | |||
| UpdateInputMinMax<<<1, 1>>>(input_min, input_max, minel, maxel); | |||
| } | |||
| return; | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max, | |||
| const float* scale, bool symmetric, cudaStream_t cuda_stream); | |||
| void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, | |||
| const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream); | |||
| void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||
| float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream); | |||
| void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| @@ -0,0 +1,174 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <thrust/extrema.h> | |||
| #include <thrust/device_vector.h> | |||
| #include <thrust/execution_policy.h> | |||
| #include <thrust/reduce.h> | |||
| #include <thrust/pair.h> | |||
| #include "fake_quant_per_channel_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| /** | |||
| * Find the nudge min, max and scale value as output. | |||
| * @param input_min array | |||
| * @param input_max array | |||
| * @param quant_min 1 << bit -1 | |||
| * @param quant_max 0 | |||
| * @param nudge_min array | |||
| * @param nudge_max array | |||
| * @param scale array | |||
| * @param channel_num | |||
| * @return | |||
| */ | |||
| __global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input_max, const float quant_min, | |||
| const float quant_max, float* nudge_min, float* nudge_max, float* scale, | |||
| int channel_num) { | |||
| float zp_from_min = 0.f; | |||
| float nudge_zp = 0.f; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) { | |||
| if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) { | |||
| scale[i] = 0.f; | |||
| zp_from_min = 0.f; | |||
| } else { | |||
| scale[i] = (input_max[i] - input_min[i]) / (quant_max - quant_min); | |||
| zp_from_min = quant_min - input_min[i] / scale[i]; | |||
| } | |||
| if (zp_from_min <= quant_min) { | |||
| nudge_zp = quant_min; | |||
| } else if (zp_from_min >= quant_max) { | |||
| nudge_zp = quant_max; | |||
| } else { | |||
| nudge_zp = round(zp_from_min); | |||
| } | |||
| nudge_min[i] = (quant_min - nudge_zp) * (scale[i]); | |||
| nudge_max[i] = (quant_max - nudge_zp) * (scale[i]); | |||
| } | |||
| } | |||
| void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||
| float* nudge_min, float* nudge_max, float* scale, const int channel_num, | |||
| cudaStream_t cuda_stream) { | |||
| NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | |||
| input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num); | |||
| } | |||
| /** | |||
| * Calulate fake quant output accroding by nudge min, nudge max, nudge scale. | |||
| * @param input - array | |||
| * @param output - array | |||
| * @param total_size - int, purpose for cal the per chanel number in filters | |||
| * @param channel_size - int, purpose for cal the per channel number in filters | |||
| * @param nudge_min - array | |||
| * @param nudge_max - array | |||
| * @param scale - array | |||
| * @return | |||
| */ | |||
| __global__ void FakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size, | |||
| const float* nudge_min, const float* nudge_max, const float* scale, | |||
| bool symmetric) { | |||
| float input_x = 0.f; | |||
| int nudge_input = 0; | |||
| int channel_idx = 0; | |||
| int per_channel_num = total_size / channel_size; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) { | |||
| input_x = input[i]; | |||
| channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num)); | |||
| // clamp input x | |||
| if (input_x < nudge_min[channel_idx]) { | |||
| input_x = nudge_min[channel_idx]; | |||
| } | |||
| if (input_x > nudge_max[channel_idx]) { | |||
| input_x = nudge_max[channel_idx]; | |||
| } | |||
| // clamp shift | |||
| nudge_input = floor((input_x - nudge_min[channel_idx]) / scale[channel_idx] + 0.5f); | |||
| // quantize | |||
| output[i] = nudge_input * scale[channel_idx] + nudge_min[channel_idx]; | |||
| } | |||
| } | |||
| void CalFakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size, | |||
| const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric, | |||
| cudaStream_t cuda_stream) { | |||
| FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>( | |||
| input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); | |||
| } | |||
| /** | |||
| * UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA. | |||
| * @param input_min | |||
| * @param input_max | |||
| * @param min | |||
| * @param max | |||
| * @return | |||
| */ | |||
| __global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, float* input, int channels, | |||
| int per_channel_nums, bool ema, float ema_decay) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { | |||
| thrust::pair<float*, float*> sum = | |||
| thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); | |||
| if (ema) { | |||
| input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; | |||
| input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; | |||
| } else { | |||
| input_min[i] = sum.first[0]; | |||
| input_max[i] = sum.second[0]; | |||
| } | |||
| } | |||
| } | |||
| __global__ void UpdateInputMinMaxPerChannelWithEMA(float* input_min, float* input_max, float min, float max, | |||
| const float decay) { | |||
| *input_min = decay * (min) + (1 - decay) * (*input_min); | |||
| *input_max = decay * (max) + (1 - decay) * (*input_max); | |||
| } | |||
| void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_size, const int channel_size, | |||
| const float ema_decay, const bool ema, cudaStream_t cuda_stream) { | |||
| int per_channel_num = total_size / channel_size; | |||
| UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>( | |||
| input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay); | |||
| } | |||
| __global__ void FakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, | |||
| const int total_size, const int channel_size, const float* nudge_min, | |||
| const float* nudge_max) { | |||
| int channel_idx = 0; | |||
| int per_channel_num = total_size / channel_size; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) { | |||
| channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num)); | |||
| if (input[i] < nudge_min[channel_idx] || input[i] > nudge_max[channel_idx]) { | |||
| output[i] = 0; | |||
| } else { | |||
| output[i] = gradient[i]; | |||
| } | |||
| } | |||
| } | |||
| void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num, | |||
| const int channel_num, const float* nudge_min, const float* nudge_max, | |||
| cudaStream_t cuda_stream) { | |||
| FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | |||
| input, gradient, output, total_num, channel_num, nudge_min, nudge_max); | |||
| } | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||
| float* nudge_min, float* nudge_max, float* scale, const int channel_num, | |||
| cudaStream_t cuda_stream); | |||
| void CalFakeQuantizePerChannel(const float* input, float* output, const int total_num, const int channel_num, | |||
| const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric, | |||
| cudaStream_t cuda_stream); | |||
| void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_num, const int channel_num, | |||
| const float ema_decay, const bool ema, cudaStream_t cuda_stream); | |||
| void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num, | |||
| const int channel_num, const float* nudge_min, const float* nudge_max, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| @@ -16,7 +16,7 @@ | |||
| #ifndef MINDSPORE_GATHER_GPU_CU_H | |||
| #define MINDSPORE_GATHER_GPU_CU_H | |||
| template <typename T, typename S> | |||
| template <typename T, typename S> | |||
| void Gather(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, | |||
| size_t input_dim1, cudaStream_t stream); | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdint.h> | |||
| #include "sparse_cross_entropy_cuda_impl.cuh" | |||
| #include "include/cuda_runtime.h" | |||
| template <typename T> | |||
| __global__ void CalCrossEntropyKernel(const float *logits, T *labels, const int batch_size, const int class_num, | |||
| float *loss) { | |||
| float total_loss = 0.0; | |||
| float epsilon = 1e-6; | |||
| for (int i = 0; i < batch_size; ++i) { | |||
| float logit = logits[i * class_num + labels[i]]; | |||
| if (logit <= 0) { | |||
| logit += epsilon; | |||
| } | |||
| float single_loss = -logf(logit); | |||
| total_loss += single_loss; | |||
| } | |||
| total_loss /= batch_size; | |||
| loss[0] = total_loss; | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void CalCrossEntropyGradKernel(const float *logits, T *labels, const int batch_size, const int class_num, | |||
| float *grad) { | |||
| for (int i = 0; i < batch_size; i++) { | |||
| for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) { | |||
| if (labels[i] == j) { | |||
| grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size; | |||
| } else { | |||
| grad[i * class_num + j] = logits[i * class_num + j] / batch_size; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss, | |||
| cudaStream_t cuda_stream) { | |||
| CalCrossEntropyKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, loss); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad, | |||
| cudaStream_t cuda_stream) { | |||
| CalCrossEntropyGradKernel<<<GET_BLOCKS(class_num), GET_THREADS, 0, cuda_stream>>>(logits, labels, batch_size, | |||
| class_num, grad); | |||
| return; | |||
| } | |||
| template void CalCrossEntropy<int>(const float *logits, int *labels, const int batch_size, const int class_num, | |||
| float *loss, cudaStream_t cuda_stream); | |||
| template void CalCrossEntropy<uint64_t>(const float *logits, uint64_t *labels, const int batch_size, | |||
| const int class_num, float *loss, cudaStream_t cuda_stream); | |||
| template void CalCrossEntropyGrad<int>(const float *logits, int *labels, const int batch_size, const int class_num, | |||
| float *grad, cudaStream_t cuda_stream); | |||
| template void CalCrossEntropyGrad<uint64_t>(const float *logits, uint64_t *labels, const int batch_size, | |||
| const int class_num, float *grad, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/nn/dropout_gpu_kernel.h" | |||
| #include "kernel/gpu/cuda_impl/dropout_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| DropoutGpuFwdKernel::DropoutGpuFwdKernel() | |||
| : cudnn_handle_(nullptr), | |||
| is_null_input_(false), | |||
| num_count_(0), | |||
| drop_prob_(0.0), | |||
| states_init_(false), | |||
| mask_generator_(nullptr) {} | |||
| DropoutGpuFwdKernel::~DropoutGpuFwdKernel() { DestroyResource(); } | |||
| const std::vector<size_t> &DropoutGpuFwdKernel::GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &DropoutGpuFwdKernel::GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &DropoutGpuFwdKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1."; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| num_count_ = 1; | |||
| for (size_t x : input_shape) { | |||
| num_count_ *= x; | |||
| } | |||
| drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void DropoutGpuFwdKernel::InitResource() { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| } | |||
| void DropoutGpuFwdKernel::DestroyResource() noexcept {} | |||
| void DropoutGpuFwdKernel::InitSizeLists() { | |||
| size_t input_size = num_count_ * sizeof(float); | |||
| size_t workspace_size = 0; | |||
| input_size_list_.push_back(input_size); | |||
| output_size_list_.push_back(input_size); // output size: the same with input size | |||
| output_size_list_.push_back(input_size); // mask size: the same with input size | |||
| workspace_size_list_.push_back(workspace_size); | |||
| } | |||
| bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| auto *input = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto *output = reinterpret_cast<float *>(outputs[0]->addr); | |||
| auto *mask = reinterpret_cast<float *>(outputs[1]->addr); | |||
| if (!states_init_) { | |||
| curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT); | |||
| curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL)); | |||
| states_init_ = true; | |||
| } | |||
| curandGenerateUniform(mask_generator_, mask, num_count_); | |||
| DropoutForward(input, mask, output, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "include/curand.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class DropoutGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| DropoutGpuFwdKernel(); | |||
| ~DropoutGpuFwdKernel() override; | |||
| const std::vector<size_t> &GetInputSizeList() const override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override; | |||
| bool Init(const CNodePtr &kernel_node) override; | |||
| protected: | |||
| void InitResource() override; | |||
| void InitSizeLists() override; | |||
| private: | |||
| void DestroyResource() noexcept; | |||
| cudnnHandle_t cudnn_handle_; | |||
| bool is_null_input_; | |||
| size_t num_count_; | |||
| float drop_prob_; | |||
| bool states_init_; | |||
| curandGenerator_t mask_generator_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| MS_REG_GPU_KERNEL(Dropout, DropoutGpuFwdKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/nn/dropout_grad_kernel.h" | |||
| #include "kernel/gpu/cuda_impl/dropout_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel() | |||
| : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), drop_prob_(0.0) {} | |||
| DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); } | |||
| const std::vector<size_t> &DropoutGradGpuFwdKernel::GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &DropoutGradGpuFwdKernel::GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &DropoutGradGpuFwdKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuFwdKernel needs 2."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| num_count_ = 1; | |||
| for (size_t x : input_shape) { | |||
| num_count_ *= x; | |||
| } | |||
| drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob")); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void DropoutGradGpuFwdKernel::InitResource() { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| } | |||
| void DropoutGradGpuFwdKernel::DestroyResource() noexcept {} | |||
| void DropoutGradGpuFwdKernel::InitSizeLists() { | |||
| size_t dy_size = num_count_ * sizeof(float); | |||
| size_t mask_size = dy_size; | |||
| size_t dx_size = dy_size; | |||
| size_t workspace_size = 0; | |||
| input_size_list_.push_back(dy_size); | |||
| input_size_list_.push_back(mask_size); | |||
| output_size_list_.push_back(dx_size); | |||
| workspace_size_list_.push_back(workspace_size); | |||
| } | |||
| bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| auto *dy = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto *mask = reinterpret_cast<float *>(inputs[1]->addr); | |||
| auto *dx = reinterpret_cast<float *>(outputs[0]->addr); | |||
| DropoutBackward(dy, mask, dx, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class DropoutGradGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| DropoutGradGpuFwdKernel(); | |||
| ~DropoutGradGpuFwdKernel() override; | |||
| const std::vector<size_t> &GetInputSizeList() const override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override; | |||
| bool Init(const CNodePtr &kernel_node) override; | |||
| protected: | |||
| void InitResource() override; | |||
| void InitSizeLists() override; | |||
| private: | |||
| void DestroyResource() noexcept; | |||
| cudnnHandle_t cudnn_handle_; | |||
| bool is_null_input_; | |||
| size_t num_count_; | |||
| float drop_prob_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| MS_REG_GPU_KERNEL(DropoutGrad, DropoutGradGpuFwdKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormFold2, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BatchNormFold2GpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,139 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BatchNormFold2GpuKernel : public GpuKernel { | |||
| public: | |||
| BatchNormFold2GpuKernel() | |||
| : cudnn_handle_(nullptr), | |||
| is_null_input_(false), | |||
| batch_size_(0), | |||
| channel_(0), | |||
| height_(0), | |||
| width_(0), | |||
| freeze_bn_(0) {} | |||
| ~BatchNormFold2GpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| auto *input = GetDeviceAddress<T>(inputs, 0); | |||
| auto *beta = GetDeviceAddress<T>(inputs, 1); | |||
| auto *gamma = GetDeviceAddress<T>(inputs, 2); | |||
| auto *batch_std = GetDeviceAddress<T>(inputs, 3); | |||
| auto *batch_mean = GetDeviceAddress<T>(inputs, 4); | |||
| auto *running_std = GetDeviceAddress<T>(inputs, 5); | |||
| auto *running_mean = GetDeviceAddress<T>(inputs, 6); | |||
| auto *global_step = GetDeviceAddress<int32_t>(inputs, 7); | |||
| auto *output = GetDeviceAddress<T>(outputs, 0); | |||
| BatchNormFold2Forward(input, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, output, | |||
| freeze_bn_, batch_size_, channel_, height_, width_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 8) { | |||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs 8."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "BatchNormFold2GpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| if (input_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "BatchNormFold2GpuKernel input shape needs (N,C,H,W)."; | |||
| return false; | |||
| } | |||
| batch_size_ = input_shape[0]; | |||
| channel_ = input_shape[1]; | |||
| height_ = input_shape[2]; | |||
| width_ = input_shape[3]; | |||
| freeze_bn_ = GetValue<int32_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| void InitSizeLists() { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = channel_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(weight_size); // beta | |||
| input_size_list_.push_back(weight_size); // gamma | |||
| input_size_list_.push_back(weight_size); // batch_std | |||
| input_size_list_.push_back(weight_size); // batch_mean | |||
| input_size_list_.push_back(weight_size); // running_std | |||
| input_size_list_.push_back(weight_size); // running_mean | |||
| input_size_list_.push_back(sizeof(int32_t)); // global_step | |||
| output_size_list_.push_back(input_size); | |||
| size_t workspace_size = 0; | |||
| workspace_size_list_.push_back(workspace_size); | |||
| } | |||
| private: | |||
| void DestroyResource() noexcept {} | |||
| cudnnHandle_t cudnn_handle_; | |||
| bool is_null_input_; | |||
| size_t batch_size_; | |||
| size_t channel_; | |||
| size_t height_; | |||
| size_t width_; | |||
| size_t freeze_bn_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BatchNormFold2GradGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,167 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BatchNormFold2GradGpuKernel : public GpuKernel { | |||
| public: | |||
| BatchNormFold2GradGpuKernel() | |||
| : cudnn_handle_(nullptr), | |||
| is_null_input_(false), | |||
| batch_size_(0), | |||
| channel_(0), | |||
| height_(0), | |||
| width_(0), | |||
| freeze_bn_(0) {} | |||
| ~BatchNormFold2GradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| auto *dout = GetDeviceAddress<T>(inputs, 0); | |||
| auto *x = GetDeviceAddress<T>(inputs, 1); | |||
| auto *gamma = GetDeviceAddress<T>(inputs, 2); | |||
| auto *batch_std = GetDeviceAddress<T>(inputs, 3); | |||
| auto *batch_mean = GetDeviceAddress<T>(inputs, 4); | |||
| auto *running_std = GetDeviceAddress<T>(inputs, 5); | |||
| auto *running_mean = GetDeviceAddress<T>(inputs, 6); | |||
| auto *global_step = GetDeviceAddress<int32_t>(inputs, 7); | |||
| auto *d_batch_std = GetDeviceAddress<T>(outputs, 0); | |||
| auto *d_batch_mean = GetDeviceAddress<T>(outputs, 1); | |||
| auto *d_beta = GetDeviceAddress<T>(outputs, 2); | |||
| auto *d_gamma = GetDeviceAddress<T>(outputs, 3); | |||
| auto *d_x = GetDeviceAddress<T>(outputs, 4); | |||
| auto *tmp = GetDeviceAddress<T>(workspace, 0); | |||
| auto *tmp2 = GetDeviceAddress<T>(workspace, 1); | |||
| auto *reduce_x = GetDeviceAddress<T>(workspace, 2); | |||
| auto *tmp_x = GetDeviceAddress<T>(workspace, 3); | |||
| int32_t current_step_host[1]; | |||
| size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost), | |||
| "Failed to copy gpu memory."); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(d_x, dout, x_size, cudaMemcpyDeviceToDevice), "Failed to copy gpu memory."); | |||
| BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (current_step_host[0] < freeze_bn_) { | |||
| CalBatchNormFold2GradNotFreezeDxMul(batch_std, running_std, d_x, batch_size_, channel_, height_, width_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalBatchNormFold2GradNotFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, | |||
| d_batch_mean, d_batch_std, channel_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalBatchNormFold2GradFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, | |||
| d_batch_mean, d_batch_std, channel_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 8) { | |||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs 8."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "BatchNormFold2GradGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| if (input_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "BatchNormFold2GradGpuKernel input shape needs (N,C,H,W)."; | |||
| return false; | |||
| } | |||
| batch_size_ = input_shape[0]; | |||
| channel_ = input_shape[1]; | |||
| height_ = input_shape[2]; | |||
| width_ = input_shape[3]; | |||
| freeze_bn_ = GetValue<int32_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| void InitSizeLists() { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = channel_ * sizeof(T); | |||
| size_t workspace_size = batch_size_ * channel_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); // dout | |||
| input_size_list_.push_back(input_size); // x | |||
| input_size_list_.push_back(weight_size); // gamma | |||
| input_size_list_.push_back(weight_size); // batch_std | |||
| input_size_list_.push_back(weight_size); // batch_mean | |||
| input_size_list_.push_back(weight_size); // running_std | |||
| input_size_list_.push_back(weight_size); // running_mean | |||
| input_size_list_.push_back(sizeof(int32_t)); // global_step | |||
| output_size_list_.push_back(weight_size); // d_batch_std | |||
| output_size_list_.push_back(weight_size); // d_batch_mean | |||
| output_size_list_.push_back(weight_size); // d_beta | |||
| output_size_list_.push_back(weight_size); // d_gamma | |||
| output_size_list_.push_back(input_size); // d_x | |||
| workspace_size_list_.push_back(workspace_size); // tmp | |||
| workspace_size_list_.push_back(workspace_size); // tmp2 | |||
| workspace_size_list_.push_back(weight_size); // reduce_x | |||
| workspace_size_list_.push_back(input_size); // tmp_x | |||
| } | |||
| private: | |||
| void DestroyResource() noexcept {} | |||
| cudnnHandle_t cudnn_handle_; | |||
| bool is_null_input_; | |||
| size_t batch_size_; | |||
| size_t channel_; | |||
| size_t height_; | |||
| size_t width_; | |||
| int32_t freeze_bn_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/batchnorm_fold_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormFold, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BatchNormFoldGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,208 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/kernel_constants.h" | |||
| #include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BatchNormFoldGpuKernel : public GpuKernel { | |||
| public: | |||
| BatchNormFoldGpuKernel() | |||
| : input_size_(0), | |||
| output_size_(0), | |||
| exp_avg_factor_(0.9), | |||
| epsilon_(1e-12), | |||
| is_training_(true), | |||
| freeze_bn_(0), | |||
| batch_(0), | |||
| channel_(0), | |||
| height_(0), | |||
| width_(0), | |||
| mode_(CUDNN_BATCHNORM_SPATIAL), | |||
| x_desc_(nullptr), | |||
| scale_bias_mean_var_desc_(nullptr), | |||
| handle_(nullptr) {} | |||
| ~BatchNormFoldGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| (void)workspace; | |||
| auto x = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto mean = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto variance = reinterpret_cast<T *>(inputs[2]->addr); | |||
| int *current_step = reinterpret_cast<int *>(inputs[3]->addr); | |||
| int current_step_host[1]; | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost), | |||
| "Copy gpu memoy failed."); | |||
| if (x == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; | |||
| return false; | |||
| } | |||
| if (mean == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGpuKernel mean is null."; | |||
| return false; | |||
| } | |||
| if (variance == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGpuKernel variance is null."; | |||
| return false; | |||
| } | |||
| if (current_step == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; | |||
| return false; | |||
| } | |||
| auto batch_mean = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto batch_std = reinterpret_cast<T *>(outputs[1]->addr); | |||
| auto running_mean = reinterpret_cast<T *>(outputs[2]->addr); | |||
| auto running_std = reinterpret_cast<T *>(outputs[3]->addr); | |||
| auto y = reinterpret_cast<T *>(workspace[0]->addr); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice), | |||
| "Failed to copy gpu memory."); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_std, variance, output_size_, cudaMemcpyDeviceToDevice), | |||
| "Failed to copy gpu memory."); | |||
| CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (!is_training_ || current_step_host[0] >= freeze_bn_) { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemset(batch_mean, 0, output_size_), "Failed to set gpu memory."); | |||
| ThrustFillWith(batch_std, channel_, 1.f, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| const T alpha = 1; | |||
| const T beta = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardTraining( | |||
| handle_, mode_, &alpha, &beta, x_desc_, x, x_desc_, y, scale_bias_mean_var_desc_, | |||
| mean, mean, exp_avg_factor_, mean, variance, epsilon_, batch_mean, batch_std), | |||
| "Failed to launch kernel.") | |||
| CalUpdateBatchStd(channel_, batch_std, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 4) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << " but BatchNormFold GpuKernel OP needs 4 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 4) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFold GpuKernel OP needs 4 output."; | |||
| return false; | |||
| } | |||
| T momentum = GetValue<T>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("momentum")); | |||
| exp_avg_factor_ = 1.0 - momentum; | |||
| epsilon_ = GetValue<T>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); | |||
| is_training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); | |||
| freeze_bn_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "Input shape is " << input_shape.size() | |||
| << ", but BatchNormFold GpuKernel OP needs 4DTensor input."; | |||
| return false; | |||
| } | |||
| batch_ = input_shape[0]; | |||
| channel_ = input_shape[1]; | |||
| height_ = input_shape[2]; | |||
| width_ = input_shape[3]; | |||
| input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; | |||
| output_size_ = sizeof(T) * channel_; | |||
| cudnnDataType_t cudnnDataType = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_), | |||
| "Set x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, 1, channel_, 1, 1), | |||
| "Set para desc failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| // x, mean, variance, current_step | |||
| input_size_list_.push_back(input_size_); | |||
| input_size_list_.push_back(output_size_); | |||
| input_size_list_.push_back(output_size_); | |||
| input_size_list_.push_back(sizeof(int)); | |||
| // batch_mean, batch_std, running_mean, running_std | |||
| output_size_list_.push_back(output_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| // store y | |||
| workspace_size_list_.push_back(input_size_); | |||
| } | |||
| void InitResource() { | |||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); | |||
| } | |||
| private: | |||
| void DestroyResource() noexcept { | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); | |||
| } | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| double exp_avg_factor_; | |||
| double epsilon_; | |||
| bool is_training_; | |||
| int freeze_bn_; | |||
| int batch_; | |||
| int channel_; | |||
| int height_; | |||
| int width_; | |||
| cudnnBatchNormMode_t mode_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| cudnnTensorDescriptor_t scale_bias_mean_var_desc_; | |||
| cudnnHandle_t handle_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormFoldGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BatchNormFoldGradGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,167 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BatchNormFoldGradGpuKernel : public GpuKernel { | |||
| public: | |||
| BatchNormFoldGradGpuKernel() | |||
| : input_size_(0), | |||
| channel_size_(0), | |||
| workspace_size_(0), | |||
| momentum_(0.1), | |||
| epsilon_(1e-12), | |||
| is_training_(true), | |||
| freeze_bn_(0), | |||
| current_step_(0), | |||
| batch_(0), | |||
| channel_(0), | |||
| height_(0), | |||
| width_(0) {} | |||
| ~BatchNormFoldGradGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| (void)workspace; | |||
| // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' | |||
| T *d_batch_mean = GetDeviceAddress<T>(inputs, 0); | |||
| T *d_batch_std = GetDeviceAddress<T>(inputs, 1); | |||
| T *x = GetDeviceAddress<T>(inputs, 2); | |||
| T *batch_mean = GetDeviceAddress<T>(inputs, 3); | |||
| T *batch_std = GetDeviceAddress<T>(inputs, 4); | |||
| int *current_step = GetDeviceAddress<int>(inputs, 5); | |||
| int current_step_host[1]; | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost), | |||
| "Copy gpu memoy failed."); | |||
| if (d_batch_mean == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; | |||
| return false; | |||
| } | |||
| if (d_batch_std == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_std is null."; | |||
| return false; | |||
| } | |||
| if (x == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel x is null."; | |||
| return false; | |||
| } | |||
| if (batch_mean == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_mean is null."; | |||
| return false; | |||
| } | |||
| if (batch_std == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_std is null."; | |||
| return false; | |||
| } | |||
| if (current_step == nullptr) { | |||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; | |||
| return false; | |||
| } | |||
| T *dx = reinterpret_cast<T *>(outputs[0]->addr); | |||
| if (!is_training_ || current_step_host[0] >= freeze_bn_) { | |||
| ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| CalBatchNormFoldGrad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_, channel_, height_, width_, dx, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 6) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFoldGrad GpuKernel OP needs 4 output."; | |||
| return false; | |||
| } | |||
| epsilon_ = GetValue<T>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); | |||
| is_training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); | |||
| freeze_bn_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| if (input_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "Input shape is " << input_shape.size() | |||
| << ", but BatchNormFoldGrad GpuKernel OP needs 4DTensor input."; | |||
| return false; | |||
| } | |||
| batch_ = input_shape[0]; | |||
| channel_ = input_shape[1]; | |||
| height_ = input_shape[2]; | |||
| width_ = input_shape[3]; | |||
| input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; | |||
| channel_size_ = sizeof(T) * channel_; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' | |||
| input_size_list_.push_back(channel_size_); | |||
| input_size_list_.push_back(channel_size_); | |||
| input_size_list_.push_back(input_size_); | |||
| input_size_list_.push_back(channel_size_); | |||
| input_size_list_.push_back(channel_size_); | |||
| input_size_list_.push_back(sizeof(int)); | |||
| // 'dx' | |||
| output_size_list_.push_back(input_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| size_t channel_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| T momentum_; | |||
| T epsilon_; | |||
| bool is_training_; | |||
| int freeze_bn_; | |||
| int current_step_; | |||
| int batch_; | |||
| int channel_; | |||
| int height_; | |||
| int width_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2020、 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/correction_mul_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(CorrectionMul, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| CorrectionMulGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class CorrectionMulGpuKernel : public GpuKernel { | |||
| public: | |||
| CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} | |||
| ~CorrectionMulGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| auto *weight = GetDeviceAddress<T>(inputs, 0); | |||
| auto *gamma = GetDeviceAddress<T>(inputs, 1); | |||
| auto *running_std = GetDeviceAddress<T>(inputs, 2); | |||
| auto *output = GetDeviceAddress<T>(outputs, 0); | |||
| CalCorrectionMul(weight, gamma, running_std, batch_size_, channel_, height_, width_, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGpuKernel needs 3."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W)."; | |||
| return false; | |||
| } | |||
| batch_size_ = input_shape[0]; | |||
| channel_ = input_shape[1]; | |||
| height_ = input_shape[2]; | |||
| width_ = input_shape[3]; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = batch_size_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); // weight | |||
| input_size_list_.push_back(weight_size); // gamma | |||
| input_size_list_.push_back(weight_size); // running_std | |||
| size_t workspace_size = 0; | |||
| output_size_list_.push_back(input_size); | |||
| workspace_size_list_.push_back(workspace_size); | |||
| } | |||
| void InitResource() {} | |||
| private: | |||
| void DestroyResource() noexcept {} | |||
| size_t batch_size_; | |||
| size_t channel_; | |||
| size_t height_; | |||
| size_t width_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/correction_mul_grad_gpu_kernel.h" | |||
| #include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| CorrectionMulGradGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class CorrectionMulGradGpuKernel : public GpuKernel { | |||
| public: | |||
| CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} | |||
| ~CorrectionMulGradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| auto *d_out = GetDeviceAddress<T>(inputs, 0); | |||
| auto *weight = GetDeviceAddress<T>(inputs, 1); | |||
| auto *gamma = GetDeviceAddress<T>(inputs, 2); | |||
| auto *running_std = GetDeviceAddress<T>(inputs, 3); | |||
| auto *d_weight = GetDeviceAddress<T>(outputs, 0); | |||
| auto *d_gamma = GetDeviceAddress<T>(outputs, 1); | |||
| auto *tmp = GetDeviceAddress<T>(workspace, 0); | |||
| CalCorrectionMul(d_out, gamma, running_std, batch_size_, channel_, height_, width_, d_weight, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalCorrectionMulGrad(d_out, weight, running_std, batch_size_, channel_, height_, width_, d_gamma, tmp, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 4) { | |||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGradGpuKernel needs 4."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W)."; | |||
| return false; | |||
| } | |||
| batch_size_ = input_shape[0]; | |||
| channel_ = input_shape[1]; | |||
| height_ = input_shape[2]; | |||
| width_ = input_shape[3]; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() { | |||
| size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | |||
| size_t weight_size = batch_size_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); // d_out | |||
| input_size_list_.push_back(input_size); // weight | |||
| input_size_list_.push_back(weight_size); // gamma | |||
| input_size_list_.push_back(weight_size); // running_std | |||
| output_size_list_.push_back(input_size); // d_weight | |||
| output_size_list_.push_back(weight_size); // d_gamma | |||
| workspace_size_list_.push_back(input_size); // tmp d_out * weight | |||
| } | |||
| void InitResource() {} | |||
| private: | |||
| void DestroyResource() noexcept {} | |||
| size_t batch_size_; | |||
| size_t channel_; | |||
| size_t height_; | |||
| size_t width_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ | |||
| @@ -0,0 +1,176 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/fake_quant_gpu_kernel.h" | |||
| #include "kernel/gpu/cuda_impl/fake_quant_impl.cuh" | |||
| #include <thrust/extrema.h> | |||
| #include <thrust/pair.h> | |||
| #include <thrust/device_vector.h> | |||
| #include <cuda_runtime_api.h> | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| FakeQuantGpuKernel::FakeQuantGpuKernel() | |||
| : input_size_(0), | |||
| min_size_(0), | |||
| max_size_(0), | |||
| output_size_(0), | |||
| workspace_size_(0), | |||
| num_bits_(0), | |||
| quant_min_(0), | |||
| quant_max_(0), | |||
| quant_num_(0), | |||
| quant_delay_(0), | |||
| ema_(false), | |||
| ema_decay_(0), | |||
| global_step_(0), | |||
| training_(false), | |||
| narrow_range_(false), | |||
| symmetric_(false) {} | |||
| const std::vector<size_t> &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; | |||
| } | |||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||
| ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||
| training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | |||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | |||
| } | |||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||
| if (quant_delay_ < 0) { | |||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; | |||
| } | |||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||
| if (symmetric_) { | |||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||
| } else { | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| } | |||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||
| if (narrow_range_) { | |||
| quant_min_++; | |||
| } | |||
| if (quant_num_ == 0) { | |||
| quant_num_ = 1; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||
| quant_num_ *= SizeToInt(input_shape[i]); | |||
| } | |||
| input_size_ = sizeof(float); | |||
| min_size_ = sizeof(float); | |||
| max_size_ = sizeof(float); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| output_size_ = input_size_; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void FakeQuantGpuKernel::InitSizeLists() { | |||
| input_size_list_.push_back(input_size_); // input | |||
| input_size_list_.push_back(min_size_); // min | |||
| input_size_list_.push_back(max_size_); // max | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| (void)workspace; | |||
| float *output = GetDeviceAddress<float>(outputs, 0); | |||
| float *input = GetDeviceAddress<float>(inputs, 0); | |||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | |||
| float *input_max = GetDeviceAddress<float>(inputs, 2); | |||
| if (input == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null."; | |||
| } | |||
| if (input_min == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null."; | |||
| } | |||
| if (input_max == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null."; | |||
| } | |||
| // Allocate space for device copies | |||
| int size = sizeof(float); | |||
| float *d_scale = nullptr; | |||
| float *d_nudge_min = nullptr; | |||
| float *d_nudge_max = nullptr; | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed"); | |||
| if (training_) { | |||
| // calculate the input min and max according by the parameter ema and ema_decay. | |||
| CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| // control flow for quant_delay | |||
| if (global_step_ >= quant_delay_) { | |||
| // real launch | |||
| CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), | |||
| "Copy gpu memory failed"); | |||
| } | |||
| global_step_++; | |||
| } else { | |||
| // real launch | |||
| CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| // Cleanup | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||
| return true; | |||
| } | |||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class FakeQuantGpuKernel : public GpuKernel { | |||
| public: | |||
| FakeQuantGpuKernel(); | |||
| ~FakeQuantGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override; | |||
| bool Init(const CNodePtr &kernel) override; | |||
| protected: | |||
| void InitSizeLists() override; | |||
| private: | |||
| size_t input_size_; | |||
| size_t min_size_; | |||
| size_t max_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int num_bits_; | |||
| float quant_min_; | |||
| float quant_max_; | |||
| int quant_num_; | |||
| int quant_delay_; | |||
| bool ema_; | |||
| float ema_decay_; | |||
| int global_step_; | |||
| bool training_; | |||
| bool narrow_range_; | |||
| bool symmetric_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ | |||
| @@ -0,0 +1,145 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h" | |||
| #include "kernel/gpu/cuda_impl/fake_quant_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| FakeQuantGradGpuKernel::FakeQuantGradGpuKernel() | |||
| : input_size_(0), | |||
| min_size_(0), | |||
| max_size_(0), | |||
| output_size_(0), | |||
| workspace_size_(0), | |||
| num_bits_(0), | |||
| quant_min_(0), | |||
| quant_max_(0), | |||
| quant_size_(0), | |||
| quant_delay_(0), | |||
| global_step_(0) {} | |||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 4) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; | |||
| } | |||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | |||
| } | |||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||
| if (quant_delay_ < 0) { | |||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; | |||
| } | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| if (quant_size_ == 0) { | |||
| quant_size_ = 1; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||
| quant_size_ *= SizeToInt(input_shape[i]); | |||
| } | |||
| input_size_ = sizeof(float); | |||
| min_size_ = sizeof(float); | |||
| max_size_ = sizeof(float); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| output_size_ = input_size_; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void FakeQuantGradGpuKernel::InitSizeLists() { | |||
| input_size_list_.push_back(input_size_); // gradient | |||
| input_size_list_.push_back(input_size_); // input | |||
| input_size_list_.push_back(min_size_); // min | |||
| input_size_list_.push_back(max_size_); // max | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| (void)workspace; | |||
| float *output = GetDeviceAddress<float>(outputs, 0); | |||
| float *gradient = GetDeviceAddress<float>(inputs, 0); | |||
| float *input = GetDeviceAddress<float>(inputs, 1); | |||
| float *input_min = GetDeviceAddress<float>(inputs, 2); | |||
| float *input_max = GetDeviceAddress<float>(inputs, 3); | |||
| if (gradient == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null"; | |||
| } | |||
| if (input == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null."; | |||
| } | |||
| if (input_min == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null."; | |||
| } | |||
| if (input_max == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null."; | |||
| } | |||
| if (global_step_ >= quant_delay_) { | |||
| float *d_scale = nullptr; | |||
| float *d_nudge_min = nullptr; | |||
| float *d_nudge_max = nullptr; | |||
| int size = sizeof(float); | |||
| // Allocate space for device copies | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed"); | |||
| CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| // Cleanup | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice), | |||
| "Copy gpu memory failed."); | |||
| } | |||
| global_step_++; | |||
| return true; | |||
| } | |||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class FakeQuantGradGpuKernel : public GpuKernel { | |||
| public: | |||
| FakeQuantGradGpuKernel(); | |||
| ~FakeQuantGradGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override; | |||
| bool Init(const CNodePtr &kernel_node) override; | |||
| protected: | |||
| void InitSizeLists() override; | |||
| private: | |||
| size_t input_size_; | |||
| size_t min_size_; | |||
| size_t max_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int num_bits_; | |||
| float quant_min_; | |||
| float quant_max_; | |||
| int quant_size_; | |||
| int quant_delay_; | |||
| int global_step_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ | |||
| @@ -0,0 +1,181 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h" | |||
| #include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh" | |||
| #include <thrust/extrema.h> | |||
| #include <thrust/pair.h> | |||
| #include <thrust/device_vector.h> | |||
| #include <cuda_runtime_api.h> | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() | |||
| : input_size_(0), | |||
| min_size_(0), | |||
| max_size_(0), | |||
| output_size_(0), | |||
| workspace_size_(0), | |||
| num_bits_(0), | |||
| quant_min_(0), | |||
| quant_max_(0), | |||
| quant_delay_(0), | |||
| ema_(false), | |||
| ema_decay_(0), | |||
| global_step_(0), | |||
| training_(false), | |||
| channel_out_(0), | |||
| narrow_range_(false), | |||
| symmetric_(false) {} | |||
| const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } | |||
| bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << " but FakeQuant GpuKernel OP needs 1 output."; | |||
| return false; | |||
| } | |||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||
| ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; | |||
| return false; | |||
| } | |||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||
| if (quant_delay_ < 0) { | |||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; | |||
| return false; | |||
| } | |||
| training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | |||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||
| if (symmetric_) { | |||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||
| } else { | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| } | |||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||
| if (narrow_range_) { | |||
| quant_min_++; | |||
| } | |||
| // shape info for gpu | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| channel_out_ = SizeToInt(input_shape[0]); | |||
| min_size_ = sizeof(float) * channel_out_; | |||
| max_size_ = sizeof(float) * channel_out_; | |||
| input_size_ = sizeof(float); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| output_size_ = input_size_; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void FakeQuantPerChannelGpuKernel::InitSizeLists() { | |||
| input_size_list_.push_back(input_size_); // input | |||
| input_size_list_.push_back(min_size_); // min | |||
| input_size_list_.push_back(max_size_); // max | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| (void)workspace; | |||
| float *output = GetDeviceAddress<float>(outputs, 0); | |||
| float *input = GetDeviceAddress<float>(inputs, 0); | |||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | |||
| float *input_max = GetDeviceAddress<float>(inputs, 2); | |||
| if (input == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; | |||
| } | |||
| if (input_min == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min is null."; | |||
| } | |||
| if (input_max == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input max is null."; | |||
| } | |||
| // Allocate space for device copies | |||
| float *d_scale = nullptr; | |||
| float *d_nudge_min = nullptr; | |||
| float *d_nudge_max = nullptr; | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), sizeof(float) * channel_out_), | |||
| "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), sizeof(float) * channel_out_), | |||
| "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), sizeof(float) * channel_out_), | |||
| "Malloc gpu memory failed"); | |||
| int total_size = input_size_ / sizeof(float); | |||
| bool symmetric = false; | |||
| if (training_) { | |||
| // calculate the input min and max according by the parameter ema and ema_decay. | |||
| CalMinMaxPerChannel(input, input_min, input_max, total_size, channel_out_, ema_decay_, ema_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| // control flow for quant_delay | |||
| if (global_step_ >= quant_delay_) { | |||
| // real launch | |||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), | |||
| "Copy gpu memory failed."); | |||
| } | |||
| global_step_++; | |||
| } else { | |||
| // real launch | |||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| // Cleanup | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||
| return true; | |||
| } | |||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class FakeQuantPerChannelGpuKernel : public GpuKernel { | |||
| public: | |||
| FakeQuantPerChannelGpuKernel(); | |||
| ~FakeQuantPerChannelGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override; | |||
| bool Init(const CNodePtr &kernel) override; | |||
| protected: | |||
| void InitSizeLists() override; | |||
| private: | |||
| size_t input_size_; | |||
| size_t min_size_; | |||
| size_t max_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int num_bits_; | |||
| float quant_min_; | |||
| float quant_max_; | |||
| int quant_delay_; | |||
| bool ema_; | |||
| float ema_decay_; | |||
| int global_step_; | |||
| bool training_; | |||
| int channel_out_; | |||
| bool narrow_range_; | |||
| bool symmetric_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ | |||
| @@ -0,0 +1,158 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h" | |||
| #include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() | |||
| : input_size_(0), | |||
| min_size_(0), | |||
| max_size_(0), | |||
| output_size_(0), | |||
| workspace_size_(0), | |||
| num_bits_(0), | |||
| quant_min_(0), | |||
| quant_max_(0), | |||
| channel_out_(0), | |||
| quant_delay_(0), | |||
| global_step_(0), | |||
| narrow_range_(false), | |||
| symmetric_(false) {} | |||
| const std::vector<size_t> &FakeQuantPerChannelGradGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||
| const std::vector<size_t> &FakeQuantPerChannelGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } | |||
| const std::vector<size_t> &FakeQuantPerChannelGradGpuKernel::GetWorkspaceSizeList() const { | |||
| return workspace_size_list_; | |||
| } | |||
| bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 4) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; | |||
| } | |||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | |||
| } | |||
| quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); | |||
| if (quant_delay_ < 0) { | |||
| MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; | |||
| } | |||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||
| if (symmetric_) { | |||
| quant_min_ = 0 - (1 << (num_bits_ - 1)); | |||
| quant_max_ = (1 << (num_bits_ - 1)) - 1; | |||
| } else { | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| } | |||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||
| if (narrow_range_) { | |||
| quant_min_++; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| channel_out_ = SizeToInt(input_shape[0]); | |||
| min_size_ = sizeof(float) * channel_out_; | |||
| max_size_ = sizeof(float) * channel_out_; | |||
| input_size_ = sizeof(float); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| output_size_ = input_size_; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { | |||
| input_size_list_.push_back(input_size_); // gradient | |||
| input_size_list_.push_back(input_size_); // input | |||
| input_size_list_.push_back(min_size_); // min | |||
| input_size_list_.push_back(max_size_); // max | |||
| output_size_list_.push_back(output_size_); | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | |||
| (void)workspace; | |||
| float *output = GetDeviceAddress<float>(outputs, 0); | |||
| float *gradient = GetDeviceAddress<float>(inputs, 0); | |||
| float *input = GetDeviceAddress<float>(inputs, 1); | |||
| float *input_min = GetDeviceAddress<float>(inputs, 2); | |||
| float *input_max = GetDeviceAddress<float>(inputs, 3); | |||
| if (gradient == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; | |||
| } | |||
| if (input == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input is null"; | |||
| } | |||
| if (input_min == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input min is null"; | |||
| } | |||
| if (input_max == nullptr) { | |||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input max is null"; | |||
| } | |||
| int total_size = input_size_ / sizeof(float); | |||
| if (global_step_ >= quant_delay_) { | |||
| float *d_scale = nullptr; | |||
| float *d_nudge_min = nullptr; | |||
| float *d_nudge_max = nullptr; | |||
| // Allocate space for device copies | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), channel_out_ * sizeof(float)), | |||
| "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), channel_out_ * sizeof(float)), | |||
| "Malloc gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), channel_out_ * sizeof(float)), | |||
| "Malloc gpu memory failed"); | |||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| // Cleanup | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice), | |||
| "Copy gpu memory failed."); | |||
| } | |||
| global_step_++; | |||
| return true; | |||
| } | |||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class FakeQuantPerChannelGradGpuKernel : public GpuKernel { | |||
| public: | |||
| FakeQuantPerChannelGradGpuKernel(); | |||
| ~FakeQuantPerChannelGradGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override; | |||
| const std::vector<size_t> &GetOutputSizeList() const override; | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override; | |||
| bool Init(const CNodePtr &kernel_node) override; | |||
| protected: | |||
| void InitSizeLists() override; | |||
| private: | |||
| size_t input_size_; | |||
| size_t min_size_; | |||
| size_t max_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int num_bits_; | |||
| float quant_min_; | |||
| float quant_max_; | |||
| int channel_out_; | |||
| int quant_delay_; | |||
| int global_step_; | |||
| bool narrow_range_; | |||
| bool symmetric_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """HSigmoid op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "HSigmoid", | |||
| "imply_type": "AutoDiff", | |||
| "fusion_type": "OPAQUE", | |||
| "processor": "cuda", | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output" | |||
| } | |||
| ] | |||
| }""") | |||
| def _hsigmoid_akg(): | |||
| """HSigmoid AutoDiff register""" | |||
| return | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """HSigmoidGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "HSigmoidGrad", | |||
| "imply_type": "AutoDiff", | |||
| "fusion_type": "OPAQUE", | |||
| "processor": "cuda", | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y_grad" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output" | |||
| } | |||
| ] | |||
| }""") | |||
| def _hsigmoid_grad_akg(): | |||
| """HSigmoidGrad AutoDiff register""" | |||
| return | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """HSwish op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "HSwish", | |||
| "imply_type": "AutoDiff", | |||
| "fusion_type": "OPAQUE", | |||
| "processor": "cuda", | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output" | |||
| } | |||
| ] | |||
| }""") | |||
| def _hswish_akg(): | |||
| """HSwish AutoDiff register""" | |||
| return | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """HSwishGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "HSwishGrad", | |||
| "imply_type": "AutoDiff", | |||
| "fusion_type": "OPAQUE", | |||
| "processor": "cuda", | |||
| "attr": [ | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "y_grad" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "x" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float32", "float16" | |||
| ], | |||
| "format": [ | |||
| "DefaultFormat", "DefaultFormat" | |||
| ], | |||
| "name": "output" | |||
| } | |||
| ] | |||
| }""") | |||
| def _hswish_grad_akg(): | |||
| """HSwishGrad AutoDiff register""" | |||
| return | |||