Browse Source

!14114 fix lstm in pynative mode

Merge pull request !14114 from baihuawei/fixpylstm1.2
tags/v1.2.0-rc1
lilongfei Gitee 5 years ago
parent
commit
39fbf69674
2 changed files with 9 additions and 2 deletions
  1. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc
  2. +6
    -0
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc

+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc View File

@@ -60,9 +60,10 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
if (!kernel_node->HasAttr(kAttrIsTraining)) {
MS_LOG(WARNING) << "LSTM has no attr is_training";
is_training = true;
} else {
is_training = GetValue<bool>(kernel_node->GetAttr(kAttrIsTraining));
}
is_training = GetValue<bool>(kernel_node->GetAttr(kAttrIsTraining));
auto prop_kind = dnnl::prop_kind::forward_training;
if (!is_training) {
prop_kind = dnnl::prop_kind::forward_inference;


+ 6
- 0
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc View File

@@ -189,6 +189,12 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(
auto shape = AnfAlgo::GetOutputInferShape(node, index);
ShapeVector temp_shape;
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
size_t tensor_size = std::accumulate(temp_shape.begin(), temp_shape.end(), type_size, std::multiplies<size_t>());
if (tensor_size < address->size_) {
temp_shape.clear();
temp_shape.emplace_back(address->size_);
}
tensor = std::make_shared<tensor::Tensor>(infer_type_id, temp_shape);
bool is_internal_output = kernel_graph->IsInternalOutput(node, index);
if (is_internal_output) {


Loading…
Cancel
Save