|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" |
|
|
|
#include <map> |
|
|
|
#include <tuple> |
|
|
|
#include <unordered_set> |
|
|
|
#include "pipeline/jit/parse/python_adapter.h" |
|
|
|
#include "pipeline/jit/action.h" |
|
|
|
@@ -244,7 +245,7 @@ AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManager |
|
|
|
|
|
|
|
bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, |
|
|
|
const DumpOption &dump_option, nlohmann::json *op_desc, |
|
|
|
std::map<std::string, AnfNodePtr> *address_node_map) { |
|
|
|
std::map<std::string, AnfNodePtr> *address_node_map = nullptr) { |
|
|
|
kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option); |
|
|
|
if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) { |
|
|
|
MS_LOG(ERROR) << "Collect json desc failed."; |
|
|
|
@@ -262,6 +263,90 @@ bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const |
|
|
|
MS_LOG(INFO) << "Collect fusion json: " << fused_name; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertComplexTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) { |
|
|
|
MS_EXCEPTION_IF_NULL(inputs_ptr); |
|
|
|
auto nodes = TopoSort(fg->get_return()); |
|
|
|
|
|
|
|
std::map<ValuePtr, AnfNodePtrList> vmap; |
|
|
|
for (const auto &node : nodes) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &inputs = node->cast<CNodePtr>()->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
auto tnode = inputs[i]; |
|
|
|
auto tensor = GetValueNode<tensor::TensorPtr>(tnode); |
|
|
|
if (tensor && (tensor->DataSize() > 1)) { |
|
|
|
vmap[GetValueNode(tnode)].push_back(tnode); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (vmap.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto mng = fg->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
mng = Manage(fg, false); |
|
|
|
fg->set_manager(mng); |
|
|
|
} |
|
|
|
|
|
|
|
auto &inputs = *inputs_ptr; |
|
|
|
for (auto iter : vmap) { |
|
|
|
auto value_nodes = iter.second; |
|
|
|
if (value_nodes.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid value in map!"; |
|
|
|
} |
|
|
|
|
|
|
|
auto vnode = value_nodes[0]; |
|
|
|
auto parameter = fg->add_parameter(); |
|
|
|
parameter->set_abstract(vnode->abstract()); |
|
|
|
parameter->set_kernel_info(vnode->kernel_info_ptr()); |
|
|
|
for (const auto &value_node : value_nodes) { |
|
|
|
mng->Replace(value_node, parameter); |
|
|
|
} |
|
|
|
|
|
|
|
inputs.push_back(vnode); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs. |
|
|
|
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes, |
|
|
|
AnfNodePtrList *src_outputs = nullptr) { |
|
|
|
FuncGraphPtr fg; |
|
|
|
AnfNodePtrList inputs; |
|
|
|
AnfNodePtrList outputs; |
|
|
|
AnfNodePtrList *soutputs = (src_outputs != nullptr) ? src_outputs : &outputs; |
|
|
|
std::tie(fg, inputs, *soutputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); |
|
|
|
|
|
|
|
FuncGraphManagerPtr mng = fg->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
mng = Manage(fg, false); |
|
|
|
fg->set_manager(mng); |
|
|
|
} |
|
|
|
|
|
|
|
// Inline origin graphkernel |
|
|
|
auto cnodes = fg->GetOrderedCnodes(); |
|
|
|
for (const auto &n : cnodes) { |
|
|
|
if (!AnfAlgo::IsGraphKernel(n)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0)); |
|
|
|
AnfNodePtrList ins; |
|
|
|
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); |
|
|
|
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); |
|
|
|
mng->Replace(n, out); |
|
|
|
} |
|
|
|
|
|
|
|
EliminateMakeTuple(fg, mng); |
|
|
|
ConvertComplexTensorToParameter(fg, &inputs); |
|
|
|
|
|
|
|
outputs.clear(); |
|
|
|
kernel::GetFuncGraphOutputNodes(fg, &outputs); |
|
|
|
return std::make_tuple(fg, inputs, outputs); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, |
|
|
|
@@ -400,6 +485,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> fn_inputs; |
|
|
|
size_t offset = 0; |
|
|
|
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { |
|
|
|
AnfNodePtrList real_outs; |
|
|
|
// not make tuple out, replace |
|
|
|
@@ -427,7 +513,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f |
|
|
|
auto value_node = value_input->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
int item_idx = GetValue<int>(value_node->value()); |
|
|
|
int new_item_idx = SizeToInt(out_idx) + item_idx; |
|
|
|
int new_item_idx = SizeToInt(out_idx) + offset + item_idx; |
|
|
|
fn_inputs.clear(); |
|
|
|
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); |
|
|
|
fn_inputs.push_back(new_fuse_cnode); |
|
|
|
@@ -436,6 +522,8 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f |
|
|
|
new_out->set_abstract(get_item_cnode->abstract()); |
|
|
|
mng->Replace(get_item_cnode, new_out); |
|
|
|
} |
|
|
|
|
|
|
|
offset += real_outs.size() - 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -454,31 +542,17 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, |
|
|
|
|
|
|
|
FuncGraphPtr fg; |
|
|
|
AnfNodePtrList inputs; |
|
|
|
AnfNodePtrList src_outputs; |
|
|
|
AnfNodePtrList outputs; |
|
|
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); |
|
|
|
|
|
|
|
// Remove nest make tuple in outs |
|
|
|
auto expand_out = GetExpandOuts(outputs); |
|
|
|
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select); |
|
|
|
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs); |
|
|
|
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); |
|
|
|
if (!is_before_kernel_select) { |
|
|
|
SetNewKernelInfo(fuse_new_node, fg, inputs, expand_out, AnfAlgo::GetProcessor(fuse_nodes[0])); |
|
|
|
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0])); |
|
|
|
} |
|
|
|
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); |
|
|
|
// Handle get-item probleam. |
|
|
|
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs); |
|
|
|
|
|
|
|
// Inline origin graphkernel |
|
|
|
auto cnodes = fg->GetOrderedCnodes(); |
|
|
|
for (const auto &n : cnodes) { |
|
|
|
if (!AnfAlgo::IsGraphKernel(n)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0)); |
|
|
|
AnfNodePtrList ins; |
|
|
|
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); |
|
|
|
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); |
|
|
|
mng->Replace(n, out); |
|
|
|
} |
|
|
|
|
|
|
|
EliminateMakeTuple(fg, mng); |
|
|
|
// set graphKernel attr |
|
|
|
std::string fuse_op_name = ""; |
|
|
|
for (auto &fuse_node : fuse_nodes) { |
|
|
|
@@ -512,32 +586,45 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n |
|
|
|
if (is_single_graph_kernel) { |
|
|
|
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); |
|
|
|
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); |
|
|
|
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); |
|
|
|
} else if (!has_graph_kernel) { |
|
|
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); |
|
|
|
op_nodes = nodes; |
|
|
|
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); |
|
|
|
} else { |
|
|
|
// When there are basic and composite ops, the composite ops should be inline to the basic ones' graph, |
|
|
|
// so a new graph generation should be done (beacuse they may in the main graph!). |
|
|
|
// If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now. |
|
|
|
MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!"; |
|
|
|
} |
|
|
|
|
|
|
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes); |
|
|
|
auto mng = Manage(fg, false); |
|
|
|
fg->set_manager(mng); |
|
|
|
// Inline origin graph kernel |
|
|
|
auto fg_nodes = fg->GetOrderedCnodes(); |
|
|
|
for (auto const &n : fg_nodes) { |
|
|
|
if (!AnfAlgo::IsGraphKernel(n)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0)); |
|
|
|
AnfNodePtrList ins; |
|
|
|
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); |
|
|
|
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); |
|
|
|
mng->Replace(n, out); |
|
|
|
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc) { |
|
|
|
MS_EXCEPTION_IF_NULL(op_desc); |
|
|
|
if (nodes.empty()) { |
|
|
|
MS_LOG(ERROR) << "Input nodes is empty."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
inputs.clear(); |
|
|
|
outputs.clear(); |
|
|
|
|
|
|
|
FuncGraphPtr fg; |
|
|
|
AnfNodePtrList op_nodes, inputs, outputs; |
|
|
|
if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) { |
|
|
|
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]); |
|
|
|
} else { |
|
|
|
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(nodes); |
|
|
|
inputs.clear(); |
|
|
|
outputs.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs); |
|
|
|
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc, address_node_map); |
|
|
|
|
|
|
|
auto mng = fg->manager(); |
|
|
|
if (mng == nullptr) { |
|
|
|
mng = Manage(fg, false); |
|
|
|
fg->set_manager(mng); |
|
|
|
} |
|
|
|
|
|
|
|
return GenJson(op_nodes, inputs, outputs, dump_option, op_desc); |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc) { |
|
|
|
|