From: @ding_fei_fei Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -16,6 +16,10 @@ Previously MakeRefKey is an external interface that is not used, now make it an | |||
| Previously the number of outputs of these operator is different on different backends. To unify their definition we change their output on Ascend backend from multiple to a single. | |||
| ##### `P.FusedBatchNorm`, `P.FusedBatchNormEx` deleted ([!12115](https://gitee.com/mindspore/mindspore/pulls/12115)) | |||
| The FusedBatchNorm and FusedBatchNormEx interface has been deleted. Please use the batchnorm operator to replace it. | |||
| # MindSpore 1.1.1 Release Notes | |||
| ## MindSpore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,14 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void FusedBatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| void BatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| CPUKernel::InitInputOutputSize(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t type_size = sizeof(float); | |||
| @@ -30,16 +30,13 @@ void FusedBatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| workspace_size_list_.emplace_back(tensor_size); | |||
| } | |||
| void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (node_name == "FusedBatchNorm") { | |||
| momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum"); | |||
| is_train = true; | |||
| } | |||
| is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training"); | |||
| momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum"); | |||
| std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| if (x_shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!"; | |||
| MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!"; | |||
| } | |||
| batch_size = x_shape[0]; | |||
| channel = x_shape[1]; | |||
| @@ -66,9 +63,9 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| AddArgument(DNNL_ARG_DST, x_desc); | |||
| } | |||
| bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| bool BatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.size() < 5 || outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Error input output size!"; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,18 +13,18 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class FusedBatchNormCPUKernel : public MKLCPUKernel { | |||
| class BatchNormCPUKernel : public MKLCPUKernel { | |||
| public: | |||
| FusedBatchNormCPUKernel() = default; | |||
| ~FusedBatchNormCPUKernel() override = default; | |||
| BatchNormCPUKernel() = default; | |||
| ~BatchNormCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| @@ -43,20 +43,6 @@ class FusedBatchNormCPUKernel : public MKLCPUKernel { | |||
| size_t nhw_size{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(FusedBatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormCPUKernel) | |||
| MS_REG_CPU_KERNEL(BatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -69,7 +55,7 @@ MS_REG_CPU_KERNEL(BatchNorm, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormCPUKernel) | |||
| BatchNormCPUKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_gard_cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/mkldnn/batch_norm_gard_cpu_kernel.h" | |||
| #include <string> | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" | |||
| @@ -22,19 +22,20 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void FusedBatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| CPUKernel::InitInputOutputSize(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t type_size = sizeof(float); | |||
| std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| size_t tensor_size = shape[1] * 2 * type_size; | |||
| input_size_list_.pop_back(); | |||
| // [2, c] to store scale and bias | |||
| workspace_size_list_.emplace_back(tensor_size); | |||
| // [2, c] to store diff_scale and diff_bias | |||
| workspace_size_list_.emplace_back(tensor_size); | |||
| } | |||
| void FusedBatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| if (x_shape.size() != 4) { | |||
| @@ -72,25 +73,25 @@ void FusedBatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| AddArgument(DNNL_ARG_DIFF_SCALE_SHIFT, scale_bias_desc); | |||
| } | |||
| bool FusedBatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.size() < 5 || outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Error input output size!"; | |||
| } | |||
| auto wksp_in = reinterpret_cast<float *>(workspace[0]->addr); | |||
| auto scale_ret = memcpy_s(wksp_in, workspace[0]->size, inputs[2]->addr, inputs[2]->size); | |||
| auto max_size = workspace[0]->size - inputs[2]->size; | |||
| auto bias_ret = memcpy_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, inputs[3]->addr, inputs[3]->size); | |||
| if (scale_ret != 0 || bias_ret != 0) { | |||
| auto bias_ret = memset_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, 0., max_size); | |||
| if (scale_ret != 0 && bias_ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Memcpy_s error."; | |||
| return false; | |||
| } | |||
| SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); | |||
| SetArgumentHandle(DNNL_ARG_MEAN, inputs[4]->addr); | |||
| SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[5]->addr); | |||
| SetArgumentHandle(DNNL_ARG_MEAN, inputs[3]->addr); | |||
| SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[4]->addr); | |||
| SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_SCALE_SHIFT, workspace[1]->addr); | |||
| @@ -99,7 +100,7 @@ bool FusedBatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> & | |||
| auto wksp_out = reinterpret_cast<float *>(workspace[1]->addr); | |||
| auto diff_scale_ret = memcpy_s(outputs[1]->addr, outputs[1]->size, wksp_out, inputs[2]->size); | |||
| auto diff_bias_ret = | |||
| memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), inputs[3]->size); | |||
| memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), outputs[2]->size); | |||
| if (diff_scale_ret != 0 || diff_bias_ret != 0) { | |||
| MS_LOG(EXCEPTION) << "Memcpy_s error."; | |||
| return false; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,18 +13,18 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_GRAD_CPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_GRAD_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class FusedBatchNormGradCPUKernel : public MKLCPUKernel { | |||
| class BatchNormGradCPUKernel : public MKLCPUKernel { | |||
| public: | |||
| FusedBatchNormGradCPUKernel() = default; | |||
| ~FusedBatchNormGradCPUKernel() override = default; | |||
| BatchNormGradCPUKernel() = default; | |||
| ~BatchNormGradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| @@ -42,7 +42,7 @@ class FusedBatchNormGradCPUKernel : public MKLCPUKernel { | |||
| size_t nhw_size{0}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(FusedBatchNormGradCPU, | |||
| MS_REG_CPU_KERNEL(BatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -53,7 +53,7 @@ MS_REG_CPU_KERNEL(FusedBatchNormGradCPU, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGradCPUKernel) | |||
| BatchNormGradCPUKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,11 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx, | |||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -29,10 +29,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormExGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx, | |||
| BatchNormGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -43,11 +42,10 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormExGpuKernel, half) | |||
| BatchNormGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation, | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormWithActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -58,10 +56,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormExGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation, | |||
| BatchNormGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormWithActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -72,11 +69,10 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormExGpuKernel, half) | |||
| BatchNormGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation, | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormWithAddAndActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -88,10 +84,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormExGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation, | |||
| BatchNormGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormWithAddAndActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| @@ -103,8 +98,7 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation, | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormExGpuKernel, half) | |||
| BatchNormGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GPU_KERNEL_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -27,10 +27,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| class BatchNormGpuKernel : public GpuKernel { | |||
| public: | |||
| FusedBatchNormExGpuKernel() { ResetResource(); } | |||
| ~FusedBatchNormExGpuKernel() override { DestroyResource(); } | |||
| BatchNormGpuKernel() { ResetResource(); } | |||
| ~BatchNormGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| @@ -46,30 +46,38 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| auto x = GetDeviceAddress<T>(inputs, 0); | |||
| auto scale = GetDeviceAddress<float>(inputs, 1); | |||
| auto bias = GetDeviceAddress<float>(inputs, 2); | |||
| auto runing_mean = GetDeviceAddress<float>(inputs, 3); | |||
| auto runnig_variance = GetDeviceAddress<float>(inputs, 4); | |||
| auto running_mean = GetDeviceAddress<float>(inputs, 3); | |||
| auto running_variance = GetDeviceAddress<float>(inputs, 4); | |||
| T *z = nullptr; | |||
| if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) { | |||
| z = GetDeviceAddress<T>(inputs, 5); | |||
| } | |||
| auto y = GetDeviceAddress<T>(outputs, 0); | |||
| auto save_mean = GetDeviceAddress<float>(outputs, 3); | |||
| auto save_variance = GetDeviceAddress<float>(outputs, 4); | |||
| auto reserve_addr = GetDeviceAddress<float>(outputs, 5); | |||
| auto reserve_addr = GetDeviceAddress<float>(outputs, 2); | |||
| T *workspace_addr = nullptr; | |||
| if (workspace_size_ != 0) { | |||
| workspace_addr = GetDeviceAddress<T>(workspace, 0); | |||
| } | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationForwardTrainingEx(handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x, z_desc_, z, y_desc_, | |||
| y, scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, | |||
| runnig_variance, epsilon_, save_mean, save_variance, activation_desc_, | |||
| workspace_addr, workspace_size_, reserve_addr, reserve_size_), | |||
| "Kernel launch failed"); | |||
| if (is_train_) { | |||
| auto save_mean = GetDeviceAddress<float>(outputs, 3); | |||
| auto save_variance = GetDeviceAddress<float>(outputs, 4); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationForwardTrainingEx( | |||
| handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x, z_desc_, z, y_desc_, y, scale_bias_mean_var_desc_, scale, | |||
| bias, exp_avg_factor_, running_mean, running_variance, epsilon_, save_mean, save_variance, activation_desc_, | |||
| workspace_addr, workspace_size_, reserve_addr, reserve_size_), | |||
| "Kernel launch failed"); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnBatchNormalizationForwardInference( | |||
| handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, scale_bias_mean_var_desc_, | |||
| scale, bias, running_mean, running_variance, epsilon_), | |||
| "Kernel launch failed"); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -77,18 +85,22 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| kernel_node_ = kernel_node; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == kFusedBatchNormEx) { | |||
| if (kernel_name == kBatchNorm) { | |||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN; | |||
| } else if (kernel_name == kFusedBatchNormExWithActivation) { | |||
| } else if (kernel_name == kBatchNormWithActivation) { | |||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; | |||
| } else if (kernel_name == kFusedBatchNormExWithAddAndActivation) { | |||
| } else if (kernel_name == kBatchNormWithAddAndActivation) { | |||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name; | |||
| } | |||
| InitResource(); | |||
| mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; | |||
| if (is_train_) { | |||
| mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; | |||
| } else { | |||
| mode_ = CUDNN_BATCHNORM_SPATIAL; | |||
| } | |||
| epsilon_ = GetAttr<float>(kernel_node, "epsilon"); | |||
| exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum"); | |||
| @@ -106,11 +118,11 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| if (shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormExGpuKernel should be 4"; | |||
| MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 4"; | |||
| } | |||
| is_null_input_ = CHECK_NULL_INPUT(shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "FusedBatchNormExGpuKernel input is null"; | |||
| MS_LOG(WARNING) << "BatchNormGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -121,6 +133,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| } | |||
| SetTensorDescriptor(format, shape); | |||
| InitSizeLists(); | |||
| is_train_ = GetAttr<bool>(kernel_node, "is_training"); | |||
| return true; | |||
| } | |||
| @@ -135,6 +148,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN; | |||
| epsilon_ = 10e-5; | |||
| exp_avg_factor_ = 0.1; | |||
| is_train_ = false; | |||
| is_null_input_ = false; | |||
| x_desc_ = nullptr; | |||
| y_desc_ = nullptr; | |||
| @@ -215,11 +229,10 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| } | |||
| output_size_list_.push_back(output_size_); // output | |||
| output_size_list_.push_back(reserve_size_); // reserve space | |||
| output_size_list_.push_back(para_size_); // save scale | |||
| output_size_list_.push_back(para_size_); // save bias | |||
| output_size_list_.push_back(para_size_); // save mean | |||
| output_size_list_.push_back(para_size_); // save variance | |||
| output_size_list_.push_back(reserve_size_); // reserve space | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| } | |||
| @@ -280,6 +293,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel { | |||
| cudnnBatchNormOps_t bn_ops_; | |||
| double epsilon_; | |||
| double exp_avg_factor_; | |||
| bool is_train_; | |||
| bool is_null_input_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| cudnnTensorDescriptor_t y_desc_; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,11 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx, | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) // dy | |||
| .AddInputAttr(kNumberTypeFloat32) // x | |||
| @@ -29,8 +29,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx, | |||
| .AddOutputAttr(kNumberTypeFloat32) // dx | |||
| .AddOutputAttr(kNumberTypeFloat32) // dscale | |||
| .AddOutputAttr(kNumberTypeFloat32), // dbias | |||
| FusedBatchNormGradExGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx, | |||
| BatchNormGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) // dy | |||
| .AddInputAttr(kNumberTypeFloat16) // x | |||
| @@ -41,9 +41,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx, | |||
| .AddOutputAttr(kNumberTypeFloat16) // dx | |||
| .AddOutputAttr(kNumberTypeFloat32) // dscale | |||
| .AddOutputAttr(kNumberTypeFloat32), // dbias | |||
| FusedBatchNormGradExGpuKernel, half) | |||
| BatchNormGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation, | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGradWithActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) // dy | |||
| .AddInputAttr(kNumberTypeFloat32) // x | |||
| @@ -56,8 +56,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation, | |||
| .AddOutputAttr(kNumberTypeFloat32) // dx | |||
| .AddOutputAttr(kNumberTypeFloat32) // dscale | |||
| .AddOutputAttr(kNumberTypeFloat32), // dbias | |||
| FusedBatchNormGradExGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation, | |||
| BatchNormGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGradWithActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) // dy | |||
| .AddInputAttr(kNumberTypeFloat16) // x | |||
| @@ -70,9 +70,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation, | |||
| .AddOutputAttr(kNumberTypeFloat16) // dx | |||
| .AddOutputAttr(kNumberTypeFloat32) // dscale | |||
| .AddOutputAttr(kNumberTypeFloat32), // dbias | |||
| FusedBatchNormGradExGpuKernel, half) | |||
| BatchNormGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation, | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGradWithAddAndActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) // dy | |||
| .AddInputAttr(kNumberTypeFloat32) // x | |||
| @@ -86,8 +86,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation, | |||
| .AddOutputAttr(kNumberTypeFloat32) // dscale | |||
| .AddOutputAttr(kNumberTypeFloat32) // dbias | |||
| .AddOutputAttr(kNumberTypeFloat32), // dz | |||
| FusedBatchNormGradExGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation, | |||
| BatchNormGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGradWithAddAndActivation, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) // dy | |||
| .AddInputAttr(kNumberTypeFloat16) // x | |||
| @@ -101,6 +101,6 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation, | |||
| .AddOutputAttr(kNumberTypeFloat32) // dscale | |||
| .AddOutputAttr(kNumberTypeFloat32) // dbias | |||
| .AddOutputAttr(kNumberTypeFloat16), // dz | |||
| FusedBatchNormGradExGpuKernel, half) | |||
| BatchNormGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GRAD_GPU_KERNEL_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -24,13 +24,14 @@ | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| class BatchNormGradGpuKernel : public GpuKernel { | |||
| public: | |||
| FusedBatchNormGradExGpuKernel() | |||
| BatchNormGradGpuKernel() | |||
| : x_size_(0), | |||
| para_size_(0), | |||
| workspace_size_(0), | |||
| @@ -38,6 +39,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| mode_(CUDNN_BATCHNORM_SPATIAL), | |||
| bn_ops_(CUDNN_BATCHNORM_OPS_BN), | |||
| epsilon_(10e-5), | |||
| is_train_(false), | |||
| is_null_input_(false), | |||
| x_desc_(nullptr), | |||
| y_desc_(nullptr), | |||
| @@ -49,7 +51,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| handle_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| beta_data_diff_(0) {} | |||
| ~FusedBatchNormGradExGpuKernel() override { DestroyResource(); } | |||
| ~BatchNormGradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| @@ -88,17 +90,22 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| if (workspace_size_ != 0) { | |||
| workspace_addr = GetDeviceAddress<T>(workspace, 0); | |||
| } | |||
| const float alpha_data_diff = 1; | |||
| const float alpha_param_diff = 1; | |||
| const float beta_param_diff = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnBatchNormalizationBackwardEx( | |||
| handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, | |||
| &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, | |||
| scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance, | |||
| activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_), | |||
| "Kernel launch failed"); | |||
| if (is_train_) { | |||
| const float alpha_data_diff = 1; | |||
| const float alpha_param_diff = 1; | |||
| const float beta_param_diff = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationBackwardEx(handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, | |||
| &alpha_param_diff, &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, | |||
| dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, scale, bias, dscale, dbias, | |||
| epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, | |||
| workspace_size_, reserve_addr, reserve_size_), | |||
| "Kernel launch failed"); | |||
| } else { | |||
| CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, dscale, dbias, epsilon_, batch_, channel_, height_, | |||
| width_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -106,11 +113,11 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| kernel_node_ = kernel_node; | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == kFusedBatchNormGradEx) { | |||
| if (kernel_name == kBatchNormGradOpName) { | |||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN; | |||
| } else if (kernel_name == kFusedBatchNormGradExWithActivation) { | |||
| } else if (kernel_name == kBatchNormGradWithActivation) { | |||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; | |||
| } else if (kernel_name == kFusedBatchNormGradExWithAddAndActivation) { | |||
| } else if (kernel_name == kBatchNormGradWithAddAndActivation) { | |||
| bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name; | |||
| @@ -134,11 +141,11 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| if (shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradExGpuKernel should be 4"; | |||
| MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4"; | |||
| } | |||
| is_null_input_ = CHECK_NULL_INPUT(shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "FusedBatchNormGradExGpuKernel input is null"; | |||
| MS_LOG(WARNING) << "BatchNormGradGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| @@ -150,6 +157,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; | |||
| SetTensorDescriptor(format, shape); | |||
| InitSizeLists(); | |||
| is_train_ = GetAttr<bool>(kernel_node, "is_training"); | |||
| return true; | |||
| } | |||
| @@ -225,50 +233,52 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| private: | |||
| void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) { | |||
| cudnnTensorFormat_t cudnn_format; | |||
| int batch, channel, height, width; | |||
| if (format == kOpFormat_NHWC) { | |||
| batch = SizeToInt(shape[0]); | |||
| height = SizeToInt(shape[1]); | |||
| width = SizeToInt(shape[2]); | |||
| channel = SizeToInt(shape[3]); | |||
| batch_ = SizeToInt(shape[0]); | |||
| height_ = SizeToInt(shape[1]); | |||
| width_ = SizeToInt(shape[2]); | |||
| channel_ = SizeToInt(shape[3]); | |||
| cudnn_format = CUDNN_TENSOR_NHWC; | |||
| } else { | |||
| batch = SizeToInt(shape[0]); | |||
| channel = SizeToInt(shape[1]); | |||
| height = SizeToInt(shape[2]); | |||
| width = SizeToInt(shape[3]); | |||
| batch_ = SizeToInt(shape[0]); | |||
| channel_ = SizeToInt(shape[1]); | |||
| height_ = SizeToInt(shape[2]); | |||
| width_ = SizeToInt(shape[3]); | |||
| cudnn_format = CUDNN_TENSOR_NCHW; | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set x desc failed"); | |||
| if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), | |||
| cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set z desc failed"); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set dy desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set dx desc failed"); | |||
| if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(dz_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), | |||
| cudnnSetTensor4dDescriptor(dz_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set z desc failed"); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1), | |||
| cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), | |||
| "Set para desc failed"); | |||
| if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) { | |||
| @@ -278,7 +288,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| "cudnnSetActivationDescriptor failed"); | |||
| } | |||
| } | |||
| int batch_; | |||
| int channel_; | |||
| int height_; | |||
| int width_; | |||
| size_t x_size_; | |||
| size_t para_size_; | |||
| size_t workspace_size_; | |||
| @@ -286,6 +299,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { | |||
| cudnnBatchNormMode_t mode_; | |||
| cudnnBatchNormOps_t bn_ops_; | |||
| double epsilon_; | |||
| bool is_train_; | |||
| bool is_null_input_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| @@ -1,48 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/batchnorm_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BatchNormGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| BatchNormGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,202 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BatchNormGradGpuKernel : public GpuKernel { | |||
| public: | |||
| BatchNormGradGpuKernel() | |||
| : batch_(0), | |||
| channel_(0), | |||
| height_(0), | |||
| width_(0), | |||
| mode_(CUDNN_BATCHNORM_SPATIAL), | |||
| epsilon_(10e-5), | |||
| is_null_input_(false), | |||
| x_desc_(nullptr), | |||
| dy_desc_(nullptr), | |||
| dx_desc_(nullptr), | |||
| scale_bias_desc_(nullptr), | |||
| handle_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT) {} | |||
| ~BatchNormGradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| VARIABLE_NOT_USED(stream_ptr); | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| auto dy = GetDeviceAddress<T>(inputs, 0); | |||
| auto x = GetDeviceAddress<T>(inputs, 1); | |||
| auto scale = GetDeviceAddress<float>(inputs, 2); | |||
| auto save_mean = GetDeviceAddress<float>(inputs, 3); | |||
| auto save_variance = GetDeviceAddress<float>(inputs, 4); | |||
| auto dx = GetDeviceAddress<T>(outputs, 0); | |||
| auto bn_scale = GetDeviceAddress<float>(outputs, 1); | |||
| auto bn_bias = GetDeviceAddress<float>(outputs, 2); | |||
| auto reserve_1 = GetDeviceAddress<T>(outputs, 3); | |||
| auto reserve_2 = GetDeviceAddress<T>(outputs, 4); | |||
| // For CI only, reserved vars can not be unused. | |||
| MS_LOG(DEBUG) << reinterpret_cast<size_t>(reserve_1) << reinterpret_cast<size_t>(reserve_2); // NOLINT | |||
| if (is_training_) { | |||
| const float alpha_data_diff = 1; | |||
| const float beta_data_diff = 0; | |||
| const float alpha_param_diff = 1; | |||
| const float beta_param_diff = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, | |||
| &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, | |||
| scale, bn_scale, bn_bias, epsilon_, save_mean, save_variance), | |||
| "Kernel Launch Failed."); | |||
| } else { | |||
| CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, epsilon_, batch_, channel_, | |||
| height_, width_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| InitResource(); | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 5) { | |||
| MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", BatchNormGradGpuKernel should be 5"; | |||
| } | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4"; | |||
| return false; | |||
| } | |||
| is_null_input_ = CHECK_NULL_INPUT(shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "BatchNormGradGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| batch_ = SizeToInt(shape[0]); | |||
| channel_ = SizeToInt(shape[1]); | |||
| height_ = SizeToInt(shape[2]); | |||
| width_ = SizeToInt(shape[3]); | |||
| mode_ = CUDNN_BATCHNORM_SPATIAL; | |||
| is_training_ = GetAttr<bool>(kernel_node, "is_training"); | |||
| epsilon_ = GetAttr<float>(kernel_node, "epsilon"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set dy desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set dx desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), | |||
| "Set para desc failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void DestroyResource() noexcept override { | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_desc_), | |||
| "Destroy para desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); | |||
| } | |||
| protected: | |||
| void InitResource() override { | |||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_desc_), | |||
| "Create para desc failed"); | |||
| } | |||
| void InitSizeLists() override { | |||
| size_t input_size = 0; | |||
| size_t para_size = 0; | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size), | |||
| "Get input size failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size), | |||
| "Get input size failed"); | |||
| } | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(para_size); | |||
| input_size_list_.push_back(para_size); | |||
| input_size_list_.push_back(para_size); | |||
| output_size_list_.push_back(input_size); | |||
| output_size_list_.push_back(para_size); | |||
| output_size_list_.push_back(para_size); | |||
| output_size_list_.push_back(input_size); | |||
| output_size_list_.push_back(input_size); | |||
| } | |||
| private: | |||
| int batch_; | |||
| int channel_; | |||
| int height_; | |||
| int width_; | |||
| cudnnBatchNormMode_t mode_; | |||
| bool is_training_; | |||
| double epsilon_; | |||
| bool is_null_input_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| cudnnTensorDescriptor_t dy_desc_; | |||
| cudnnTensorDescriptor_t dx_desc_; | |||
| cudnnTensorDescriptor_t scale_bias_desc_; | |||
| cudnnHandle_t handle_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_ | |||
| @@ -1,74 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BatchNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,204 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class FusedBatchNormGpuKernel : public GpuKernel { | |||
| public: | |||
| FusedBatchNormGpuKernel() | |||
| : batch_(0), | |||
| channel_(0), | |||
| height_(0), | |||
| width_(0), | |||
| mode_(CUDNN_BATCHNORM_SPATIAL), | |||
| epsilon_(10e-5), | |||
| exp_avg_factor_(0.1), | |||
| is_train_(false), | |||
| is_null_input_(false), | |||
| x_desc_(nullptr), | |||
| y_desc_(nullptr), | |||
| scale_bias_mean_var_desc_(nullptr), | |||
| handle_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT) {} | |||
| ~FusedBatchNormGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| VARIABLE_NOT_USED(stream_ptr); | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| auto x = GetDeviceAddress<T>(inputs, 0); | |||
| auto scale = GetDeviceAddress<float>(inputs, 1); | |||
| auto bias = GetDeviceAddress<float>(inputs, 2); | |||
| auto runing_mean = GetDeviceAddress<float>(inputs, 3); | |||
| auto runnig_variance = GetDeviceAddress<float>(inputs, 4); | |||
| auto y = GetDeviceAddress<T>(outputs, 0); | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if (is_train_) { | |||
| auto save_mean = GetDeviceAddress<float>(outputs, 3); | |||
| auto save_variance = GetDeviceAddress<float>(outputs, 4); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, | |||
| scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, | |||
| runnig_variance, epsilon_, save_mean, save_variance), | |||
| "Kernel launch failed"); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||
| cudnnBatchNormalizationForwardInference(handle_, mode_, &alpha, &beta, x_desc_, x, | |||
| y_desc_, y, scale_bias_mean_var_desc_, scale, | |||
| bias, runing_mean, runnig_variance, epsilon_), | |||
| "Kernel launch failed"); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| InitResource(); | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 5) { | |||
| MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; | |||
| } | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGpuKernel should be >= 4"; | |||
| } | |||
| is_null_input_ = CHECK_NULL_INPUT(shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "FusedBatchNormGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW; | |||
| auto format = AnfAlgo::GetInputFormat(kernel_node, 0); | |||
| auto format_attr = GetAttr<std::string>(kernel_node, "format"); | |||
| if (format_attr == kOpFormat_NHWC) { | |||
| format = kOpFormat_NHWC; | |||
| cudnn_format = CUDNN_TENSOR_NHWC; | |||
| } | |||
| SetNCHW(shape, &batch_, &channel_, &height_, &width_, format); | |||
| mode_ = CUDNN_BATCHNORM_SPATIAL; | |||
| epsilon_ = GetAttr<float>(kernel_node, "epsilon"); | |||
| // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference | |||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (node_name == "FusedBatchNorm") { | |||
| is_train_ = true; | |||
| exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum"); | |||
| } | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set y desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, cudnn_format, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), | |||
| "Set para desc failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void DestroyResource() noexcept override { | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), | |||
| "Destroy para desc failed"); | |||
| } | |||
| protected: | |||
| void InitResource() override { | |||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), | |||
| "Create para desc failed"); | |||
| } | |||
| void InitSizeLists() override { | |||
| size_t input_size = 0; | |||
| size_t para_size = 0; | |||
| size_t output_size = 0; | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size), | |||
| "Get input size failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size), | |||
| "Get para size failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size), | |||
| "Get para size failed"); | |||
| } | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(para_size); // scale | |||
| input_size_list_.push_back(para_size); // bias | |||
| input_size_list_.push_back(para_size); // mean | |||
| input_size_list_.push_back(para_size); // variance | |||
| output_size_list_.push_back(output_size); | |||
| output_size_list_.push_back(para_size); // running mean | |||
| output_size_list_.push_back(para_size); // running variance | |||
| output_size_list_.push_back(para_size); // save mean | |||
| output_size_list_.push_back(para_size); // save variance | |||
| return; | |||
| } | |||
| private: | |||
| int batch_; | |||
| int channel_; | |||
| int height_; | |||
| int width_; | |||
| cudnnBatchNormMode_t mode_; | |||
| double epsilon_; | |||
| double exp_avg_factor_; | |||
| bool is_train_; | |||
| bool is_null_input_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| cudnnTensorDescriptor_t y_desc_; | |||
| cudnnTensorDescriptor_t scale_bias_mean_var_desc_; | |||
| cudnnHandle_t handle_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ | |||
| @@ -1,44 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| FusedBatchNormGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,188 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class FusedBatchNormGradGpuKernel : public GpuKernel { | |||
| public: | |||
| FusedBatchNormGradGpuKernel() | |||
| : batch_(0), | |||
| channel_(0), | |||
| height_(0), | |||
| width_(0), | |||
| mode_(CUDNN_BATCHNORM_SPATIAL), | |||
| epsilon_(10e-5), | |||
| is_null_input_(false), | |||
| x_desc_(nullptr), | |||
| dy_desc_(nullptr), | |||
| dx_desc_(nullptr), | |||
| scale_bias_desc_(nullptr), | |||
| handle_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT) {} | |||
| ~FusedBatchNormGradGpuKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| VARIABLE_NOT_USED(stream_ptr); | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| auto dy = GetDeviceAddress<T>(inputs, 0); | |||
| auto x = GetDeviceAddress<T>(inputs, 1); | |||
| auto scale = GetDeviceAddress<float>(inputs, 2); | |||
| auto save_mean = GetDeviceAddress<float>(inputs, 3); | |||
| auto save_variance = GetDeviceAddress<float>(inputs, 4); | |||
| auto dx = GetDeviceAddress<T>(outputs, 0); | |||
| auto bn_scale = GetDeviceAddress<float>(outputs, 1); | |||
| auto bn_bias = GetDeviceAddress<float>(outputs, 2); | |||
| const float alpha_data_diff = 1; | |||
| const float beta_data_diff = 0; | |||
| const float alpha_param_diff = 1; | |||
| const float beta_param_diff = 0; | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, | |||
| &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale, | |||
| bn_scale, bn_bias, epsilon_, save_mean, save_variance), | |||
| "Kernel Launch Failed."); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| InitResource(); | |||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 5) { | |||
| MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5"; | |||
| } | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradGpuKernel should be 4"; | |||
| return false; | |||
| } | |||
| is_null_input_ = CHECK_NULL_INPUT(shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "FusedBatchNormGradGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| batch_ = SizeToInt(shape[0]); | |||
| channel_ = SizeToInt(shape[1]); | |||
| height_ = SizeToInt(shape[2]); | |||
| width_ = SizeToInt(shape[3]); | |||
| mode_ = CUDNN_BATCHNORM_SPATIAL; | |||
| epsilon_ = GetAttr<float>(kernel_node, "epsilon"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set dy desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), | |||
| "Set dx desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), | |||
| "Set para desc failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void DestroyResource() noexcept override { | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_desc_), | |||
| "Destroy para desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); | |||
| } | |||
| protected: | |||
| void InitResource() override { | |||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_desc_), | |||
| "Create para desc failed"); | |||
| } | |||
| void InitSizeLists() override { | |||
| size_t input_size = 0; | |||
| size_t para_size = 0; | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size), | |||
| "Get input size failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size), | |||
| "Get input size failed"); | |||
| } | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(para_size); | |||
| input_size_list_.push_back(para_size); | |||
| input_size_list_.push_back(para_size); | |||
| output_size_list_.push_back(input_size); | |||
| output_size_list_.push_back(para_size); | |||
| output_size_list_.push_back(para_size); | |||
| } | |||
| private: | |||
| int batch_; | |||
| int channel_; | |||
| int height_; | |||
| int width_; | |||
| cudnnBatchNormMode_t mode_; | |||
| double epsilon_; | |||
| bool is_null_input_; | |||
| cudnnTensorDescriptor_t x_desc_; | |||
| cudnnTensorDescriptor_t dy_desc_; | |||
| cudnnTensorDescriptor_t dx_desc_; | |||
| cudnnTensorDescriptor_t scale_bias_desc_; | |||
| cudnnHandle_t handle_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ | |||
| @@ -34,7 +34,12 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | |||
| auto bn_grad_inputs = bn_grad_node->inputs(); | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| if (AnfAlgo::CheckPrimitiveType(bn_grad_node, prim::kPrimBatchNormGrad)) { | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| } else { | |||
| CheckCNodeInputSize(bn_grad_node, kSyncBNGradInputTensorNum); | |||
| } | |||
| std::vector<AnfNodePtr> bn_update_grad_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], | |||
| bn_grad_inputs[4], bn_grad_inputs[5]}; | |||
| @@ -57,7 +62,12 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_grad_node); | |||
| auto bn_grad_inputs = bn_grad_node->inputs(); | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| if (AnfAlgo::CheckPrimitiveType(bn_grad_node, prim::kPrimBatchNormGrad)) { | |||
| CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); | |||
| } else { | |||
| CheckCNodeInputSize(bn_grad_node, kSyncBNGradInputTensorNum); | |||
| } | |||
| if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; | |||
| } | |||
| @@ -110,6 +120,7 @@ CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &c | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> bn_update_grad_outputs; | |||
| CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); | |||
| if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -38,7 +38,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(bn_cnode); | |||
| if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) { | |||
| MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString(); | |||
| MS_LOG(INFO) << "BatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString(); | |||
| return false; | |||
| } | |||
| std::vector<AnfNodePtr> bn_training_reduce_inputs = { | |||
| @@ -51,7 +51,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & | |||
| bn_training_reduce->set_kernel_info(kernel_info); | |||
| std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); | |||
| if (bn_shape_i0.size() < kShape2dDims) { | |||
| MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; | |||
| MS_LOG(INFO) << "The BatchNorm's first input's shape dims less than " << kShape2dDims; | |||
| return false; | |||
| } | |||
| std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]}; | |||
| @@ -33,7 +33,7 @@ CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchn | |||
| MS_EXCEPTION_IF_NULL(batchnorm_grad); | |||
| auto prim = std::make_shared<Primitive>(kBatchNorm3DGradOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||
| for (size_t i = 1; i < batchnorm_grad->size(); ++i) { | |||
| for (size_t i = 1; i < batchnorm_grad->size() - 1; ++i) { | |||
| inputs.push_back(batchnorm_grad->input(i)); | |||
| } | |||
| auto new_node = graph->NewCNode(inputs); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -56,7 +56,8 @@ constexpr size_t kBN1OutputNum = 2; | |||
| constexpr size_t kBN2OutputNum = 3; | |||
| constexpr size_t kBN3OutputNum = 1; | |||
| constexpr size_t kBNGradInputTensorNum = 5; | |||
| constexpr size_t kBNGradInputTensorNum = 6; | |||
| constexpr size_t kSyncBNGradInputTensorNum = 5; | |||
| constexpr size_t kBNGradOutputNum = 3; | |||
| constexpr size_t kBNGrad1OutputNum = 3; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -28,8 +28,8 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef BatchNormAddReluFusion::DefinePattern() const { | |||
| VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); | |||
| VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); | |||
| VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_}); | |||
| VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_}); | |||
| VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_}); | |||
| VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); | |||
| return relu; | |||
| @@ -44,24 +44,24 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons | |||
| MS_EXCEPTION_IF_NULL(tensor_add); | |||
| auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0); | |||
| MS_EXCEPTION_IF_NULL(tuple_get_item); | |||
| auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm_ex); | |||
| auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); | |||
| auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm); | |||
| auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format"); | |||
| MS_EXCEPTION_IF_NULL(format_attr); | |||
| auto format = GetValue<std::string>(format_attr); | |||
| if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | |||
| if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") { | |||
| return nullptr; | |||
| } | |||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); | |||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0); | |||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||
| return nullptr; | |||
| } | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2); | |||
| auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3); | |||
| auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4); | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); | |||
| auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3); | |||
| auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4); | |||
| auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| @@ -71,7 +71,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons | |||
| MS_EXCEPTION_IF_NULL(var); | |||
| MS_EXCEPTION_IF_NULL(z); | |||
| auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation); | |||
| auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z}; | |||
| auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs); | |||
| @@ -79,17 +79,17 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons | |||
| std::vector<TypeId> outputs_type; | |||
| std::vector<std::vector<size_t>> outputs_shape; | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); | |||
| AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu); | |||
| AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu); | |||
| manager->Replace(batch_norm, fused_batch_norm_with_add_relu); | |||
| device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu); | |||
| return tuple_get_item; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -85,14 +85,14 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A | |||
| std::vector<AnfNodePtr> bn_add_relu_grad_output; | |||
| CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output); | |||
| if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) { | |||
| MS_LOG(EXCEPTION) << "The output size of node " << kFusedBatchNormGradExWithAddAndActivation << " must be " | |||
| MS_LOG(EXCEPTION) << "The output size of node " << kBatchNormGradWithAddAndActivation << " must be " | |||
| << kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size(); | |||
| } | |||
| // Get bn outputs | |||
| std::vector<AnfNodePtr> bn_outputs; | |||
| if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) { | |||
| MS_LOG(INFO) << "The " << prim::kPrimFusedBatchNormGradEx | |||
| MS_LOG(INFO) << "The " << prim::kPrimBatchNormGrad | |||
| << " node should only have output 0, 1 and 2. The node should not be changed"; | |||
| return; | |||
| } | |||
| @@ -139,7 +139,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0); | |||
| if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) { | |||
| if (AnfAlgo::GetCNodeName(forward_node) != kBatchNormWithAddAndActivation) { | |||
| return false; | |||
| } | |||
| @@ -150,7 +150,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| const BaseRef BatchNormAddReluGradFusion::DefinePattern() const { | |||
| VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); | |||
| VectorRef batch_norm_grad = | |||
| VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); | |||
| VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); | |||
| return batch_norm_grad; | |||
| } | |||
| @@ -184,7 +184,7 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); | |||
| MS_EXCEPTION_IF_NULL(bias); | |||
| auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithAddAndActivation); | |||
| auto prim = std::make_shared<Primitive>(kBatchNormGradWithAddAndActivation); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y}; | |||
| auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -28,8 +28,8 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef BatchNormReluFusion::DefinePattern() const { | |||
| VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); | |||
| VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); | |||
| VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_}); | |||
| VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_}); | |||
| VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get}); | |||
| return relu; | |||
| } | |||
| @@ -41,24 +41,24 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A | |||
| auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0); | |||
| MS_EXCEPTION_IF_NULL(tuple_get_item); | |||
| auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm_ex); | |||
| auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); | |||
| auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm); | |||
| auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format"); | |||
| MS_EXCEPTION_IF_NULL(format_attr); | |||
| auto format = GetValue<std::string>(format_attr); | |||
| if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | |||
| if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") { | |||
| return nullptr; | |||
| } | |||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); | |||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0); | |||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||
| return nullptr; | |||
| } | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2); | |||
| auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3); | |||
| auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4); | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); | |||
| auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3); | |||
| auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(scale); | |||
| @@ -66,7 +66,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A | |||
| MS_EXCEPTION_IF_NULL(mean); | |||
| MS_EXCEPTION_IF_NULL(var); | |||
| auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithActivation); | |||
| auto prim = std::make_shared<Primitive>(kBatchNormWithActivation); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var}; | |||
| auto fused_batch_norm_with_relu = graph->NewCNode(inputs); | |||
| @@ -74,17 +74,17 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A | |||
| std::vector<TypeId> outputs_type; | |||
| std::vector<std::vector<size_t>> outputs_shape; | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get()); | |||
| AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_relu); | |||
| AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_relu); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(batch_norm_ex, fused_batch_norm_with_relu); | |||
| manager->Replace(batch_norm, fused_batch_norm_with_relu); | |||
| device::gpu::SetKernelInfo(fused_batch_norm_with_relu); | |||
| return tuple_get_item; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -31,7 +31,7 @@ namespace opt { | |||
| const BaseRef BatchNormReluGradFusion::DefinePattern() const { | |||
| VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_}); | |||
| VectorRef batch_norm_grad = | |||
| VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); | |||
| VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_}); | |||
| return batch_norm_grad; | |||
| } | |||
| @@ -82,7 +82,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); | |||
| MS_EXCEPTION_IF_NULL(bias); | |||
| auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithActivation); | |||
| auto prim = std::make_shared<Primitive>(kBatchNormGradWithActivation); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y}; | |||
| auto fused_batch_norm_grad_with_relu = graph->NewCNode(inputs); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -42,8 +42,7 @@ struct AnfNodeIndex { | |||
| }; | |||
| // opname, output idx | |||
| std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, | |||
| {kFusedBatchNormGradExWithAddAndActivation, 3}}; | |||
| std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, {kBatchNormGradWithAddAndActivation, 3}}; | |||
| std::set<string> kSkipOpNames = { | |||
| kTensorAddOpName, | |||
| @@ -51,7 +50,7 @@ std::set<string> kSkipOpNames = { | |||
| // opname, input idx | |||
| std::map<string, uint32_t> kAggregatesOpNames = { | |||
| {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}}; | |||
| {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kBatchNormGradWithAddAndActivation, 0}}; | |||
| constexpr size_t inplace_node_size = 2; | |||
| @@ -28,8 +28,8 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef PostBatchNormAddReluFusion::DefinePattern() const { | |||
| VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); | |||
| VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); | |||
| VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_}); | |||
| VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_}); | |||
| VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item}); | |||
| VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); | |||
| return relu; | |||
| @@ -44,24 +44,24 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, | |||
| MS_EXCEPTION_IF_NULL(tensor_add); | |||
| auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1); | |||
| MS_EXCEPTION_IF_NULL(tuple_get_item); | |||
| auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm_ex); | |||
| auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); | |||
| auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0); | |||
| MS_EXCEPTION_IF_NULL(batch_norm); | |||
| auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format"); | |||
| MS_EXCEPTION_IF_NULL(format_attr); | |||
| auto format = GetValue<std::string>(format_attr); | |||
| if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { | |||
| if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") { | |||
| return nullptr; | |||
| } | |||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); | |||
| auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0); | |||
| if (shape.back() % kBNChannelMultipleFactor != 0) { | |||
| return nullptr; | |||
| } | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2); | |||
| auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3); | |||
| auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4); | |||
| auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0); | |||
| auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1); | |||
| auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2); | |||
| auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3); | |||
| auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4); | |||
| auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| @@ -71,7 +71,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, | |||
| MS_EXCEPTION_IF_NULL(var); | |||
| MS_EXCEPTION_IF_NULL(z); | |||
| auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation); | |||
| auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z}; | |||
| auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs); | |||
| @@ -79,17 +79,17 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, | |||
| std::vector<TypeId> outputs_type; | |||
| std::vector<std::vector<size_t>> outputs_shape; | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); | |||
| AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu); | |||
| AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu); | |||
| manager->Replace(batch_norm, fused_batch_norm_with_add_relu); | |||
| device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu); | |||
| return tuple_get_item; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -445,9 +445,10 @@ size_t BestFitMemReuse::GetAllocatedSize() { | |||
| bool BestFitMemReuse::IsRelease() { | |||
| // unable_used_node include the node type that output tensor cannot be released, | |||
| // even if its refcount is equal to zero. | |||
| std::unordered_set<std::string> unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(), | |||
| prim::kPrimFusedBatchNorm->name(), | |||
| prim::kPrimFusedBatchNormGrad->name()}; | |||
| std::unordered_set<std::string> unable_used_node = { | |||
| prim::kPrimBatchNorm->name(), | |||
| prim::kPrimBatchNormGrad->name(), | |||
| }; | |||
| return unable_used_node.find(current_kernel_->kernel_name()) == unable_used_node.end(); | |||
| } | |||
| @@ -494,7 +495,7 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { | |||
| #endif | |||
| for (const auto &op_def_ptr : op_ptr_list_) { | |||
| current_kernel_ = op_def_ptr; | |||
| // releas pre_op_def | |||
| // release pre_op_def | |||
| if (pre_op != nullptr) { | |||
| ReleasePreNodeWorkspace(pre_op.get()); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -450,12 +450,7 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, | |||
| // Experimental support for 3D data (input_size == 3). | |||
| if (input_size >= 1 && input_size <= 4) { | |||
| if (dim == 0) { | |||
| // Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors. | |||
| if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) { | |||
| s.push_back(1); | |||
| } else { | |||
| s.push_back(std::min(max_device_num, target_tensor_batch)); | |||
| } | |||
| s.push_back(std::min(max_device_num, target_tensor_batch)); | |||
| } else { | |||
| s.push_back(1); | |||
| } | |||
| @@ -533,8 +528,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector | |||
| return PrepareOneHot(graph, ops, iter_graph, iter_ops); | |||
| } else if ((type == SOFTMAX) || (type == LAYER_NORM)) { | |||
| return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops); | |||
| } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || | |||
| (type == "FusedBatchNormEx") || (type == "Dropout") || (type == BATCH_MATMUL)) { | |||
| } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || (type == "Dropout") || | |||
| (type == BATCH_MATMUL)) { | |||
| return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); | |||
| } else { | |||
| return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -46,7 +46,6 @@ const std::map<std::string, OperatorType> DictOpType{ | |||
| {RESHAPE, OperatorType::kRecReshape}, | |||
| {BIAS_ADD, OperatorType::kRecBiasAdd}, | |||
| {BATCH_NORM, OperatorType::kRecBatchNorm}, | |||
| {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, | |||
| {LAYER_NORM, OperatorType::kRecBatchNorm}, | |||
| {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, | |||
| {ONEHOT, OperatorType::kRecOneHot}, | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -53,9 +53,7 @@ enum MatchCountPriority : int { | |||
| MATCH_COUNT_PRIORITY_END | |||
| }; | |||
| const std::map<std::string, std::vector<std::string>> kNextOpFormatList = { | |||
| {prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}, | |||
| {prim::kPrimFusedBatchNorm->name(), | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}}}; | |||
| {prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}}; | |||
| bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -233,6 +231,24 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype( | |||
| return result; | |||
| } | |||
| bool CheckHitTargetDtype(const std::map<TypeId, TypeId> &type_map, const TypeId &in_dtype, const TypeId &device_dtype, | |||
| bool *flag) { | |||
| auto iter = type_map.find(in_dtype); | |||
| // if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false | |||
| if (iter == type_map.end() && in_dtype != device_dtype) { | |||
| return false; | |||
| } | |||
| // infer dtype in type_map, but can not find dst dtype that supported raise or reduce, | |||
| // or infer dtype not equal kernel info dtype, return false | |||
| if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { | |||
| return false; | |||
| } | |||
| if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) { | |||
| *flag = true; | |||
| } | |||
| return true; | |||
| } | |||
| bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode, | |||
| const std::map<TypeId, TypeId> &type_map) { | |||
| // filte kernel info that unsupported raise or reduce datatype | |||
| @@ -245,19 +261,9 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build | |||
| if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) { | |||
| device_dtype = kNumberTypeFloat32; | |||
| } | |||
| auto iter = type_map.find(in_dtype); | |||
| // if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false | |||
| if (iter == type_map.end() && in_dtype != device_dtype) { | |||
| if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) { | |||
| return false; | |||
| } | |||
| // infer dtype in type_map, but can not find dst dtype that supported raise or reduce, | |||
| // or infer dtype not equal kernel info dtype, return false | |||
| if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { | |||
| return false; | |||
| } | |||
| if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) { | |||
| flag = true; | |||
| } | |||
| } | |||
| for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) { | |||
| @@ -266,19 +272,10 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build | |||
| if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) { | |||
| device_dtype = kNumberTypeFloat32; | |||
| } | |||
| auto iter = type_map.find(in_dtype); | |||
| // if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false | |||
| if (iter == type_map.end() && in_dtype != device_dtype) { | |||
| return false; | |||
| } | |||
| // infer dtype in type_map, but can not find dst dtype that supported raise or reduce, | |||
| // or infer dtype not equal kernel info dtype, return false | |||
| if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { | |||
| if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) { | |||
| return false; | |||
| } | |||
| if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) { | |||
| flag = true; | |||
| } | |||
| } | |||
| if (flag) { | |||
| auto node_name = AnfAlgo::GetCNodeName(cnode); | |||
| @@ -101,9 +101,9 @@ namespace { | |||
| std::vector<int> CheckRealOutput(const std::string &node_name, const size_t &output_size) { | |||
| // define a vector containing real output number | |||
| std::vector<int> real_outputs; | |||
| // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference | |||
| // P.BatchNorm is used for training and inference | |||
| // can add the filter list for more operators here.... | |||
| if (node_name == "FusedBatchNorm" || node_name == "BatchNorm") { | |||
| if (node_name == "BatchNorm") { | |||
| MS_LOG(INFO) << "loading node named " << node_name; | |||
| real_outputs.insert(real_outputs.end(), {0, 3, 4}); | |||
| } else { | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -374,7 +374,7 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s | |||
| if (kernel_name == prim::kPrimConv2D->name()) { | |||
| conv_cnt++; | |||
| } | |||
| if (kernel_name == prim::kPrimFusedBatchNormEx->name()) { | |||
| if (kernel_name == prim::kPrimBatchNorm->name()) { | |||
| bn_cnt++; | |||
| } | |||
| } | |||
| @@ -46,12 +46,12 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>> | |||
| {prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}}, | |||
| {kAvgPoolOpName, {{0}, {0}}}, | |||
| {kAvgPoolGradOpName, {{0, 1, 2}, {0}}}, | |||
| {kFusedBatchNormEx, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithActivation, {{0}, {0}}}, | |||
| {kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}}, | |||
| {kFusedBatchNormGradEx, {{0, 1}, {0}}}, | |||
| {kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}}, | |||
| {kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, | |||
| {kBatchNorm, {{0}, {0}}}, | |||
| {kBatchNormWithActivation, {{0}, {0}}}, | |||
| {kBatchNormWithAddAndActivation, {{0, 5}, {0}}}, | |||
| {kBatchNormGradOpName, {{0, 1}, {0}}}, | |||
| {kBatchNormGradWithActivation, {{0, 1, 7}, {0}}}, | |||
| {kBatchNormGradWithAddAndActivation, {{0, 1, 7}, {0, 3}}}, | |||
| {kBiasAddOpName, {{0}, {0}}}, | |||
| {prim::kPrimBiasAddGrad->name(), {{0}, {}}}, | |||
| // Format insensitive. | |||
| @@ -50,13 +50,12 @@ constexpr auto kFusedBN3OpName = "FusedBN3"; | |||
| constexpr auto kBNGrad1OpName = "BNGrad1"; | |||
| constexpr auto kBNGrad2OpName = "BNGrad2"; | |||
| constexpr auto kBNGrad3OpName = "BNGrad3"; | |||
| constexpr auto kFusedBatchNormEx = "FusedBatchNormEx"; | |||
| constexpr auto kBatchNorm = "BatchNorm"; | |||
| constexpr auto kInstanceNorm = "InstanceNorm"; | |||
| constexpr auto kFusedBatchNormExWithActivation = "FusedBatchNormExWithActivation"; | |||
| constexpr auto kFusedBatchNormExWithAddAndActivation = "FusedBatchNormExWithAddAndActivation"; | |||
| constexpr auto kFusedBatchNormGradEx = "FusedBatchNormGradEx"; | |||
| constexpr auto kFusedBatchNormGradExWithActivation = "FusedBatchNormGradExWithActivation"; | |||
| constexpr auto kFusedBatchNormGradExWithAddAndActivation = "FusedBatchNormGradExWithAddAndActivation"; | |||
| constexpr auto kBatchNormWithActivation = "BatchNormWithActivation"; | |||
| constexpr auto kBatchNormWithAddAndActivation = "BatchNormWithAddAndActivation"; | |||
| constexpr auto kBatchNormGradWithActivation = "BatchNormGradWithActivation"; | |||
| constexpr auto kBatchNormGradWithAddAndActivation = "BatchNormGradWithAddAndActivation"; | |||
| constexpr auto kClearZeroOpName = "ClearZero"; | |||
| constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; | |||
| constexpr auto kGetNextOpName = "GetNext"; | |||
| @@ -45,14 +45,10 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -140,77 +140,6 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: five tensors(x, gamma, beta, mean, variance). | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 5); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| MS_LOG(DEBUG) << "InferImplFusedBatchNorm args0:" << args_spec_list[0]->ToString() | |||
| << ", arg1:" << args_spec_list[1]->ToString(); | |||
| FusedBatchNormCheckDim(primitive, args_spec_list); | |||
| auto input = args_spec_list[0]; | |||
| auto input_shape = dyn_cast<Shape>(input->GetShapeTrack()); | |||
| MS_EXCEPTION_IF_NULL(input_shape); | |||
| const auto &input_shape_list = input_shape->shape(); | |||
| if (input_shape_list.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Input shape size should >= 2."; | |||
| } | |||
| for (size_t i = 1; i < args_spec_list.size(); ++i) { | |||
| auto arg_shape = dyn_cast<Shape>(args_spec_list[i]->GetShapeTrack()); | |||
| MS_EXCEPTION_IF_NULL(arg_shape); | |||
| const auto &arg_shape_list = arg_shape->shape(); | |||
| if (arg_shape_list.size() < 1) { | |||
| MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; | |||
| } | |||
| if (arg_shape_list[0] != input_shape_list[1]) { | |||
| MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] | |||
| << ") should match the second dimension of tensor" | |||
| " param[0](which is " | |||
| << input_shape_list[1] << ")."; | |||
| } | |||
| } | |||
| auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param 0 of FusedBatchNorm should be %s"); | |||
| AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>(); | |||
| for (size_t i = 1; i < args_spec_list.size(); ++i) { | |||
| auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | |||
| tensorPtrList.push_back(param); | |||
| } | |||
| (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, "param 1 to 4 of FusedBatchNorm should be %s"); | |||
| // check validity; | |||
| auto epsilon_value = primitive->GetAttr("epsilon"); | |||
| auto momentum_value = primitive->GetAttr("momentum"); | |||
| MS_EXCEPTION_IF_NULL(epsilon_value); | |||
| MS_EXCEPTION_IF_NULL(momentum_value); | |||
| if (!epsilon_value->isa<FP32Imm>() || !momentum_value->isa<FP32Imm>()) { | |||
| MS_LOG(EXCEPTION) << "expect epsilon and momentum be float, but: epsilon: " << epsilon_value->ToString() | |||
| << ", momentum: " << momentum_value->ToString(); | |||
| } | |||
| auto epsilon = epsilon_value->cast<FP32ImmPtr>()->value(); | |||
| auto momentum = momentum_value->cast<FP32ImmPtr>()->value(); | |||
| if (epsilon > 1.0f || epsilon <= 0.0f) { | |||
| MS_LOG(EXCEPTION) << "expect epsilon is greater than 0 and less or equal than 1, but epsilon: " << epsilon; | |||
| } | |||
| if (momentum > 1.0f || momentum < 0.0f) { | |||
| MS_LOG(EXCEPTION) << "expect momentum is great or equal than 0 and less or equal than 1, but epsilon: " << momentum; | |||
| } | |||
| // Outputs: y, running_mean, running_variance, save_mean, save_inv_variance. | |||
| AbstractBasePtr y = input->Broaden(); | |||
| AbstractBasePtr other = args_spec_list[1]->Broaden(); | |||
| MS_LOG(DEBUG) << "output y: " << y->ToString() << ", other: " << other->ToString(); | |||
| AbstractBasePtrList elements = {y, other, other, other, other}; | |||
| return std::make_shared<AbstractTuple>(elements); | |||
| } | |||
| AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| @@ -228,24 +157,8 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin | |||
| return std::make_shared<abstract::AbstractTensor>(type_tensor->element(), shape); | |||
| } | |||
| AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[1]); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[2]); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[3]); | |||
| CheckArgsSize(primitive->name(), args_spec_list, 5); | |||
| auto dx = args_spec_list[1]->Broaden(); | |||
| auto dscale = args_spec_list[2]->Broaden(); | |||
| auto dbias = args_spec_list[3]->Broaden(); | |||
| AbstractBasePtrList rets = {dx, dscale, dbias}; | |||
| return std::make_shared<AbstractTuple>(rets); | |||
| } | |||
| AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: five tensors(x, gamma, beta, mean, variance). | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 5); | |||
| @@ -256,12 +169,21 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi | |||
| ShapeVector x_min_shape = input_x->shape()->min_shape(); | |||
| ShapeVector x_max_shape = input_x->shape()->max_shape(); | |||
| CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); | |||
| if (x_shape.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "Input rank should 4."; | |||
| auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be"); | |||
| AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>(); | |||
| for (size_t i = 1; i < args_spec_list.size(); ++i) { | |||
| auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | |||
| tensorPtrList.push_back(param); | |||
| } | |||
| (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, | |||
| "param gamma, beta, mean, variance of Batchnorm should be"); | |||
| auto data_format_ptr = primitive->GetAttr("format"); | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format = GetAndCheckFormat(data_format_ptr); | |||
| int64_t c_axis = 1; | |||
| if (data_format == Format::NHWC) { | |||
| c_axis = 3; | |||
| @@ -275,8 +197,8 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi | |||
| MS_LOG(EXCEPTION) << "Arg " << i << " rank should be 1, but got " << arg_shape.size(); | |||
| } | |||
| if ((x_shape[c_axis] != Shape::SHP_ANY) && (arg_shape[0] != x_shape[c_axis])) { | |||
| MS_LOG(EXCEPTION) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis << "]=" << x_shape[c_axis] | |||
| << ", but got " << arg_shape[0]; | |||
| MS_EXCEPTION(ValueError) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis | |||
| << "]=" << x_shape[c_axis] << ", but got " << arg_shape[0]; | |||
| } | |||
| } | |||
| AbstractTensorPtr input_gamma = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| @@ -288,7 +210,7 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi | |||
| AbstractTensorPtr output = std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr); | |||
| ShapePtr gamma_shape_ptr = std::make_shared<Shape>(gamma_shape, gamma_min_shape, gamma_max_shape); | |||
| AbstractTensorPtr output_gamma = std::make_shared<AbstractTensor>(input_gamma->element(), gamma_shape_ptr); | |||
| AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma, output_gamma}; | |||
| AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma}; | |||
| return std::make_shared<AbstractTuple>(rets); | |||
| } | |||
| @@ -117,9 +117,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| // NN | |||
| {prim::kPrimPooling, {InferImplPooling, true}}, | |||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | |||
| {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, | |||
| {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, | |||
| {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, | |||
| {prim::kPrimBatchNorm, {InferImplBatchNorm, true}}, | |||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | |||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | |||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, | |||
| @@ -219,13 +219,10 @@ inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoo | |||
| inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm"); | |||
| inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam"); | |||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | |||
| inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D"); | |||
| inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection"); | |||
| inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose"); | |||
| inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx"); | |||
| inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm"); | |||
| inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad"); | |||
| inline const PrimitivePtr kPrimSyncBatchNorm = std::make_shared<Primitive>("SyncBatchNorm"); | |||
| @@ -130,8 +130,6 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC | |||
| {"MaxPoolGradGradWithArgmax", FormatAndPadUpperAttrMap}, | |||
| {"BatchNorm", DataFormatMap}, | |||
| {"BatchNormGrad", DataFormatMap}, | |||
| {"FusedBatchNormEx", DataFormatMap}, | |||
| {"FusedBatchNormGradEx", DataFormatMap}, | |||
| {"BiasAdd", DataFormatMap}, | |||
| {"BiasAddGrad", DataFormatMap}, | |||
| {"BinaryCrossEntropy", ReductionMap}, | |||
| @@ -140,25 +140,16 @@ class _BatchNorm(Cell): | |||
| else: | |||
| self.is_ge_backend = False | |||
| if self._target == "Ascend": | |||
| self.bn_train = P.BatchNorm(is_training=True, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum, | |||
| data_format=self.format) | |||
| if self._target == "GPU": | |||
| self.bn_train = P.FusedBatchNormEx(mode=1, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum, | |||
| data_format=self.format) | |||
| if self._target == "CPU": | |||
| self.bn_train = P.FusedBatchNorm(mode=1, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum) | |||
| self.bn_train = P.BatchNorm(is_training=True, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum, | |||
| data_format=self.format) | |||
| if self.is_global: | |||
| self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, | |||
| momentum=self.momentum, | |||
| group=SYNC_BN_GROUP_NAME, | |||
| device_num=self.group_device_num) | |||
| self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) | |||
| data_parallel_strategy = ((1,), (1,)) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -541,19 +541,9 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||
| channel_axis=channel_axis, | |||
| num_channels=out_channels, | |||
| quant_dtype=quant_dtype) | |||
| if self._target == "Ascend": | |||
| self.bn_train = P.BatchNorm(is_training=True, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum) | |||
| if self._target == "GPU": | |||
| self.bn_train = P.FusedBatchNormEx(mode=1, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum, | |||
| data_format=self.format) | |||
| if self._target == "CPU": | |||
| self.bn_train = P.FusedBatchNorm(mode=1, | |||
| epsilon=self.eps, | |||
| momentum=self.momentum) | |||
| self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, | |||
| momentum=self.momentum, data_format=self.format) | |||
| self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) | |||
| data_parallel_strategy = ((1,), (1,)) | |||
| data_parallel_strategy_one = ((1,), ()) | |||
| @@ -647,49 +647,6 @@ def get_bprop_fast_gelu_2(self): | |||
| return bprop | |||
| @bprop_getters.register(P.FusedBatchNorm) | |||
| def get_bprop_fused_batch_norm(self): | |||
| """Grad definition for `FusedBatchNorm` operation.""" | |||
| input_grad = G.FusedBatchNormGrad(self.epsilon, self.momentum) | |||
| target_cpu = False | |||
| if self.target == "CPU": | |||
| input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum) | |||
| target_cpu = True | |||
| def bprop(x, scale, b, mean, variance, out, dout): | |||
| saved_mean = out[3] | |||
| saved_variance = out[4] | |||
| if target_cpu: | |||
| out = input_grad(dout[0], x, scale, b, saved_mean, saved_variance) | |||
| else: | |||
| out = input_grad(dout[0], x, scale, saved_mean, saved_variance) | |||
| dx = out[0] | |||
| dscale = out[1] | |||
| dbias = out[2] | |||
| return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) | |||
| return bprop | |||
| @bprop_getters.register(P.FusedBatchNormEx) | |||
| def get_bprop_fused_batch_norm_ex(self): | |||
| """Grad definition for `FusedBatchNormEx` operation.""" | |||
| input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum, self.format) | |||
| def bprop(x, scale, b, mean, variance, out, dout): | |||
| saved_mean = out[3] | |||
| saved_variance = out[4] | |||
| reserve = out[5] | |||
| out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve) | |||
| dx = out[0] | |||
| dscale = out[1] | |||
| dbias = out[2] | |||
| return dx, dscale, dbias, zeros_like(mean), zeros_like(variance) | |||
| return bprop | |||
| @bprop_getters.register(P.InstanceNorm) | |||
| def get_bprop_instance_norm(self): | |||
| """Grad definition for `InstanceNorm` operation.""" | |||
| @@ -715,12 +672,14 @@ def get_bprop_batch_norm(self): | |||
| def bprop(x, scale, b, mean, variance, out, dout): | |||
| if is_training: | |||
| saved_reserve_1 = out[3] | |||
| saved_reserve_2 = out[4] | |||
| saved_mean = out[3] | |||
| saved_variance = out[4] | |||
| reserve = out[2] | |||
| else: | |||
| saved_reserve_1 = mean | |||
| saved_reserve_2 = variance | |||
| out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2) | |||
| saved_mean = mean | |||
| saved_variance = variance | |||
| reserve = out[2] | |||
| out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve) | |||
| dx = out[0] | |||
| dscale = out[1] | |||
| dbias = out[2] | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -48,12 +48,6 @@ class BiasAdd: | |||
| pass | |||
| @op_selector | |||
| class FusedBatchNorm: | |||
| def __call__(self, *args): | |||
| pass | |||
| @op_selector | |||
| class ApplyMomentum: | |||
| def __call__(self, *args): | |||
| @@ -65,9 +65,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||
| BiasAdd, Conv2D, | |||
| DepthwiseConv2dNative, | |||
| DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten, | |||
| FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate, | |||
| InstanceNorm, BNTrainingReduce, BNTrainingUpdate, | |||
| GeLU, Gelu, FastGeLU, FastGelu, Elu, | |||
| GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, | |||
| LogSoftmax, MaxPool3D, | |||
| MaxPool, DataFormatDimMap, | |||
| @@ -142,8 +141,6 @@ __all__ = [ | |||
| 'Conv2D', | |||
| 'Flatten', | |||
| 'MaxPoolWithArgmax', | |||
| 'FusedBatchNorm', | |||
| 'FusedBatchNormEx', | |||
| 'BNTrainingReduce', | |||
| 'BNTrainingUpdate', | |||
| 'BatchNorm', | |||
| @@ -197,12 +197,12 @@ class BatchNormGrad(PrimitiveWithInfer): | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name) | |||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): | |||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve): | |||
| validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) | |||
| return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) | |||
| return (x_shape, scale_shape, scale_shape) | |||
| def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): | |||
| return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) | |||
| def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape, reserve): | |||
| return (x_type, scale_type, scale_type) | |||
| class SyncBatchNormGrad(PrimitiveWithInfer): | |||
| @@ -708,53 +708,6 @@ class FlattenGrad(PrimitiveWithInfer): | |||
| return out | |||
| class FusedBatchNormGrad(Primitive): | |||
| """Gradients of FusedBatchNorm operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=0.0, momentum=0.1): | |||
| self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance'], | |||
| outputs=['dx', 'bn_scale', 'bn_bias']) | |||
| def __call__(self, dy, x, scale, save_mean, save_inv_variance): | |||
| raise NotImplementedError | |||
| class FusedBatchNormGradCPU(PrimitiveWithInfer): | |||
| """Gradients of FusedBatchNorm operation for CPU.""" | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=0.0, momentum=0.1): | |||
| self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'bias', 'save_mean', 'save_inv_variance'], | |||
| outputs=['dx', 'bn_scale', 'bn_bias']) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| def infer_shape(self, dy_shape, x_shape, scale_shape, bias_shape, save_mean_shape, save_inv_variance_shape): | |||
| return (x_shape, scale_shape, bias_shape) | |||
| def infer_dtype(self, dy_type, x_type, scale_type, bias_type, save_mean_type, save_inv_variance_type): | |||
| return (x_type, scale_type, bias_type) | |||
| class FusedBatchNormGradEx(PrimitiveWithInfer): | |||
| """Gradients of FusedBatchNormEx operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, epsilon=0.0, momentum=0.1, data_format="NCHW"): | |||
| self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'], | |||
| outputs=['dx', 'bn_scale', 'bn_bias']) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.add_prim_attr('data_format', self.format) | |||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape): | |||
| return (x_shape, scale_shape, scale_shape) | |||
| def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_type, save_variance_type, reserve_type): | |||
| return (x_type, scale_type, scale_type) | |||
| class InstanceNormGrad(PrimitiveWithInfer): | |||
| """Gradients of InstanceNorm operation.""" | |||
| @@ -817,221 +817,20 @@ class Tanh(PrimitiveWithInfer): | |||
| class FusedBatchNorm(Primitive): | |||
| r""" | |||
| FusedBatchNorm is a BatchNorm. Moving mean and moving variance will be computed instead of being loaded. | |||
| Batch Normalization is widely used in convolutional networks. This operation applies | |||
| Batch Normalization over input to avoid internal covariate shift as described in the | |||
| paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal | |||
| Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the | |||
| feature using a mini-batch of data and the learned parameters which can be described | |||
| in the following formula. | |||
| .. math:: | |||
| y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||
| where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | |||
| Args: | |||
| mode (int): Mode of batch normalization, value is 0 or 1. Default: 0. | |||
| epsilon (float): A small value added for numerical stability. Default: 1e-5. | |||
| momentum (float): The hyper parameter to compute moving average for running_mean and running_var | |||
| (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). | |||
| Momentum value must be [0, 1]. Default: 0.1. | |||
| Inputs: | |||
| - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`. | |||
| - **scale** (Parameter) - Tensor of shape :math:`(C,)`. | |||
| - **bias** (Parameter) - Tensor of shape :math:`(C,)`. | |||
| - **mean** (Parameter) - Tensor of shape :math:`(C,)`. | |||
| - **variance** (Parameter) - Tensor of shape :math:`(C,)`. | |||
| Outputs: | |||
| Tuple of 5 Tensor, the normalized input and the updated parameters. | |||
| - **output_x** (Tensor) - The same type and shape as the `input_x`. | |||
| - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. | |||
| Raises: | |||
| TypeError: If `mode` is not an int. | |||
| TypeError: If `epsilon` or `momentum` is not a float. | |||
| TypeError: If `output_x`, `updated_scale`, `updated_bias`, `updated_moving_mean` or `updated_moving_variance` is | |||
| a Tensor. | |||
| Supported Platforms: | |||
| ``CPU`` | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Parameter | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as ops | |||
| >>> class FusedBatchNormNet(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(FusedBatchNormNet, self).__init__() | |||
| >>> self.fused_batch_norm = ops.FusedBatchNorm() | |||
| >>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale") | |||
| >>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias") | |||
| >>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean") | |||
| >>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance") | |||
| >>> | |||
| >>> def construct(self, input_x): | |||
| >>> out = self.fused_batch_norm(input_x, self.scale, self.bias, self.mean, self.variance) | |||
| >>> return out | |||
| >>> | |||
| >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | |||
| >>> net = FusedBatchNormNet() | |||
| >>> output = net(input_x) | |||
| >>> result = output[0].shape | |||
| >>> print(result) | |||
| (128, 64, 32, 64) | |||
| The FusedBatchNorm interface is deprecated, please use the BatchNorm interface. | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('input_x', dtype=sig.sig_dtype.T1), | |||
| sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | |||
| outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) | |||
| self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self._update_parameter = True | |||
| self.target = context.get_context("device_target") | |||
| raise TypeError("The FusedBatchNorm interface is deprecated, please use the BatchNorm interface.") | |||
| class FusedBatchNormEx(PrimitiveWithCheck): | |||
| r""" | |||
| FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve) | |||
| than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that | |||
| moving mean and moving variance will be computed instead of being loaded. FusedBatchNormEx currently only | |||
| supports 4D inputs. | |||
| Batch Normalization is widely used in convolutional networks. This operation applies | |||
| Batch Normalization over input to avoid internal covariate shift as described in the | |||
| paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal | |||
| Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the | |||
| feature using a mini-batch of data and the learned parameters which can be described | |||
| in the following formula. | |||
| .. math:: | |||
| y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||
| where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | |||
| Args: | |||
| mode (int): Mode of batch normalization, value is 0 or 1. Default: 0. | |||
| epsilon (float): A small value added for numerical stability. Default: 1e-5. | |||
| momentum (float): The hyper parameter to compute moving average for running_mean and running_var | |||
| (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). | |||
| Momentum value must be [0, 1]. Default: 0.1. | |||
| data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. | |||
| Default: "NCHW". | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`, | |||
| data type: float16 or float32. | |||
| - **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, | |||
| data type: float32. | |||
| - **bias** (Parameter) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`, | |||
| data type: float32. | |||
| - **mean** (Parameter) - mean value, Tensor of shape :math:`(C,)`, data type: float32. | |||
| - **variance** (Parameter) - variance value, Tensor of shape :math:`(C,)`, data type: float32. | |||
| Outputs: | |||
| Tuple of 6 Tensors, the normalized input, the updated parameters and reserve. | |||
| - **output_x** (Tensor) - The output of FusedBatchNormEx, same type and shape as the `input_x`. | |||
| - **updated_scale** (Tensor) - Updated parameter scale, Tensor of shape :math:`(C,)`, data type: float32. | |||
| - **updated_bias** (Tensor) - Updated parameter bias, Tensor of shape :math:`(C,)`, data type: float32. | |||
| - **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(C,)`, data type: float32. | |||
| - **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(C,)`, | |||
| data type: float32. | |||
| - **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32. | |||
| Raises: | |||
| TypeError: If `mode` is not an int. | |||
| TypeError: If neither `epsilon` nor `momentum` is a float. | |||
| TypeError: If `data_format` is not a str. | |||
| TypeError: If `input_x` is not a Tensor. | |||
| TypeError: If dtype of `scale`, `bias`, `mean` or `variance` is not float32. | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Parameter | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as ops | |||
| >>> class FusedBatchNormExNet(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(FusedBatchNormExNet, self).__init__() | |||
| >>> self.fused_batch_norm_ex = ops.FusedBatchNormEx() | |||
| >>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale") | |||
| >>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias") | |||
| >>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean") | |||
| >>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance") | |||
| >>> | |||
| >>> def construct(self, input_x): | |||
| >>> out = self.fused_batch_norm_ex(input_x, self.scale, self.bias, self.mean, self.variance) | |||
| >>> return out | |||
| >>> | |||
| >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | |||
| >>> net = FusedBatchNormExNet() | |||
| >>> output = net(input_x) | |||
| >>> result = output[0].shape | |||
| >>> print(result) | |||
| (128, 64, 32, 64) | |||
| The FusedBatchNormEx interface is deprecated, please use the BatchNorm interface. | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('input_x', dtype=sig.sig_dtype.T1), | |||
| sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self, mode=0, epsilon=1e-5, momentum=0.1, data_format="NCHW"): | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | |||
| outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) | |||
| self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self._update_parameter = True | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.add_prim_attr('data_format', self.format) | |||
| def check_shape(self, input_x, scale, bias, mean, variance): | |||
| input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) | |||
| validator.check_equal_int(len(input_shape_norm), 4, "x rank", self.name) | |||
| validator.check_equal_int(len(scale), 1, "scale rank", self.name) | |||
| validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) | |||
| validator.check_equal_int(len(mean), 1, "mean rank", self.name) | |||
| validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) | |||
| validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) | |||
| def check_dtype(self, input_x, scale, bias, mean, variance): | |||
| validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) | |||
| args = {"scale": scale, "bias": bias} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) | |||
| args_moving = {"mean": mean, "variance": variance} | |||
| valid_dtypes = [mstype.tensor_type(mstype.float32)] | |||
| validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name) | |||
| raise TypeError("FusedBatchnormEx interface is deprecated, please use BatchNorm interface.") | |||
| class InstanceNorm(PrimitiveWithInfer): | |||
| @@ -1420,7 +1219,7 @@ class BatchNorm(PrimitiveWithInfer): | |||
| else: | |||
| args_moving = {"mean": mean, "variance": variance} | |||
| validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) | |||
| return (input_x, scale, bias, input_x, input_x) | |||
| return (input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32) | |||
| class Conv2D(PrimitiveWithCheck): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -41,7 +41,7 @@ class Grad(nn.Cell): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.bn = P.FusedBatchNorm() | |||
| self.bn = P.BatchNorm() | |||
| self.scale = Parameter(initializer('ones', [64]), name='scale') | |||
| self.b = Parameter(initializer('zeros', [64]), name='b') | |||
| self.mean = Parameter(initializer('ones', [64]), name='mean') | |||
| @@ -1,128 +0,0 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.nn import Cell | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.ops import operations as P | |||
| class NetFusedBatchNormEx(Cell): | |||
| def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None): | |||
| super(NetFusedBatchNormEx, self).__init__() | |||
| self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1) | |||
| self.moving_mean = Parameter(initializer( | |||
| mean_init, num_features), name="mean", requires_grad=False) | |||
| self.moving_variance = Parameter(initializer( | |||
| var_init, num_features), name="variance", requires_grad=False) | |||
| self.gamma = Parameter(initializer( | |||
| gamma_init, num_features), name="gamma", requires_grad=True) | |||
| self.beta = Parameter(initializer( | |||
| beta_init, num_features), name="beta", requires_grad=True) | |||
| self.dynshape = inner.GpuConvertToDynamicShape() | |||
| def construct(self, x): | |||
| x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance) | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fused_bn_ex(): | |||
| x = np.array([[ | |||
| [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], | |||
| [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) | |||
| expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], | |||
| [-0.1471, 0.7706, 1.6882, 2.6059], | |||
| [0.3118, 1.6882, 2.1471, 2.1471], | |||
| [0.7706, 0.3118, 2.6059, -0.1471]], | |||
| [[0.9119, 1.8518, 1.3819, -0.0281], | |||
| [-0.0281, 0.9119, 1.3819, 1.8518], | |||
| [2.7918, 0.4419, -0.4981, 0.9119], | |||
| [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) | |||
| weight = np.ones(2).astype(np.float32) | |||
| bias = np.ones(2).astype(np.float32) | |||
| moving_mean = np.ones(2).astype(np.float32) | |||
| moving_var = np.ones(2).astype(np.float32) | |||
| error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| bn_net = NetFusedBatchNormEx(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var)) | |||
| output_list = bn_net(Tensor(x)) | |||
| output = output_list[0] | |||
| diff = output.asnumpy() - expect_output | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| class NetFusedBatchNormExDynamic(Cell): | |||
| def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None): | |||
| super(NetFusedBatchNormExDynamic, self).__init__() | |||
| self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1) | |||
| self.moving_mean = Parameter(initializer( | |||
| mean_init, num_features), name="mean", requires_grad=False) | |||
| self.moving_variance = Parameter(initializer( | |||
| var_init, num_features), name="variance", requires_grad=False) | |||
| self.gamma = Parameter(initializer( | |||
| gamma_init, num_features), name="gamma", requires_grad=True) | |||
| self.beta = Parameter(initializer( | |||
| beta_init, num_features), name="beta", requires_grad=True) | |||
| self.dynshape = inner.GpuConvertToDynamicShape() | |||
| def construct(self, x): | |||
| x = self.dynshape(x) | |||
| x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance) | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fused_bn_ex_dynamic(): | |||
| x = np.array([[ | |||
| [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], | |||
| [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) | |||
| expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], | |||
| [-0.1471, 0.7706, 1.6882, 2.6059], | |||
| [0.3118, 1.6882, 2.1471, 2.1471], | |||
| [0.7706, 0.3118, 2.6059, -0.1471]], | |||
| [[0.9119, 1.8518, 1.3819, -0.0281], | |||
| [-0.0281, 0.9119, 1.3819, 1.8518], | |||
| [2.7918, 0.4419, -0.4981, 0.9119], | |||
| [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) | |||
| weight = np.ones(2).astype(np.float32) | |||
| bias = np.ones(2).astype(np.float32) | |||
| moving_mean = np.ones(2).astype(np.float32) | |||
| moving_var = np.ones(2).astype(np.float32) | |||
| error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| bn_net = NetFusedBatchNormExDynamic(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var)) | |||
| output_list = bn_net(Tensor(x)) | |||
| output = output_list[0] | |||
| diff = output.asnumpy() - expect_output | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @@ -414,25 +414,6 @@ TEST_F(TestOps, ReluTest) { | |||
| ASSERT_EQ(prim->name(), kPrimRelu->name()); | |||
| } | |||
| TEST_F(TestOps, FusedBatchNormTest) { | |||
| auto prim = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| ASSERT_EQ(prim->name(), kPrimFusedBatchNorm->name()); | |||
| } | |||
| TEST_F(TestOps, FusedBatchNormAttrTest) { | |||
| Primitive prim("FusedBatchNorm"); | |||
| prim.SetAttrs({ | |||
| {"epsilon", MakeValue(0.001f)}, | |||
| {"momentum", MakeValue(0.1f)}, | |||
| }); | |||
| ASSERT_EQ(prim.name(), kPrimFusedBatchNorm->name()); | |||
| FP32Imm epsilon(0.001f); | |||
| FP32Imm momentum(0.1f); | |||
| ASSERT_EQ(*prim.GetAttr("epsilon"), epsilon); | |||
| ASSERT_EQ(*prim.GetAttr("momentum"), momentum); | |||
| } | |||
| TEST_F(TestOps, PoolingTest) { | |||
| auto prim = std::make_shared<Primitive>("Pooling"); | |||
| ASSERT_EQ(prim->name(), kPrimPooling->name()); | |||
| @@ -612,65 +612,6 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) { | |||
| ASSERT_TRUE(*res == *expected); | |||
| } | |||
| TEST_F(TestPrim, test_fused_batch_norm) { | |||
| PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm; | |||
| fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f)); | |||
| fused_batch_norm->AddAttr("momentum", MakeValue(0.1f)); | |||
| FuncGraphPtr func_graph = MakeFuncGraph(fused_batch_norm, 5); | |||
| // NCHW | |||
| std::vector<int64_t> inputs_dims = {128, 64, 32, 64}; | |||
| std::vector<int64_t> scale_dims = {64}; | |||
| std::vector<int64_t> offset_dims = {64}; | |||
| std::vector<int64_t> mean_dims = {64}; | |||
| std::vector<int64_t> variance_dims = {64}; | |||
| tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>(); | |||
| inputs->set_data_type(kNumberTypeFloat32); | |||
| inputs->set_shape(inputs_dims); | |||
| tensor::TensorPtr scale = std::make_shared<tensor::Tensor>(); | |||
| scale->set_data_type(kNumberTypeFloat32); | |||
| scale->set_shape(scale_dims); | |||
| tensor::TensorPtr offset = std::make_shared<tensor::Tensor>(); | |||
| offset->set_data_type(kNumberTypeFloat32); | |||
| offset->set_shape(offset_dims); | |||
| tensor::TensorPtr mean = std::make_shared<tensor::Tensor>(); | |||
| mean->set_data_type(kNumberTypeFloat32); | |||
| mean->set_shape(mean_dims); | |||
| tensor::TensorPtr variance = std::make_shared<tensor::Tensor>(); | |||
| variance->set_data_type(kNumberTypeFloat32); | |||
| variance->set_shape(variance_dims); | |||
| AbstractBasePtr abstract_inputs = FromValue(inputs, true); | |||
| AbstractBasePtr abstract_scale = FromValue(scale, true); | |||
| AbstractBasePtr abstract_offset = FromValue(offset, true); | |||
| AbstractBasePtr abstract_mean = FromValue(mean, true); | |||
| AbstractBasePtr abstract_variance = FromValue(variance, true); | |||
| AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_scale, abstract_offset, abstract_mean, | |||
| abstract_variance}; | |||
| AbstractBasePtr expected0 = abstract_inputs->Clone(); | |||
| AbstractBasePtr expected1 = abstract_scale->Clone(); | |||
| AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); | |||
| MS_LOG(INFO) << "result: " << res->ToString(); | |||
| MS_LOG(INFO) << "expected0: " << expected0->ToString(); | |||
| MS_LOG(INFO) << "expected1: " << expected1->ToString(); | |||
| std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res); | |||
| ASSERT_TRUE(abs_tuple != nullptr); | |||
| ASSERT_TRUE(*abs_tuple->elements()[0] == *expected0); | |||
| ASSERT_TRUE(*abs_tuple->elements()[1] == *expected1); | |||
| ASSERT_TRUE(*abs_tuple->elements()[2] == *expected1); | |||
| ASSERT_TRUE(*abs_tuple->elements()[3] == *expected1); | |||
| ASSERT_TRUE(*abs_tuple->elements()[4] == *expected1); | |||
| } | |||
| TEST_F(TestPrim, test_pooling) { | |||
| PrimitivePtr pooling = prim::kPrimPooling; | |||
| pooling->AddAttr("mode", MakeValue(std::string("avg"))); | |||
| @@ -35,7 +35,7 @@ TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_fission) { | |||
| std::vector<int64_t> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 5; ++i) { | |||
| for (size_t i = 0; i < 6; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| @@ -56,7 +56,7 @@ TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission1) | |||
| std::vector<int64_t> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 5; ++i) { | |||
| for (size_t i = 0; i < 6; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| @@ -75,7 +75,7 @@ TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission2) | |||
| std::vector<int64_t> shp_x{32, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < 5; ++i) { | |||
| for (size_t i = 0; i < 6; ++i) { | |||
| args_spec_list.push_back(x_abstract); | |||
| } | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| @@ -47,7 +47,7 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) { | |||
| std::vector<int64_t> shp_b{64}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract}; | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract, b_abstract}; | |||
| auto kernel_graph = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kernel_graph, nullptr); | |||
| @@ -80,13 +80,17 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) { | |||
| // set kernel for BNGrad | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; | |||
| builder1.SetInputsFormat( | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, | |||
| kOpFormat_NC1HWC0}); | |||
| builder1.SetOutputsFormat( | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| {kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, | |||
| kOpFormat_NC1HWC0}); | |||
| builder1.SetInputsDeviceType( | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, | |||
| kNumberTypeFloat32}); | |||
| builder1.SetOutputsDeviceType( | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32}); | |||
| {kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, | |||
| kNumberTypeFloat32}); | |||
| builder1.SetKernelType(TBE_KERNEL); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), bn_grad.get()); | |||
| // do bn_grad_split pass | |||
| @@ -37,7 +37,7 @@ TEST_F(TestHWOptimizeBatchNormGrad2BNInferGrad, test_fusion) { | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| std::vector<int64_t> shp_y{64}; | |||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract}; | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract, y_abstract}; | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| @@ -57,7 +57,7 @@ TEST_F(TestHWOptimizeBatchNormGrad2BNInferGrad, test_no_fusion) { | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| std::vector<int64_t> shp_y{64}; | |||
| auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract}; | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract, y_abstract}; | |||
| auto fg = GetKernelGraph(g, args_spec_list); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*fg); | |||
| @@ -39,28 +39,28 @@ def test_batch_norm_grad_infer_fission(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4): | |||
| batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4) | |||
| def before(input0, input1, input2, input3, input4, input5): | |||
| batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4, input5) | |||
| outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_is_training(input0, input1, input2, input3, input4): | |||
| batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4) | |||
| def before_is_training(input0, input1, input2, input3, input4, input5): | |||
| batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4, input5) | |||
| outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def before_output3_not_null(input0, input1, input2, input3, input4): | |||
| batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4) | |||
| outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 3)) | |||
| def before_output3_not_null(input0, input1, input2, input3, input4, input5): | |||
| batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4, input5) | |||
| outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2)) | |||
| output = tuple_getitem(outputs, 0) | |||
| return output | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4): | |||
| def after(input0, input1, input2, input3, input4, input5): | |||
| bn_infer_grad = BNInferGrad(input0, input2, input4) | |||
| bn_training_update_grad = BNTrainingUpdateGrad(input0, input1, input3, input4) | |||
| outputs = make_tuple(bn_infer_grad, tuple_getitem(bn_training_update_grad, 0), | |||
| @@ -38,19 +38,19 @@ def test_batchnormgrad_to_bninfergrad(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(input0, input1, input2, input3, input4): | |||
| res = batch_norm_grad(input0, input1, input2, input3, input4) | |||
| def before(input0, input1, input2, input3, input4, input5): | |||
| res = batch_norm_grad(input0, input1, input2, input3, input4, input5) | |||
| res = tuple_getitem(res, 0) | |||
| return res | |||
| @fns | |||
| def after(input0, input1, input2, input3, input4): | |||
| def after(input0, input1, input2, input3, input4, input5): | |||
| res = bn_infer_grad(input0, input2, input4) | |||
| return make_tuple(res) | |||
| @fns | |||
| def no_fusion(input0, input1, input2, input3, input4): | |||
| res = batch_norm_grad(input0, input1, input2, input3, input4) | |||
| def no_fusion(input0, input1, input2, input3, input4, input5): | |||
| res = batch_norm_grad(input0, input1, input2, input3, input4, input5) | |||
| item0 = tuple_getitem(res, 0) | |||
| item1 = tuple_getitem(res, 1) | |||
| item2 = tuple_getitem(res, 2) | |||
| @@ -49,8 +49,8 @@ def test_bn_grad_split(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(i0, i1, i2, i3, i4): | |||
| bn_grad_output = bn_grad(i0, i1, i2, i3, i4) | |||
| def before(i0, i1, i2, i3, i4, i5): | |||
| bn_grad_output = bn_grad(i0, i1, i2, i3, i4, i5) | |||
| item0 = tuple_getitem(bn_grad_output, 0) | |||
| item1 = tuple_getitem(bn_grad_output, 1) | |||
| item2 = tuple_getitem(bn_grad_output, 2) | |||
| @@ -58,7 +58,7 @@ def test_bn_grad_split(tag): | |||
| return output | |||
| @fns | |||
| def after1(i0, i1, i2, i3, i4): | |||
| def after1(i0, i1, i2, i3, i4, i5): | |||
| bn_grad1_output = bn_grad1(i0, i1, i3) | |||
| bn_grad1_item0 = tuple_getitem(bn_grad1_output, 0) | |||
| bn_grad1_item1 = tuple_getitem(bn_grad1_output, 1) | |||
| @@ -78,7 +78,7 @@ def test_bn_grad_split(tag): | |||
| return make_tuple(output) | |||
| @fns | |||
| def after2(i0, i1, i2, i3, i4): | |||
| def after2(i0, i1, i2, i3, i4, i5): | |||
| bn_update_grad_output = bn_training_update_grad(i0, i1, i3, i4) | |||
| update_item0 = tuple_getitem(bn_update_grad_output, 0) | |||
| update_item1 = tuple_getitem(bn_update_grad_output, 1) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -27,8 +27,6 @@ make_tuple = Primitive('MakeTuple') | |||
| four2five = Primitive('Four2Five') | |||
| five2four = Primitive('Five2Four') | |||
| cast = Primitive('Cast') | |||
| conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1) | |||
| bn = P.FusedBatchNorm() | |||
| relu = P.ReLU() | |||
| @@ -140,25 +138,6 @@ def test_eliminate_depend_input2(tag): | |||
| return fns[tag] | |||
| def test_opt_match(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def graph1(x, y): | |||
| sum_add = add(x, y) | |||
| output = make_tuple(sum_add) | |||
| return output | |||
| @fns | |||
| def graph2(x, w, scale, b, mean, variance): | |||
| conv_output = conv(x, w) | |||
| bn_output = bn(conv_output, scale, b, mean, variance) | |||
| res = tuple_getitem(bn_output, 0) | |||
| return res | |||
| return fns[tag] | |||
| def test_func_graph_cse(tag): | |||
| """ test_func_graph_cse """ | |||
| fns = FnDict() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops import _constants as Constants | |||
| # pylint: disable=unused-variable | |||
| @@ -25,7 +24,6 @@ allreduce = P.AllReduce() | |||
| allreduce.add_prim_attr('fusion', 1) | |||
| make_tuple = Primitive("MakeTuple") | |||
| conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1) | |||
| bn = P.FusedBatchNorm() | |||
| relu = P.ReLU() | |||
| conv_bn1 = Primitive('ConvBN1') | |||
| bn2_add_relu = Primitive('BN2AddRelu') | |||
| @@ -33,7 +31,6 @@ bn2_relu = Primitive('BN2Relu') | |||
| fused_bn1 = Primitive('FusedBN1') | |||
| fused_bn2 = Primitive('FusedBN2') | |||
| fused_bn3 = Primitive('FusedBN3') | |||
| bn_grad = G.FusedBatchNormGrad() | |||
| bn_grad1 = Primitive('BNGrad1') | |||
| bn_grad2 = Primitive('BNGrad2') | |||
| bn_grad3 = Primitive('BNGrad3') | |||
| @@ -50,73 +47,6 @@ class FnDict: | |||
| return self.fnDict[name] | |||
| def test_bn_split(tag): | |||
| """ test_split_bn_fusion """ | |||
| fns = FnDict() | |||
| @fns | |||
| def before(x, scale, b, mean, variance): | |||
| bn_output = bn(x, scale, b, mean, variance) | |||
| item0 = tuple_getitem(bn_output, 0) | |||
| return item0 | |||
| @fns | |||
| def after(x, scale, b, mean, variance): | |||
| fused_bn1_output = fused_bn1(x) | |||
| fused_bn2_input0 = tuple_getitem(fused_bn1_output, 0) | |||
| fused_bn2_input1 = tuple_getitem(fused_bn1_output, 1) | |||
| fused_bn2_output = fused_bn2(fused_bn2_input0, fused_bn2_input1, mean, variance) | |||
| fused_bn3_input1 = tuple_getitem(fused_bn2_output, 0) | |||
| fused_bn3_input2 = tuple_getitem(fused_bn2_output, 1) | |||
| fused_bn3_output = fused_bn3(x, fused_bn3_input1, fused_bn3_input2, scale, b) | |||
| output1 = tuple_getitem(fused_bn2_output, 2) | |||
| output2 = tuple_getitem(fused_bn2_output, 3) | |||
| output3 = tuple_getitem(fused_bn2_output, 0) | |||
| output4 = tuple_getitem(fused_bn2_output, 1) | |||
| output = make_tuple(fused_bn3_output, output1, output2, output3, output4) | |||
| item0 = tuple_getitem(output, 0) | |||
| return make_tuple(item0) | |||
| return fns[tag] | |||
| def test_bn_grad_split(tag): | |||
| """ test_bn_grad_split """ | |||
| fns = FnDict() | |||
| @fns | |||
| def before(dy, x, scale, save_mean, save_inv_variance): | |||
| bn_grad_output = bn_grad(dy, x, scale, save_mean, save_inv_variance) | |||
| item0 = tuple_getitem(bn_grad_output, 0) | |||
| item1 = tuple_getitem(bn_grad_output, 1) | |||
| item2 = tuple_getitem(bn_grad_output, 2) | |||
| output = make_tuple(item0, item1, item2) | |||
| res = tuple_getitem(output, 0) | |||
| return res | |||
| @fns | |||
| def after(i0, i1, i2, i3, i4): | |||
| bn_grad1_output = bn_grad1(i0, i1, i3) | |||
| bn_grad1_item0 = tuple_getitem(bn_grad1_output, 0) | |||
| bn_grad1_item1 = tuple_getitem(bn_grad1_output, 1) | |||
| bn_grad1_item2 = tuple_getitem(bn_grad1_output, 2) | |||
| bn_grad2_output = bn_grad2(bn_grad1_item0, bn_grad1_item1, i4, i2) | |||
| bn_grad2_item0 = tuple_getitem(bn_grad2_output, 0) | |||
| bn_grad2_item1 = tuple_getitem(bn_grad2_output, 1) | |||
| bn_grad2_item2 = tuple_getitem(bn_grad2_output, 2) | |||
| bn_grad2_item3 = tuple_getitem(bn_grad2_output, 3) | |||
| bn_grad2_item4 = tuple_getitem(bn_grad2_output, 4) | |||
| bn_grad3_output = bn_grad3(i0, bn_grad2_item2, bn_grad2_item3, bn_grad2_item4, bn_grad1_item2) | |||
| bn_grad_make_tuple = make_tuple(bn_grad3_output, bn_grad2_item0, bn_grad2_item1) | |||
| item0 = tuple_getitem(bn_grad_make_tuple, 0) | |||
| item1 = tuple_getitem(bn_grad_make_tuple, 1) | |||
| item2 = tuple_getitem(bn_grad_make_tuple, 2) | |||
| output = make_tuple(item0, item1, item2) | |||
| return make_tuple(tuple_getitem(output, 0)) | |||
| return fns[tag] | |||
| def test_all_reduce_fusion_all(tag): | |||
| """ test_all_reduce_fusion_all """ | |||
| fns = FnDict() | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -221,7 +221,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) { | |||
| auto kernel_graph = std::make_shared<KernelGraph>(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| // test fused batch norm as input | |||
| inputs.push_back(NewValueNode(prim::kPrimFusedBatchNorm)); | |||
| inputs.push_back(NewValueNode(prim::kPrimBatchNorm)); | |||
| auto bn = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(bn); | |||
| std::vector<int64_t> shp{2, 32, 224, 224}; | |||
| @@ -417,7 +417,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) { | |||
| TEST_F(AnfRuntimeAlgorithmTest, GetOutputInferDataTypeTest) { | |||
| auto kernel_graph = std::make_shared<KernelGraph>(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimFusedBatchNorm)); | |||
| inputs.push_back(NewValueNode(prim::kPrimBatchNorm)); | |||
| auto bn = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(bn); | |||
| std::vector<int64_t> shp{2, 32, 224, 224}; | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -690,26 +690,6 @@ test_cases_for_verify_exception = [ | |||
| 'block': (lambda _: P.MaxPoolWithArgmax(strides=-1), {'exception': ValueError}), | |||
| 'desc_inputs': [0], | |||
| }), | |||
| ('FusedBatchNorm_ValueError_1', { | |||
| 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}), | |||
| 'desc_inputs': [0], | |||
| }), | |||
| ('FusedBatchNorm_ValueError_2', { | |||
| 'block': (lambda _: P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}), | |||
| 'desc_inputs': [0], | |||
| }), | |||
| ('FusedBatchNorm_ValueError_3', { | |||
| 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}), | |||
| 'desc_inputs': [0], | |||
| }), | |||
| ('FusedBatchNorm_ValueError_4', { | |||
| 'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}), | |||
| 'desc_inputs': [0], | |||
| }), | |||
| ('FusedBatchNorm_ValueError_5', { | |||
| 'block': (lambda _: P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}), | |||
| 'desc_inputs': [0], | |||
| }), | |||
| ('Softmax_ValueError_1', { | |||
| 'block': (lambda _: P.Softmax("1"), {'exception': TypeError}), | |||
| 'desc_inputs': [0], | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1749,11 +1749,6 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], | |||
| 'desc_bprop': [[2, 16], [16], [16]], | |||
| 'skip': ['backward']}), | |||
| ('FusedBatchNormGrad', { | |||
| 'block': G.FusedBatchNormGrad(), | |||
| 'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]], | |||
| 'skip': ['backward']}), | |||
| ('BatchNorm', { | |||
| 'block': P.BatchNorm(), | |||
| 'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| @@ -1761,8 +1756,8 @@ test_case_nn_ops = [ | |||
| 'skip': []}), | |||
| ('BatchNormGrad', { | |||
| 'block': G.BatchNormGrad(), | |||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64]], | |||
| 'skip': ['backward']}), | |||
| ('SyncBatchNorm', { | |||
| 'block': inner.SyncBatchNorm(), | |||
| @@ -1,77 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y, b): | |||
| predict = self.network(x, y, b) | |||
| return self.loss(predict) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y, b): | |||
| return grad_all(self.network)(x, y, b) | |||
| # model_parallel test | |||
| def test_two_matmul_batchnorm_ex(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy2): | |||
| super().__init__() | |||
| self.matmul1 = P.BatchMatMul().shard(strategy1) | |||
| self.norm = P.FusedBatchNormEx() | |||
| self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma") | |||
| self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta") | |||
| self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean") | |||
| self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var") | |||
| self.matmul2 = P.BatchMatMul().shard(strategy2) | |||
| def construct(self, x, y, b): | |||
| out = self.matmul1(x, y) | |||
| out = self.norm(out, self.gamma, self.beta, self.mean, self.var)[0] | |||
| out = self.matmul2(out, b) | |||
| return out | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8) | |||
| strategy1 = ((1, 1, 4, 2), (1, 1, 2, 1)) | |||
| strategy2 = ((1, 1, 1, 8), (1, 1, 8, 1)) | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 64, 128, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 32, 64]), dtype=ms.float32) | |||
| b = Tensor(np.ones([64, 64, 64, 64]), dtype=ms.float32) | |||
| net.set_train() | |||
| _executor.compile(net, x, y, b) | |||
| @@ -1,260 +0,0 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.common.dtype as DT | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.nn import WithLossCell | |||
| from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train.model import Model | |||
| from mindspore.context import ParallelMode | |||
| from tests.dataset_mock import MindData | |||
| class Dataset(MindData): | |||
| def __init__(self, predict, label, length=3): | |||
| super(Dataset, self).__init__(size=length) | |||
| self.predict = predict | |||
| self.label = label | |||
| self.index = 0 | |||
| self.length = length | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.index >= self.length: | |||
| raise StopIteration | |||
| self.index += 1 | |||
| return self.predict, self.label | |||
| def reset(self): | |||
| self.index = 0 | |||
| class FusedBatchNorm(nn.Cell): | |||
| """Batch Normalization base class.""" | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.1, | |||
| affine=True, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| moving_mean_init='zeros', | |||
| moving_var_init='ones'): | |||
| super(FusedBatchNorm, self).__init__() | |||
| if num_features < 1: | |||
| raise ValueError("num_features must be at least 1") | |||
| if momentum < 0 or momentum > 1: | |||
| raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) | |||
| self.num_features = num_features | |||
| self.eps = eps | |||
| self.momentum = Tensor(1.0 - momentum, DT.float32) | |||
| self.gamma = Parameter(initializer( | |||
| gamma_init, num_features), name="gamma", requires_grad=affine) | |||
| self.beta = Parameter(initializer( | |||
| beta_init, num_features), name="beta", requires_grad=affine) | |||
| self.moving_mean = Parameter(initializer( | |||
| moving_mean_init, num_features), name="mean", requires_grad=False) | |||
| self.moving_variance = Parameter(initializer( | |||
| moving_var_init, num_features), name="variance", requires_grad=False) | |||
| self.bn_train = P.BatchNorm(is_training=True, | |||
| epsilon=self.eps) | |||
| self.bn_infer = P.BatchNorm(is_training=False, | |||
| epsilon=self.eps) | |||
| self.sub_mean = P.Sub().shard(((1), (1))) | |||
| self.sub_var = P.Sub().shard(((1), (1))) | |||
| self.mul_mean = P.Mul().shard(((1,), ())) | |||
| self.mul_var = P.Mul().shard(((1,), ())) | |||
| self.assign_sub_mean = P.AssignSub().shard(((1,), (1,))) | |||
| self.assign_sub_var = P.AssignSub().shard(((1), (1))) | |||
| self.sub_mean2 = P.Sub().shard(((1), (1))) | |||
| self.sub_var2 = P.Sub().shard(((1), (1))) | |||
| def shard(self, strategy): | |||
| self.bn_train.shard(strategy) | |||
| self.bn_infer.shard(strategy) | |||
| def _check_data_dim(self, x): | |||
| raise NotImplementedError | |||
| def construct(self, x): | |||
| if self.training: | |||
| y, batch_mean, batch_var, _, _ = \ | |||
| self.bn_train(x, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance) | |||
| mean_sub = self.sub_mean(self.moving_mean, batch_mean) | |||
| temp_mean = self.mul_mean(mean_sub, self.momentum) | |||
| mean_sub2 = self.sub_var(self.moving_variance, batch_var) | |||
| temp_variance = self.mul_var(mean_sub2, self.momentum) | |||
| y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) | |||
| y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) | |||
| else: | |||
| y = self.bn_infer(x, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance)[0] | |||
| return y | |||
| def extend_repr(self): | |||
| return 'num_features={}, eps={}, momentum={}, ' \ | |||
| 'beta={}, gamma={}, ' \ | |||
| 'moving_mean={}, moving_variance={} ' \ | |||
| .format(self.num_features, | |||
| self.eps, | |||
| self.momentum, | |||
| self.beta, | |||
| self.gamma, | |||
| self.moving_mean, | |||
| self.moving_variance) | |||
| class PReLU(nn.Cell): | |||
| """ | |||
| PReLU activation function. | |||
| Computes prelu value of a 4-dim tensor(NCHW). | |||
| PReLU: out = max(0, A) + min(0, wA) | |||
| Args: | |||
| channel: Integer. The dimensionality of w. Default: 1. | |||
| w: Float. The initial value of w. Default: 0.25. | |||
| Returns: | |||
| Tensor, has the same type as features. | |||
| Examples: | |||
| prelu = nn.PReLU(1, [np.float32(0.25)]) # or prelu = nn.PReLU(33, Tensor(np.random.rand(33), ms.float32)]) | |||
| input_data = Tensor(np.random.rand(1, 33, 4, 4), ms.float32) | |||
| output = prelu.construct(input_data) | |||
| """ | |||
| def __init__(self, channel=1, w=0.25): | |||
| super(PReLU, self).__init__() | |||
| if isinstance(w, (np.float32, float)): | |||
| tmp = np.empty((channel,), dtype=np.float32) | |||
| tmp.fill(w) | |||
| w = tmp | |||
| elif isinstance(w, (int, bool, complex, str)): | |||
| raise TypeError("w only support input type float32 and float") | |||
| if not isinstance(w, Tensor): | |||
| w = Tensor(w) | |||
| self.w = Parameter(initializer(w, [channel,]), name='a') | |||
| self.prelu = P.PReLU() | |||
| self.relu = P.ReLU().shard(((1))) | |||
| def construct(self, x): | |||
| self.w = self.relu(self.w) | |||
| return self.prelu(x, self.w) | |||
| class BNNet(nn.Cell): | |||
| def __init__(self): | |||
| super(BNNet, self).__init__() | |||
| self.bn = FusedBatchNorm(512) | |||
| self.prelu = PReLU(512) | |||
| def construct(self, x): | |||
| x = self.bn(x) | |||
| x = self.prelu(x) | |||
| return x | |||
| def bn_net(): | |||
| return BNNet() | |||
| def bn_common(parallel_mode, train_flag, strategy_loss=None): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) | |||
| learning_rate = 0.1 | |||
| momentum = 0.9 | |||
| epoch_size = 2 | |||
| rank_size = 8 | |||
| predict = Tensor(np.ones([32, 512]), dtype=ms.float32) | |||
| label = Tensor(np.ones([32]), dtype=ms.int32) | |||
| dataset = Dataset(predict, label, 2) | |||
| net = bn_net() | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| loss.softmax_cross_entropy.shard(strategy_loss) | |||
| opt = Momentum(net.trainable_params(), learning_rate, momentum, 0.0001, 1024 * rank_size) | |||
| if not train_flag: | |||
| net = WithLossCell(net, loss) | |||
| net.set_train() | |||
| if parallel_mode == ParallelMode.DATA_PARALLEL: | |||
| context.set_auto_parallel_context(parameter_broadcast=True) | |||
| model = Model(net, loss, opt) | |||
| if train_flag: | |||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | |||
| else: | |||
| model._predict(predict, label) | |||
| def test_data_parallel(): | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| train_flag = True | |||
| bn_common(parallel_mode, train_flag) | |||
| def auto_parallel(): | |||
| train_flag = True | |||
| parallel_mode = ParallelMode.AUTO_PARALLEL | |||
| bn_common(parallel_mode, train_flag) | |||
| def Xtest_data_parallel_predict(): | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| train_flag = False | |||
| bn_common(parallel_mode, train_flag) | |||
| def Xtest_semi_auto_parallel_predict(): | |||
| train_flag = False | |||
| parallel_mode = ParallelMode.SEMI_AUTO_PARALLEL | |||
| bn_common(parallel_mode, train_flag) | |||
| def Xtest_auto_parallel_predict(): | |||
| train_flag = False | |||
| parallel_mode = ParallelMode.AUTO_PARALLEL | |||
| bn_common(parallel_mode, train_flag) | |||
| if __name__ == '__main__': | |||
| auto_parallel() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -92,27 +92,6 @@ def vm_impl_tanh(self): | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.FusedBatchNorm) | |||
| def vm_impl_fused_batch_norm(self): | |||
| """Generate vm_impl function for FusedBatchNorm""" | |||
| def vm_impl(x, scale, b, mean, variance): | |||
| # pylint: disable=unused-argument | |||
| x = x.asnumpy() | |||
| scale = scale.asnumpy() | |||
| b = b.asnumpy() | |||
| mean = mean.asnumpy() | |||
| variance = variance.asnumpy() | |||
| out, x_mean, x_var, running_mean, running_var = vm.batch_norm(x, scale, b, mean, \ | |||
| variance, \ | |||
| eps=self.epsilon, \ | |||
| momentum=self.momentum) | |||
| return Tensor(out), Tensor(x_mean), Tensor(x_var), \ | |||
| Tensor(running_mean), Tensor(running_var) | |||
| return vm_impl | |||
| @vm_impl_getters.register(P.BatchNorm) | |||
| def vm_impl_batch_norm(self): | |||
| """Generate vm_impl function for BatchNorm""" | |||
| @@ -223,23 +202,6 @@ def vm_impl_avg_pool_grad(self): | |||
| return vm_impl | |||
| # pylint: disable=function-redefined | |||
| @vm_impl_getters.register(G.FusedBatchNormGrad) | |||
| def vm_impl_fused_batch_norm_grad(self): | |||
| """Generate vm_impl function for FusedBatchNormGrad""" | |||
| def vm_impl(dy, x, scale, save_mean, save_inv_variance): | |||
| dy = dy.asnumpy() | |||
| x = x.asnumpy() | |||
| scale = scale.asnumpy() | |||
| save_mean = save_mean.asnumpy() | |||
| save_inv_variance = save_inv_variance.asnumpy() | |||
| dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance) | |||
| return (Tensor(dx), Tensor(dscale), Tensor(dshift)) | |||
| return vm_impl | |||
| # pylint: disable=function-redefined | |||
| @vm_impl_getters.register(G.BatchNormGrad) | |||
| def vm_impl_fused_batch_norm_grad(self): | |||