|
|
|
@@ -40,13 +40,14 @@ bool IsTensorListOp(const AnfNodePtr &anf_node) { |
|
|
|
opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { |
|
|
|
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map, |
|
|
|
int index = 0) { |
|
|
|
AnfNodePtr ret = nullptr; |
|
|
|
auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name); |
|
|
|
if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) { |
|
|
|
ret = anf_node_map.at(flat_anf_name); |
|
|
|
} else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { |
|
|
|
ret = anf_node_map.at(flat_anf_name + ":0"); |
|
|
|
} else if (anf_node_map.find(name + ":" + to_string(index)) != anf_node_map.end()) { |
|
|
|
ret = anf_node_map.at(flat_anf_name + ":" + to_string(index)); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
@@ -901,6 +902,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
|
MS_LOG(ERROR) << "node " << op_type << " parser failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
node_output_num_[node_def.name()] = output_size; |
|
|
|
for (int i = 0; i < output_size; i++) { |
|
|
|
node_output_num_[node_def.name() + ":" + to_string(i)] = 1; |
|
|
|
} |
|
|
|
auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC)); |
|
|
|
if (value_node == nullptr) { |
|
|
|
MS_LOG(ERROR) << "value_node is nullptr"; |
|
|
|
@@ -915,6 +920,32 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
|
// control_depends are not processed currently |
|
|
|
auto anf_node = func_graph_ptr->NewCNode(inputs); |
|
|
|
anf_node->set_fullname_with_scope(node_def.name()); |
|
|
|
status = ProcessControlFlowOp(anf_node, op_type, node_def); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "ProcessControlFlowOp failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
if (!input_name_not_found.empty()) { |
|
|
|
RecordNullInput(anf_node, input_name_not_found); |
|
|
|
} |
|
|
|
|
|
|
|
status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed."; |
|
|
|
return status; |
|
|
|
} |
|
|
|
return status; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TFModelParser::ProcessControlFlowOp(CNodePtr anf_node, const string op_type, |
|
|
|
const tensorflow::NodeDef &node_def) { |
|
|
|
if (op_type == "StatelessWhile" || op_type == "While") { |
|
|
|
MS_LOG(INFO) << "find while node:" << node_def.name(); |
|
|
|
tensorflow::AttrValue attr_value; |
|
|
|
@@ -944,21 +975,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
|
MS_LOG(DEBUG) << "parse else name:" << else_name; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (!input_name_not_found.empty()) { |
|
|
|
RecordNullInput(anf_node, input_name_not_found); |
|
|
|
} |
|
|
|
|
|
|
|
status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; |
|
|
|
} |
|
|
|
|
|
|
|
status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed."; |
|
|
|
} |
|
|
|
return status; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size, |
|
|
|
@@ -994,13 +1011,15 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { |
|
|
|
auto it = all_node_inputs.find(pair.first); |
|
|
|
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity |
|
|
|
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); |
|
|
|
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); |
|
|
|
if (anf_node == nullptr) { |
|
|
|
MS_LOG(ERROR) << "can't find anf node: " << origin_name; |
|
|
|
return RET_ERROR; |
|
|
|
for (int i = 0; i < node_output_num_[origin_name]; i++) { |
|
|
|
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_, i); |
|
|
|
if (anf_node == nullptr) { |
|
|
|
MS_LOG(ERROR) << "can't find anf node: " << origin_name; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
output_nodes.push_back(anf_node); |
|
|
|
graph_output_names_.push_back(anf_node->fullname_with_scope()); |
|
|
|
} |
|
|
|
output_nodes.push_back(anf_node); |
|
|
|
graph_output_names_.push_back(anf_node->fullname_with_scope()); |
|
|
|
} |
|
|
|
} |
|
|
|
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); |
|
|
|
|