Browse Source

fix getinputformat error when input is not a realnode

tags/v1.0.0
VectorSL 5 years ago
parent
commit
853987da79
2 changed files with 6 additions and 1 deletions
  1. +5
    -0
      mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc

+ 5
- 0
mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc View File

@@ -49,6 +49,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto outlist = GetRealNodeUsedList(graph, fbn2); auto outlist = GetRealNodeUsedList(graph, fbn2);
bool changed = false;
for (size_t i = 0; i < outlist->size(); i++) { for (size_t i = 0; i < outlist->size(); i++) {
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1); auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
auto value_node = index_node->cast<ValueNodePtr>(); auto value_node = index_node->cast<ValueNodePtr>();
@@ -63,8 +64,12 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
outputs_type.push_back(kNumberTypeFloat16); outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get()); AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
changed = true;
} }
} }
if (!changed) {
return nullptr;
}
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before)); manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
outputs_type.clear(); outputs_type.clear();
outputs_shape.clear(); outputs_shape.clear();


+ 1
- 1
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -425,7 +425,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
<< node->DebugString() << "]"; << node->DebugString() << "]";
} }
if (!IsRealKernel(node)) { if (!IsRealKernel(node)) {
GetPrevNodeOutputFormat(node, input_idx);
return GetPrevNodeOutputFormat(node, input_idx);
} }
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);


Loading…
Cancel
Save