浏览代码

!1929 use VisitKernelWithReturnType instead of VisitKernel to get node's input in mem_reuse

Merge pull request !1929 from laiyongqiang/mem
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
5770df0951
共有 2 个文件被更改,包括 10 次插入3 次删除
  1. +8
    -2
      mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc
  2. +2
    -1
      mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc

+ 8
- 2
mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc 查看文件

@@ -226,7 +226,10 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t
<< AnfAlgo::GetInputTensorNum(kernel); << AnfAlgo::GetInputTensorNum(kernel);
} }
auto input_node = kernel->input(input_idx + 1); auto input_node = kernel->input(input_idx + 1);
auto kernel_input = AnfAlgo::VisitKernel(input_node, 0);
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple";
}
auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second)); auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second));
return result; return result;
} }
@@ -264,7 +267,10 @@ void MemReuseUtil::SetKernelDefInputs() {
if (ref_ptr != nullptr) { if (ref_ptr != nullptr) {
// set the inputs of this kernel_def // set the inputs of this kernel_def
auto input_node = AnfAlgo::GetInputNode(kernel, i); auto input_node = AnfAlgo::GetInputNode(kernel, i);
auto input = AnfAlgo::VisitKernel(input_node, 0);
auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
if (IsPrimitive(input.first, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
}
auto input_key = (input.first).get(); auto input_key = (input.first).get();
auto input_iter = kernel_map_.find(input_key); auto input_iter = kernel_map_.find(input_key);
if (input_iter == kernel_map_.end()) { if (input_iter == kernel_map_.end()) {


+ 2
- 1
mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc 查看文件

@@ -48,7 +48,8 @@ void MemReuseChecker::CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr
auto iter = kernel_refs.find(key); auto iter = kernel_refs.find(key);
auto node_name = AnfAlgo::GetCNodeName(c_node); auto node_name = AnfAlgo::GetCNodeName(c_node);
if (iter == kernel_refs.end()) { if (iter == kernel_refs.end()) {
MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor";
MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor, node: " << c_node->DebugString()
<< " output index: " << output_idx;
} }
if (output_idx >= iter->second.size()) { if (output_idx >= iter->second.size()) {
MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str(); MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str();


正在加载...
取消
保存