|
|
|
@@ -34,7 +34,6 @@ |
|
|
|
#include "src/tensor.h" |
|
|
|
#include "src/param_value_lite.h" |
|
|
|
#include "src/common/utils.h" |
|
|
|
#include "ops/partial.h" |
|
|
|
#include "tools/common/graph_util.h" |
|
|
|
#include "src/ops/ops_utils.h" |
|
|
|
|
|
|
|
@@ -287,7 +286,6 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &m |
|
|
|
|
|
|
|
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, |
|
|
|
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, |
|
|
|
const std::unique_ptr<schema::SubGraphT> &sub_graphT, |
|
|
|
schema::CNodeT *return_node) { |
|
|
|
MS_ASSERT(nullptr != meta_graphT); |
|
|
|
MS_ASSERT(nullptr != return_node); |
|
|
|
@@ -319,9 +317,15 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
bool AnfExporter::HasExported(const FuncGraphPtr &func_graph) { |
|
|
|
if (fg_subgraph_map_.find(func_graph) != fg_subgraph_map_.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, |
|
|
|
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive, |
|
|
|
const std::unique_ptr<schema::SubGraphT> &sub_graphT) { |
|
|
|
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive) { |
|
|
|
int ret = RET_OK; |
|
|
|
auto cnodes = GetOrderedCNodes(func_graph); |
|
|
|
for (const auto &cnode : cnodes) { |
|
|
|
@@ -334,19 +338,18 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc |
|
|
|
prim = GetValueNode<std::shared_ptr<Primitive>>(partial_cnode->input(0)); |
|
|
|
primT = GetPrimitiveT(partial_cnode->input(0)); |
|
|
|
MS_ASSERT(primT != nullptr); |
|
|
|
auto pos = fg_subgraph_map.find(fg); |
|
|
|
if (pos != fg_subgraph_map.end()) { |
|
|
|
auto pos = fg_subgraph_map_.find(fg); |
|
|
|
if (pos != fg_subgraph_map_.end()) { |
|
|
|
MS_ASSERT(primT->value.AsPartialFusion() != nullptr); |
|
|
|
primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map.at(fg); |
|
|
|
primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map_.at(fg); |
|
|
|
} else { |
|
|
|
size_t next_subgraph_index = fg_subgraph_map.size() + 1; |
|
|
|
fg_subgraph_map.insert(std::pair<FuncGraphPtr, int>{fg, next_subgraph_index}); |
|
|
|
size_t next_subgraph_index = meta_graphT->subGraph.size(); |
|
|
|
MS_ASSERT(primT->value.AsPartialFusion() != nullptr); |
|
|
|
primT->value.AsPartialFusion()->sub_graph_index = next_subgraph_index; |
|
|
|
ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode); |
|
|
|
ret = ExportSubgraph(fg, meta_graphT, keep_graph, copy_primitive, cnode); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "ExportSubgraph failed"; |
|
|
|
break; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
@@ -374,7 +377,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc |
|
|
|
} |
|
|
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) { |
|
|
|
node->name = mindspore::ops::kNameReturn; |
|
|
|
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get()); |
|
|
|
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, node.get()); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "SetOpOutputN failed"; |
|
|
|
break; |
|
|
|
@@ -398,26 +401,28 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc |
|
|
|
break; |
|
|
|
} |
|
|
|
meta_graphT->nodes.push_back(std::move(node)); |
|
|
|
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++); |
|
|
|
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, |
|
|
|
const size_t &subgraph_index, bool keep_graph, bool copy_primitive, |
|
|
|
const std::shared_ptr<AnfNode> &partial_anode) { |
|
|
|
int ret = RET_OK; |
|
|
|
bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode) { |
|
|
|
if (HasExported(func_graph)) { |
|
|
|
MS_LOG(INFO) << "Has been exported."; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
meta_graphT->subGraph.emplace_back(std::make_unique<schema::SubGraphT>()); |
|
|
|
auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index); |
|
|
|
auto subgraph_index = meta_graphT->subGraph.size() - 1; |
|
|
|
fg_subgraph_map_[func_graph] = subgraph_index; |
|
|
|
auto subgraph_name = func_graph->get_attr("graph_name"); |
|
|
|
MS_ASSERT(subgraph_name != nullptr); |
|
|
|
sub_graphT->name = GetValue<std::string>(subgraph_name); |
|
|
|
auto fmk = func_graph->get_attr("fmk"); |
|
|
|
MS_ASSERT(fmk != nullptr); |
|
|
|
meta_graphT->fmkType = GetValue<int>(fmk); |
|
|
|
meta_graphT->subGraph.back()->name = GetValue<std::string>(subgraph_name); |
|
|
|
|
|
|
|
ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive, sub_graphT); |
|
|
|
int ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Anf2Fb failed"; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
@@ -441,11 +446,14 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu |
|
|
|
|
|
|
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, |
|
|
|
bool train_flag) { |
|
|
|
static int subgraph_index = 0; |
|
|
|
this->train_flag = train_flag; |
|
|
|
this->train_flag_ = train_flag; |
|
|
|
auto meta_graphT = std::make_unique<schema::MetaGraphT>(); |
|
|
|
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive); |
|
|
|
auto fmk = func_graph->get_attr("fmk"); |
|
|
|
MS_ASSERT(fmk != nullptr); |
|
|
|
meta_graphT->fmkType = GetValue<int>(fmk); |
|
|
|
int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "ExportSubgraph failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -455,7 +463,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee |
|
|
|
int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) { |
|
|
|
MS_ASSERT(input_anode != nullptr && output_cnode != nullptr); |
|
|
|
auto input_name = input_anode->fullname_with_scope(); |
|
|
|
if (this->train_flag) { |
|
|
|
if (this->train_flag_) { |
|
|
|
bool found = false; |
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); |
|
|
|
@@ -618,7 +626,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc |
|
|
|
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), |
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); }); |
|
|
|
(*paramTensor)->dims = dims; |
|
|
|
if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; |
|
|
|
if (train_flag_ && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1}; |
|
|
|
(*paramTensor)->nodeType = NodeType_ValueNode; |
|
|
|
auto data = value->cast<tensor::TensorPtr>(); |
|
|
|
(*paramTensor)->data.resize(data->Size()); |
|
|
|
@@ -742,7 +750,7 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu |
|
|
|
(*paramTensor)->dataType = valueLite->tensor_type(); |
|
|
|
(*paramTensor)->dims = valueLite->tensor_shape(); |
|
|
|
|
|
|
|
if (train_flag && (*paramTensor)->dims.empty()) { |
|
|
|
if (train_flag_ && (*paramTensor)->dims.empty()) { |
|
|
|
(*paramTensor)->dims = {1}; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -765,7 +773,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano |
|
|
|
auto paramTensor = std::make_unique<schema::TensorT>(); |
|
|
|
auto value = valueNode->value(); |
|
|
|
int ret = RET_OK; |
|
|
|
if (train_flag) { |
|
|
|
if (train_flag_) { |
|
|
|
paramTensor->name = valueNode->fullname_with_scope(); |
|
|
|
} |
|
|
|
if (value->isa<tensor::Tensor>()) { |
|
|
|
@@ -861,7 +869,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s |
|
|
|
} |
|
|
|
msTensor->nodeType = NodeType_CNode; |
|
|
|
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); |
|
|
|
if (train_flag) { |
|
|
|
if (train_flag_) { |
|
|
|
std::string name = cnode_name + "_o:" + std::to_string(i); |
|
|
|
node_id_map_[name] = meta_graphT->allTensors.size(); |
|
|
|
meta_graphT->allTensors.emplace_back(msTensor); |
|
|
|
|