Browse Source

!7332 cpu conv2d support tuple pad

Merge pull request !7332 from baihuawei/conv2dgrad
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
6152bdaf16
3 changed files with 7 additions and 6 deletions
  1. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h
  2. +5
    -5
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc
  3. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc

+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h View File

@@ -34,6 +34,7 @@ const char STRIDE[] = "stride";
const char STRIDES[] = "strides";
const char DILATION[] = "dilation";
const char PAD[] = "pad";
const char PAD_LIST[] = "pad_list";
const char PAD_MODE[] = "pad_mode";
const char PADDING[] = "padding";
const char PAD_MODE_LOWER_SAME[] = "same";


+ 5
- 5
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc View File

@@ -56,11 +56,11 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
padding_r->emplace_back(0);
padding_r->emplace_back(0);
} else {
int pad = AnfAlgo::GetNodeAttr<int>(kernel_node, PAD);
padding_l->emplace_back(pad);
padding_l->emplace_back(pad);
padding_r->emplace_back(pad);
padding_r->emplace_back(pad);
std::vector<int> pad = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, PAD_LIST);
padding_l->emplace_back(pad[0]);
padding_l->emplace_back(pad[1]);
padding_r->emplace_back(pad[2]);
padding_r->emplace_back(pad[3]);
}
}



+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc View File

@@ -37,7 +37,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &ke
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> label_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (label_shape.size() > 1) {
MS_LOG(EXCEPTION) << "label shape should be 1D";
MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length minus 1";
}
dnnl::memory::dims mem_dims;
mem_dims.insert(mem_dims.end(), shape.begin(), shape.end());


Loading…
Cancel
Save