| @@ -50,11 +50,13 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({2, channel}); | dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({2, channel}); | ||||
| auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon"); | auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon"); | ||||
| auto prop_kind = dnnl::prop_kind::forward_inference; | auto prop_kind = dnnl::prop_kind::forward_inference; | ||||
| auto normalization_flags = dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats; | |||||
| if (is_train) { | if (is_train) { | ||||
| prop_kind = dnnl::prop_kind::forward_training; | prop_kind = dnnl::prop_kind::forward_training; | ||||
| normalization_flags = dnnl::normalization_flags::use_scale_shift; | |||||
| } | } | ||||
| dnnl::batch_normalization_forward::desc desc = | dnnl::batch_normalization_forward::desc desc = | ||||
| dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, dnnl::normalization_flags::use_scale_shift); | |||||
| dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags); | |||||
| auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); | auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); | ||||
| primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc); | primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc); | ||||
| AddArgument(DNNL_ARG_SRC, x_desc); | AddArgument(DNNL_ARG_SRC, x_desc); | ||||
| @@ -74,14 +76,14 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu | |||||
| auto wksp = reinterpret_cast<float *>(workspace[0]->addr); | auto wksp = reinterpret_cast<float *>(workspace[0]->addr); | ||||
| memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size); | memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size); | ||||
| memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size); | memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size); | ||||
| SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_VARIANCE, outputs[4]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); | |||||
| ExecutePrimitive(); | |||||
| if (is_train) { | if (is_train) { | ||||
| SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_VARIANCE, outputs[4]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); | |||||
| ExecutePrimitive(); | |||||
| auto moving_mean = reinterpret_cast<float *>(inputs[3]->addr); | auto moving_mean = reinterpret_cast<float *>(inputs[3]->addr); | ||||
| auto moving_variance = reinterpret_cast<float *>(inputs[4]->addr); | auto moving_variance = reinterpret_cast<float *>(inputs[4]->addr); | ||||
| auto mean = reinterpret_cast<float *>(outputs[3]->addr); | auto mean = reinterpret_cast<float *>(outputs[3]->addr); | ||||
| @@ -90,6 +92,13 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu | |||||
| moving_mean[i] = moving_mean[i] * (1 - momentum) + mean[i] * momentum; | moving_mean[i] = moving_mean[i] * (1 - momentum) + mean[i] * momentum; | ||||
| moving_variance[i] = moving_variance[i] * (1 - momentum) + variance[i] * momentum; | moving_variance[i] = moving_variance[i] * (1 - momentum) + variance[i] * momentum; | ||||
| } | } | ||||
| } else { | |||||
| SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_MEAN, inputs[3]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[4]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); | |||||
| SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); | |||||
| ExecutePrimitive(); | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||