|
|
|
@@ -35,6 +35,10 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNo |
|
|
|
void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); |
|
|
|
std::vector<size_t> label_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); |
|
|
|
if (label_shape.size() > 1) { |
|
|
|
MS_LOG(EXCEPTION) << "label shape should be 1D"; |
|
|
|
} |
|
|
|
dnnl::memory::dims mem_dims; |
|
|
|
mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); |
|
|
|
if (mem_dims.size() != 2) { |
|
|
|
|