|
|
|
@@ -55,60 +55,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { |
|
|
|
MS_ASSERT(cnode != nullptr); |
|
|
|
bool has_tuple_get_item = false; |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.clear(); |
|
|
|
inputs.emplace_back(cnode->input(0)); |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) { |
|
|
|
AnfNodePtr input_node = cnode->input(i); |
|
|
|
if (!input_node->isa<CNode>()) { |
|
|
|
inputs.emplace_back(cnode->input(i)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto tuple_get_item_node = utils::cast<CNodePtr>(input_node); |
|
|
|
if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) { |
|
|
|
has_tuple_get_item = true; |
|
|
|
inputs.emplace_back(tuple_get_item_node->input(1)); |
|
|
|
AnfNodePtr indexNode = tuple_get_item_node->input(2); |
|
|
|
if (!utils::isa<ValueNode>(indexNode)) { |
|
|
|
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto value_node = utils::cast<ValueNodePtr>(indexNode); |
|
|
|
} else { |
|
|
|
inputs.emplace_back(cnode->input(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (has_tuple_get_item) { |
|
|
|
cnode->set_inputs(inputs); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode) { |
|
|
|
MS_ASSERT(meta_graphT != nullptr); |
|
|
|
MS_ASSERT(cnode != nullptr); |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) { |
|
|
|
auto input_anode = cnode->input(i); |
|
|
|
if (!input_anode->isa<CNode>()) { |
|
|
|
MS_LOG(ERROR) << "Node of Return's input is not CNode"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto input_cnode = utils::cast<CNodePtr>(input_anode); |
|
|
|
std::string input_name = input_anode->fullname_with_scope(); |
|
|
|
auto iter = node_id_map_.find(input_name); |
|
|
|
if (iter == node_id_map_.end()) { |
|
|
|
MS_LOG(ERROR) << "Could not find output node"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto graph_output = iter->second; |
|
|
|
meta_graphT->outputIndex.emplace_back(graph_output); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, |
|
|
|
const std::shared_ptr<PrimitiveTValue> primitive, |
|
|
|
const std::unique_ptr<schema::CNodeT> &dst_node) { |
|
|
|
@@ -182,6 +128,28 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> & |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, |
|
|
|
schema::CNodeT *return_node) { |
|
|
|
MS_ASSERT(nullptr != meta_graph); |
|
|
|
MS_ASSERT(nullptr != return_node); |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) { |
|
|
|
auto input_node = cnode->input(i); |
|
|
|
if (input_node->isa<CNode>()) { |
|
|
|
auto ret = ConvertInputCNode(input_node, return_node); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "obtain outputs failed"; |
|
|
|
return; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node"; |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < return_node->inputIndex.size(); ++i) { |
|
|
|
meta_graphT->outputIndex.push_back(return_node->inputIndex[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { |
|
|
|
auto cnodes = func_graph->GetOrderedCnodes(); |
|
|
|
auto meta_graphT = std::make_unique<schema::MetaGraphT>(); |
|
|
|
@@ -202,24 +170,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { |
|
|
|
} |
|
|
|
RemoveIfMakeTuple(cnode); |
|
|
|
|
|
|
|
auto node = std::make_unique<schema::CNodeT>(); |
|
|
|
|
|
|
|
if (primT->value.type == schema::PrimitiveType_Return) { |
|
|
|
AddOutPutIfReturn(meta_graphT, cnode); |
|
|
|
node->name = "return_node"; |
|
|
|
SetGraphoutputIndex(cnode, meta_graphT, node.get()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto node = std::make_unique<schema::CNodeT>(); |
|
|
|
node->name = cnode->fullname_with_scope(); |
|
|
|
node->nodeType = schema::NodeType_CNode; |
|
|
|
|
|
|
|
node->name = cnode->fullname_with_scope(); |
|
|
|
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT); |
|
|
|
auto ret = SetOpInputNode(cnode, meta_graphT, node.get()); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "SetOpInputNode failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
SetOpOutputNode(cnode, meta_graphT, node.get()); |
|
|
|
|
|
|
|
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "ConvertQuantParam failed"; |
|
|
|
|