|
|
|
@@ -35,6 +35,20 @@ bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphMana |
|
|
|
static_cast<void>(DeepLinkedGraphSearch(leaf, IncludeUser));
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Update matmul's BuildInfo as last input changed
|
|
|
|
void UpdateBuildInfo(const AnfNodePtr &matmul_node, const AnfNodePtr &cast_node) {
|
|
|
|
std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(matmul_node);
|
|
|
|
std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(matmul_node);
|
|
|
|
input_types.pop_back();
|
|
|
|
auto cast_types = AnfAlgo::GetAllInputDeviceTypes(cast_node);
|
|
|
|
input_types.push_back(cast_types.front());
|
|
|
|
std::vector<std::string> output_formats = AnfAlgo::GetAllOutputFormats(matmul_node);
|
|
|
|
std::vector<TypeId> output_types = AnfAlgo::GetAllOutputDeviceTypes(matmul_node);
|
|
|
|
auto graph_sel_info =
|
|
|
|
BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, matmul_node);
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, matmul_node.get());
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
/* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword
|
|
|
|
@@ -93,6 +107,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) { |
|
|
|
// Case1 : Cast is only used by matmul
|
|
|
|
if (user_index_set.size() == 1) {
|
|
|
|
mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1));
|
|
|
|
UpdateBuildInfo(cnode, cast_node);
|
|
|
|
changed = true;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
@@ -109,6 +124,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) { |
|
|
|
cnode->set_input(4, (cast_node->cast<CNodePtr>())->input(1));
|
|
|
|
mng->RemoveRoots();
|
|
|
|
mng->KeepRoots({func_graph});
|
|
|
|
UpdateBuildInfo(cnode, cast_node);
|
|
|
|
changed = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|