|
|
@@ -24,10 +24,16 @@ |
|
|
#include "tools/common/graph_util.h" |
|
|
#include "tools/common/graph_util.h" |
|
|
#include "tools/common/protobuf_utils.h" |
|
|
#include "tools/common/protobuf_utils.h" |
|
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h" |
|
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h" |
|
|
|
|
|
#include "tools/optimizer/common/gllo_utils.h" |
|
|
|
|
|
|
|
|
namespace mindspore { |
|
|
namespace mindspore { |
|
|
namespace lite { |
|
|
namespace lite { |
|
|
namespace { |
|
|
namespace { |
|
|
|
|
|
static const std::vector<schema::PrimitiveType> tensorListOutputOpList = { |
|
|
|
|
|
schema::PrimitiveType_TensorListFromTensor, |
|
|
|
|
|
schema::PrimitiveType_TensorListSetItem, |
|
|
|
|
|
schema::PrimitiveType_TensorListReserve, |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { |
|
|
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { |
|
|
AnfNodePtr ret = nullptr; |
|
|
AnfNodePtr ret = nullptr; |
|
|
@@ -216,7 +222,6 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value |
|
|
param_value->set_tensor_type(type); |
|
|
param_value->set_tensor_type(type); |
|
|
param_value->set_format(schema::Format::Format_NHWC); |
|
|
param_value->set_format(schema::Format::Format_NHWC); |
|
|
parameter->set_default_param(param_value); |
|
|
parameter->set_default_param(param_value); |
|
|
parameter->set_name("const_" + std::to_string(anf_root_node_map.size()) + "_parameter"); |
|
|
|
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -248,8 +253,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa |
|
|
return status; |
|
|
return status; |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
parameter->set_name("placeholder_" + std::to_string(anf_root_node_map.size())); |
|
|
|
|
|
graph_input_names.emplace_back(parameter->name()); // only root graph need set graph input names |
|
|
|
|
|
|
|
|
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); |
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); |
|
|
@@ -257,8 +261,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa |
|
|
MS_LOG(ERROR) << "abstract_tensor is nullptr"; |
|
|
MS_LOG(ERROR) << "abstract_tensor is nullptr"; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
|
|
|
parameter->set_name(node.name()); |
|
|
parameter->set_abstract(abstract_tensor); |
|
|
parameter->set_abstract(abstract_tensor); |
|
|
|
|
|
|
|
|
|
|
|
(*anf_node_map)[node.name()] = parameter; |
|
|
(*anf_node_map)[node.name() + ":0"] = parameter; |
|
|
(*anf_node_map)[node.name() + ":0"] = parameter; |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
@@ -294,43 +300,48 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
tf_root_graph = std::make_unique<tensorflow::GraphDef>(); |
|
|
|
|
|
if (tf_root_graph == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "tf_root_graph is nullptr"; |
|
|
|
|
|
|
|
|
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>(); |
|
|
|
|
|
if (tf_root_graph_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph.get()); |
|
|
|
|
|
|
|
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); |
|
|
if (status != RET_OK) { |
|
|
if (status != RET_OK) { |
|
|
MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; |
|
|
MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
anf_root_graph = std::make_shared<FuncGraph>(); |
|
|
|
|
|
if (anf_root_graph == nullptr) { |
|
|
|
|
|
|
|
|
anf_root_graph_ = std::make_shared<FuncGraph>(); |
|
|
|
|
|
if (anf_root_graph_ == nullptr) { |
|
|
MS_LOG(ERROR) << "funGraphPtr is nullptr"; |
|
|
MS_LOG(ERROR) << "funGraphPtr is nullptr"; |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for (int i = 0; i < tf_root_graph->node_size(); i++) { |
|
|
|
|
|
auto &node_def = tf_root_graph->node(i); |
|
|
|
|
|
tf_root_graph_nodes[node_def.name()] = &node_def; |
|
|
|
|
|
|
|
|
for (int i = 0; i < tf_root_graph_->node_size(); i++) { |
|
|
|
|
|
auto &node_def = tf_root_graph_->node(i); |
|
|
|
|
|
tf_root_graph_nodes_[node_def.name()] = &node_def; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes, anf_root_graph, &anf_root_node_map); |
|
|
|
|
|
|
|
|
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); |
|
|
if (status != RET_OK) { |
|
|
if (status != RET_OK) { |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
for (int i = 0; i < tf_root_graph->node_size(); i++) { |
|
|
|
|
|
auto &node_def = tf_root_graph->node(i); |
|
|
|
|
|
if (ConvertOps(node_def, tf_root_graph_nodes, anf_root_graph, &anf_root_node_map) != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "Convert ops failed."; |
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
|
|
return nullptr; |
|
|
|
|
|
|
|
|
bool success_flag = true; |
|
|
|
|
|
for (int i = 0; i < tf_root_graph_->node_size(); i++) { |
|
|
|
|
|
auto &node_def = tf_root_graph_->node(i); |
|
|
|
|
|
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); |
|
|
|
|
|
if (status != RET_OK) { |
|
|
|
|
|
success_flag = false; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if (!success_flag) { |
|
|
|
|
|
MS_LOG(ERROR) << "Convert ops failed."; |
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
|
|
return nullptr; |
|
|
|
|
|
} |
|
|
status = ConvertRootGraphOutputs(); |
|
|
status = ConvertRootGraphOutputs(); |
|
|
if (status != RET_OK) { |
|
|
if (status != RET_OK) { |
|
|
MS_LOG(ERROR) << "Convert graph outputs failed."; |
|
|
MS_LOG(ERROR) << "Convert graph outputs failed."; |
|
|
@@ -345,10 +356,10 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return anf_root_graph; |
|
|
|
|
|
|
|
|
return anf_root_graph_; |
|
|
} |
|
|
} |
|
|
STATUS TFModelParser::ConvertSubgraph() { |
|
|
STATUS TFModelParser::ConvertSubgraph() { |
|
|
auto graph_def_liarary = tf_root_graph->library(); |
|
|
|
|
|
|
|
|
auto graph_def_liarary = tf_root_graph_->library(); |
|
|
auto subgraph_size = graph_def_liarary.function_size(); |
|
|
auto subgraph_size = graph_def_liarary.function_size(); |
|
|
std::map<CNodePtr, FuncGraphPtr> while_cond_map; |
|
|
std::map<CNodePtr, FuncGraphPtr> while_cond_map; |
|
|
std::map<CNodePtr, FuncGraphPtr> while_body_map; |
|
|
std::map<CNodePtr, FuncGraphPtr> while_body_map; |
|
|
@@ -359,11 +370,11 @@ STATUS TFModelParser::ConvertSubgraph() { |
|
|
auto input_arg_size = tf_sub_signature.input_arg_size(); |
|
|
auto input_arg_size = tf_sub_signature.input_arg_size(); |
|
|
|
|
|
|
|
|
auto &sub_graph_name = tf_sub_signature.name(); |
|
|
auto &sub_graph_name = tf_sub_signature.name(); |
|
|
if (!function_while_map.count(sub_graph_name)) { |
|
|
|
|
|
|
|
|
if (!function_while_map_.count(sub_graph_name)) { |
|
|
MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; |
|
|
MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
auto while_cnode = function_while_map[sub_graph_name]->cast<CNodePtr>(); |
|
|
|
|
|
|
|
|
auto while_cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>(); |
|
|
if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) { |
|
|
if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) { |
|
|
MS_LOG(ERROR) << "while cnode not equal input arg size"; |
|
|
MS_LOG(ERROR) << "while cnode not equal input arg size"; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
@@ -441,7 +452,7 @@ STATUS TFModelParser::ConvertSubgraph() { |
|
|
} |
|
|
} |
|
|
// hardcode subgraph inputs name |
|
|
// hardcode subgraph inputs name |
|
|
for (size_t j = 0; j < sub_graph_inputs.size(); j++) { |
|
|
for (size_t j = 0; j < sub_graph_inputs.size(); j++) { |
|
|
sub_graph_inputs[j]->set_name("graph_input_" + std::to_string(j) + "parameter"); |
|
|
|
|
|
|
|
|
sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); |
|
|
} |
|
|
} |
|
|
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; |
|
|
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; |
|
|
} |
|
|
} |
|
|
@@ -458,9 +469,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr |
|
|
MS_LOG(ERROR) << "while cond body size error"; |
|
|
MS_LOG(ERROR) << "while cond body size error"; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
std::vector<FuncGraphPtr> roots = {anf_root_graph}; |
|
|
|
|
|
|
|
|
std::vector<FuncGraphPtr> roots = {anf_root_graph_}; |
|
|
auto root_func_manager = std::make_shared<FuncGraphManager>(roots); |
|
|
auto root_func_manager = std::make_shared<FuncGraphManager>(roots); |
|
|
anf_root_graph->set_manager(root_func_manager); |
|
|
|
|
|
|
|
|
anf_root_graph_->set_manager(root_func_manager); |
|
|
for (auto &kv : while_cond_map) { |
|
|
for (auto &kv : while_cond_map) { |
|
|
auto while_node = kv.first; |
|
|
auto while_node = kv.first; |
|
|
auto &cond_sub_graph = kv.second; |
|
|
auto &cond_sub_graph = kv.second; |
|
|
@@ -513,10 +524,20 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C |
|
|
MS_ASSERT(op != nullptr); |
|
|
MS_ASSERT(op != nullptr); |
|
|
MS_ASSERT(anf_node != nullptr); |
|
|
MS_ASSERT(anf_node != nullptr); |
|
|
MS_ASSERT(anf_graph != nullptr); |
|
|
MS_ASSERT(anf_graph != nullptr); |
|
|
if (output_size == 1) { |
|
|
|
|
|
|
|
|
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) { |
|
|
|
|
|
MS_LOG(ERROR) << "tensorlist output op output_size !=1"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (output_size == 0) { |
|
|
|
|
|
return RET_OK; |
|
|
|
|
|
} else if (output_size == 1) { |
|
|
|
|
|
auto type = kFloat32; |
|
|
std::vector<int64_t> shape_vector; |
|
|
std::vector<int64_t> shape_vector; |
|
|
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); |
|
|
|
|
|
anf_node_map->insert(std::pair(op.name() + ":0", anf_node)); |
|
|
|
|
|
|
|
|
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { |
|
|
|
|
|
type = TypeIdToType(kObjectTypeTensorType); |
|
|
|
|
|
} |
|
|
|
|
|
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector)); |
|
|
|
|
|
anf_node_map->insert(std::pair(op.name(), anf_node)); |
|
|
} else { |
|
|
} else { |
|
|
AbstractBasePtrList abstractList; |
|
|
AbstractBasePtrList abstractList; |
|
|
for (int output_idx = 0; output_idx < output_size; output_idx++) { |
|
|
for (int output_idx = 0; output_idx < output_size; output_idx++) { |
|
|
@@ -585,12 +606,12 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
tensorflow::AttrValue attr_value; |
|
|
tensorflow::AttrValue attr_value; |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { |
|
|
auto body_name = attr_value.func().name(); |
|
|
auto body_name = attr_value.func().name(); |
|
|
function_while_map[body_name] = anf_node; |
|
|
|
|
|
|
|
|
function_while_map_[body_name] = anf_node; |
|
|
MS_LOG(DEBUG) << "parse body name:" << body_name; |
|
|
MS_LOG(DEBUG) << "parse body name:" << body_name; |
|
|
} |
|
|
} |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { |
|
|
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { |
|
|
auto cond_name = attr_value.func().name(); |
|
|
auto cond_name = attr_value.func().name(); |
|
|
function_while_map[cond_name] = anf_node; |
|
|
|
|
|
|
|
|
function_while_map_[cond_name] = anf_node; |
|
|
MS_LOG(DEBUG) << "parse cond name:" << cond_name; |
|
|
MS_LOG(DEBUG) << "parse cond name:" << cond_name; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -606,31 +627,31 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, |
|
|
|
|
|
|
|
|
STATUS TFModelParser::ConvertRootGraphOutputs() { |
|
|
STATUS TFModelParser::ConvertRootGraphOutputs() { |
|
|
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in |
|
|
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in |
|
|
// tf_root_graph_nodes but not anf_root_node_map |
|
|
|
|
|
|
|
|
// tf_root_graph_nodes_ but not anf_root_node_map_ |
|
|
std::set<std::string> all_node_inputs; |
|
|
std::set<std::string> all_node_inputs; |
|
|
std::vector<AnfNodePtr> output_nodes; |
|
|
std::vector<AnfNodePtr> output_nodes; |
|
|
for (auto &pair : tf_root_graph_nodes) { |
|
|
|
|
|
|
|
|
for (auto &pair : tf_root_graph_nodes_) { |
|
|
for (int i = 0; i < pair.second->input_size(); ++i) { |
|
|
for (int i = 0; i < pair.second->input_size(); ++i) { |
|
|
all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); |
|
|
all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
for (auto &pair : tf_root_graph_nodes) { |
|
|
|
|
|
|
|
|
for (auto &pair : tf_root_graph_nodes_) { |
|
|
if (pair.second->op() == "Assert") { |
|
|
if (pair.second->op() == "Assert") { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
auto it = all_node_inputs.find(pair.first); |
|
|
auto it = all_node_inputs.find(pair.first); |
|
|
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity |
|
|
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity |
|
|
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes); |
|
|
|
|
|
auto anf_node = GetAnfNode(origin_name, anf_root_node_map); |
|
|
|
|
|
|
|
|
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); |
|
|
|
|
|
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); |
|
|
if (anf_node == nullptr) { |
|
|
if (anf_node == nullptr) { |
|
|
MS_LOG(ERROR) << "can't find anf node"; |
|
|
MS_LOG(ERROR) << "can't find anf node"; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
output_nodes.push_back(anf_node); |
|
|
output_nodes.push_back(anf_node); |
|
|
graph_output_names.push_back(anf_node->fullname_with_scope()); |
|
|
|
|
|
|
|
|
graph_output_names_.push_back(anf_node->fullname_with_scope()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph); |
|
|
|
|
|
|
|
|
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); |
|
|
if (status != RET_OK) { |
|
|
if (status != RET_OK) { |
|
|
MS_LOG(ERROR) << "make anf graph outputs node error"; |
|
|
MS_LOG(ERROR) << "make anf graph outputs node error"; |
|
|
return status; |
|
|
return status; |
|
|
|