|
|
|
@@ -22,6 +22,11 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
void PoolingCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { |
|
|
|
CPUKernel::InitInputOutputSize(kernel_node); |
|
|
|
workspace_size_list_.emplace_back(workspace_size_); |
|
|
|
} |
|
|
|
|
|
|
|
void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); |
|
|
|
@@ -62,6 +67,7 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
dst_desc, strides_dims, kernels_dims, padding_l, padding_r); |
|
|
|
} |
|
|
|
auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); |
|
|
|
workspace_size_ = prim_desc.workspace_desc().get_size(); |
|
|
|
primitive_ = std::make_shared<dnnl::pooling_forward>(prim_desc); |
|
|
|
AddArgument(DNNL_ARG_SRC, src_desc); |
|
|
|
AddArgument(DNNL_ARG_DST, dst_desc); |
|
|
|
@@ -69,13 +75,14 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
} |
|
|
|
|
|
|
|
bool PoolingCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> & /*workspace*/, |
|
|
|
const std::vector<kernel::AddressPtr> &workspace, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
if (inputs.empty() || outputs.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "error input output size!"; |
|
|
|
} |
|
|
|
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); |
|
|
|
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); |
|
|
|
SetArgumentHandle(DNNL_ARG_WORKSPACE, workspace[0]->addr); |
|
|
|
ExecutePrimitive(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|