|
|
@@ -49,6 +49,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A |
|
|
auto manager = graph->manager();
|
|
|
auto manager = graph->manager();
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
auto outlist = GetRealNodeUsedList(graph, fbn2);
|
|
|
auto outlist = GetRealNodeUsedList(graph, fbn2);
|
|
|
|
|
|
bool changed = false;
|
|
|
for (size_t i = 0; i < outlist->size(); i++) {
|
|
|
for (size_t i = 0; i < outlist->size(); i++) {
|
|
|
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
|
|
|
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
|
|
|
auto value_node = index_node->cast<ValueNodePtr>();
|
|
|
auto value_node = index_node->cast<ValueNodePtr>();
|
|
|
@@ -63,8 +64,12 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A |
|
|
outputs_type.push_back(kNumberTypeFloat16);
|
|
|
outputs_type.push_back(kNumberTypeFloat16);
|
|
|
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
|
|
|
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
|
|
|
|
|
|
changed = true;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if (!changed) {
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
}
|
|
|
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
|
|
|
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
|
|
|
outputs_type.clear();
|
|
|
outputs_type.clear();
|
|
|
outputs_shape.clear();
|
|
|
outputs_shape.clear();
|
|
|
|