From 8430ca147cace09b5f9481cd6788bc69ecdca1bd Mon Sep 17 00:00:00 2001 From: baihuawei Date: Thu, 25 Mar 2021 20:46:22 +0800 Subject: [PATCH] fix lstm in pynative mode --- .../backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc | 5 +++-- mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc index 6305046e9e..13c5a6d755 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc @@ -60,9 +60,10 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); if (!kernel_node->HasAttr(kAttrIsTraining)) { - MS_LOG(WARNING) << "LSTM has no attr is_training"; + is_training = true; + } else { + is_training = GetValue(kernel_node->GetAttr(kAttrIsTraining)); } - is_training = GetValue(kernel_node->GetAttr(kAttrIsTraining)); auto prop_kind = dnnl::prop_kind::forward_training; if (!is_training) { prop_kind = dnnl::prop_kind::forward_inference; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 20f291a32f..3355548f8e 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -189,6 +189,12 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput( auto shape = AnfAlgo::GetOutputInferShape(node, index); ShapeVector temp_shape; (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); + size_t type_size = GetTypeByte(TypeIdToType(device_type_id)); + size_t tensor_size = std::accumulate(temp_shape.begin(), temp_shape.end(), type_size, std::multiplies()); + if (tensor_size < address->size_) { + temp_shape.clear(); + temp_shape.emplace_back(address->size_); + } tensor = std::make_shared(infer_type_id, temp_shape); bool is_internal_output = kernel_graph->IsInternalOutput(node, index); if (is_internal_output) {