From 46c3f4e019dba76864bcb27e31c5ef0b325ef958 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Sun, 27 Sep 2020 11:54:18 +0800 Subject: [PATCH] fix sparse_softmax --- .../sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc index b1299f1ae0..c34fd17046 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -35,6 +35,10 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNo void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector label_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + if (label_shape.size() > 1) { + MS_LOG(EXCEPTION) << "label shape should be 1D"; + } dnnl::memory::dims mem_dims; mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); if (mem_dims.size() != 2) {