|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h" |
|
|
|
#include <string> |
|
|
|
#include <map> |
|
|
|
#include <algorithm> |
|
|
|
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" |
|
|
|
#include "runtime/device/cpu/cpu_device_address.h" |
|
|
|
@@ -22,6 +23,7 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
const std::map<std::string, size_t> kFormatIndexMap = {{"NCHW", 2}, {"HWCN", 0}, {"NHWC", 1}}; |
|
|
|
void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
std::vector<size_t> src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); |
|
|
|
@@ -47,7 +49,16 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
std::vector<int> dilation_ori; |
|
|
|
auto stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDE); |
|
|
|
auto dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, DILATION); |
|
|
|
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_ori), |
|
|
|
auto format_me = AnfAlgo::GetNodeAttr<std::string>(kernel_node, FORMAT); |
|
|
|
auto iter = kFormatIndexMap.find(format_me); |
|
|
|
if (iter == kFormatIndexMap.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "OriFormat is " << format_me << ", Please confirm that in {NCHW, HWCN, NHWC}."; |
|
|
|
} |
|
|
|
size_t h_index = iter->second; |
|
|
|
if (stride_me.size() < h_index + 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Strides should greater than " << h_index + 1 << ", but got " << stride_me.size(); |
|
|
|
} |
|
|
|
(void)std::transform(stride_me.begin() + h_index, stride_me.begin() + h_index + 2, std::back_inserter(stride_ori), |
|
|
|
[](const int64_t &value) { return static_cast<int>(value); }); |
|
|
|
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori), |
|
|
|
[](const int64_t &value) { return static_cast<int>(value); }); |
|
|
|
|