|
|
|
@@ -21,6 +21,7 @@ |
|
|
|
#include <functional> |
|
|
|
#include <utility> |
|
|
|
#include <vector> |
|
|
|
#include <algorithm> |
|
|
|
#include "tools/converter/converter_flags.h" |
|
|
|
#include "abstract/abstract_value.h" |
|
|
|
#include "mindspore/core/ir/primitive.h" |
|
|
|
@@ -421,18 +422,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee |
|
|
|
|
|
|
|
int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) { |
|
|
|
MS_ASSERT(input_anode != nullptr && output_cnode != nullptr); |
|
|
|
auto input_name = input_anode->fullname_with_scope(); |
|
|
|
if (this->train_flag_) { |
|
|
|
bool found = false; |
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); |
|
|
|
found = true; |
|
|
|
} |
|
|
|
if (!found) { |
|
|
|
auto input_index_key = input_name + "_o:" + std::to_string(0); |
|
|
|
if (node_id_map_.find(input_index_key) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]); |
|
|
|
} |
|
|
|
auto key = std::make_pair(input_anode, 0); |
|
|
|
if (node_id_map_.find(key) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[key]); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -444,20 +437,15 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema |
|
|
|
} |
|
|
|
auto elements = tuple->elements(); |
|
|
|
for (size_t i = 0; i < elements.size(); i++) { |
|
|
|
if (elements.size() == 1) { |
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::string name = input_name + "_o:" + std::to_string(i); |
|
|
|
if (node_id_map_.find(name) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[name]); |
|
|
|
} |
|
|
|
auto key = std::make_pair(input_anode, i); |
|
|
|
if (node_id_map_.find(key) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[key]); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); |
|
|
|
auto key = std::make_pair(input_anode, 0); |
|
|
|
if (node_id_map_.find(key) != node_id_map_.end()) { |
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[key]); |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
@@ -490,16 +478,16 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, |
|
|
|
MS_LOG(ERROR) << "cast to ValueNode failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + |
|
|
|
std::to_string(value_node->value()->type()->number_type() == kNumberTypeInt64 |
|
|
|
? GetValue<int64_t>(value_node->value()) |
|
|
|
: GetValue<int>(value_node->value())); |
|
|
|
auto iter = node_id_map_.find(input_index_key); |
|
|
|
auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value()) |
|
|
|
: GetValue<int>(value_node->value()); |
|
|
|
auto key = std::make_pair(get_item_input_cnode, idx); |
|
|
|
auto iter = node_id_map_.find(key); |
|
|
|
if (iter == node_id_map_.end()) { |
|
|
|
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0 |
|
|
|
iter = node_id_map_.find(input_index_key); |
|
|
|
key = std::make_pair(get_item_input_cnode, 0); // try name with 0 |
|
|
|
iter = node_id_map_.find(key); |
|
|
|
if (iter == node_id_map_.end()) { |
|
|
|
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key; |
|
|
|
MS_LOG(ERROR) << "Can not find get_item output tensor " |
|
|
|
<< get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(idx); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -513,9 +501,9 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons |
|
|
|
schema::CNodeT *op_node) { |
|
|
|
auto param_node = cnode->input(index)->cast<ParameterPtr>(); |
|
|
|
MS_ASSERT(param_node != nullptr); |
|
|
|
std::string input_name = param_node->fullname_with_scope(); |
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) { |
|
|
|
op_node->inputIndex.emplace_back(node_id_map_[param_node->name()]); |
|
|
|
auto key = std::make_pair(param_node, 0); |
|
|
|
if (node_id_map_.find(key) != node_id_map_.end()) { |
|
|
|
op_node->inputIndex.emplace_back(node_id_map_[key]); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
DataInfo data_info; |
|
|
|
@@ -532,7 +520,7 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons |
|
|
|
schema_tensor->data = data_info.data_; |
|
|
|
schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_; |
|
|
|
|
|
|
|
node_id_map_[input_name] = meta_graphT->allTensors.size(); |
|
|
|
node_id_map_[key] = meta_graphT->allTensors.size(); |
|
|
|
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); |
|
|
|
meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); |
|
|
|
return RET_OK; |
|
|
|
@@ -556,7 +544,9 @@ int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, cons |
|
|
|
schema_tensor->dataType = data_info.data_type_; |
|
|
|
schema_tensor->dims = data_info.shape_; |
|
|
|
schema_tensor->data = data_info.data_; |
|
|
|
node_id_map_[cnode->input(index)->fullname_with_scope()] = meta_graphT->allTensors.size(); |
|
|
|
|
|
|
|
auto key = std::make_pair(cnode->input(index), 0); |
|
|
|
node_id_map_[key] = meta_graphT->allTensors.size(); |
|
|
|
op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); |
|
|
|
meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); |
|
|
|
return RET_OK; |
|
|
|
@@ -628,18 +618,18 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s |
|
|
|
} |
|
|
|
ms_tensor->nodeType = NodeType_CNode; |
|
|
|
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); |
|
|
|
auto key = std::make_pair(cnode, i); |
|
|
|
if (train_flag_) { |
|
|
|
std::string name = cnode_name + "_o:" + std::to_string(i); |
|
|
|
node_id_map_[name] = meta_graphT->allTensors.size(); |
|
|
|
node_id_map_[key] = meta_graphT->allTensors.size(); |
|
|
|
meta_graphT->allTensors.emplace_back(ms_tensor); |
|
|
|
} else { |
|
|
|
if (elements.size() == 1) { |
|
|
|
node_id_map_[cnode_name] = meta_graphT->allTensors.size(); |
|
|
|
key = std::make_pair(cnode, 0); |
|
|
|
node_id_map_[key] = meta_graphT->allTensors.size(); |
|
|
|
ms_tensor->name = cnode_name; |
|
|
|
} else { |
|
|
|
std::string name = cnode_name + "_o:" + std::to_string(i); |
|
|
|
node_id_map_[name] = meta_graphT->allTensors.size(); |
|
|
|
ms_tensor->name = name; |
|
|
|
node_id_map_[key] = meta_graphT->allTensors.size(); |
|
|
|
ms_tensor->name = cnode_name + "_o:" + std::to_string(i); |
|
|
|
} |
|
|
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { |
|
|
|
@@ -673,7 +663,9 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s |
|
|
|
ms_tensor->nodeType = NodeType_CNode; |
|
|
|
ms_tensor->name = cnode_name; |
|
|
|
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); |
|
|
|
node_id_map_[cnode_name] = meta_graphT->allTensors.size(); |
|
|
|
|
|
|
|
auto key = std::make_pair(cnode, 0); |
|
|
|
node_id_map_[key] = meta_graphT->allTensors.size(); |
|
|
|
meta_graphT->allTensors.emplace_back(ms_tensor); |
|
|
|
} |
|
|
|
} |
|
|
|
|