Browse Source

split_codedex

pull/15567/head
yefeng 4 years ago
parent
commit
1cc60732f7
6 changed files with 58 additions and 28 deletions
  1. +3
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc
  2. +43
    -24
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  3. +4
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  4. +6
    -2
      mindspore/lite/tools/converter/parser/tf/tf_util.cc
  5. +1
    -1
      mindspore/lite/tools/optimizer/graph/functionalize_cond.h
  6. +1
    -1
      mindspore/lite/tools/optimizer/graph/functionalize_while.h

+ 3
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc View File

@@ -45,6 +45,9 @@ ops::PrimitiveC *OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const
padding_mode = mindspore::PaddingMode::REFLECT; padding_mode = mindspore::PaddingMode::REFLECT;
} else if (mode == "edge") { } else if (mode == "edge") {
padding_mode = mindspore::PaddingMode::SYMMETRIC; padding_mode = mindspore::PaddingMode::SYMMETRIC;
} else {
MS_LOG(ERROR) << "Unsupported pad mode: " << mode;
return nullptr;
} }
prim->set_padding_mode(padding_mode); prim->set_padding_mode(padding_mode);
} }


+ 43
- 24
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -40,13 +40,14 @@ bool IsTensorListOp(const AnfNodePtr &anf_node) {
opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve); opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve);
} }


AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
int index = 0) {
AnfNodePtr ret = nullptr; AnfNodePtr ret = nullptr;
auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name); auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name);
if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) { if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) {
ret = anf_node_map.at(flat_anf_name); ret = anf_node_map.at(flat_anf_name);
} else if (anf_node_map.find(name + ":0") != anf_node_map.end()) {
ret = anf_node_map.at(flat_anf_name + ":0");
} else if (anf_node_map.find(name + ":" + to_string(index)) != anf_node_map.end()) {
ret = anf_node_map.at(flat_anf_name + ":" + to_string(index));
} }
return ret; return ret;
} }
@@ -901,6 +902,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
MS_LOG(ERROR) << "node " << op_type << " parser failed"; MS_LOG(ERROR) << "node " << op_type << " parser failed";
return RET_ERROR; return RET_ERROR;
} }
node_output_num_[node_def.name()] = output_size;
for (int i = 0; i < output_size; i++) {
node_output_num_[node_def.name() + ":" + to_string(i)] = 1;
}
auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC)); auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC));
if (value_node == nullptr) { if (value_node == nullptr) {
MS_LOG(ERROR) << "value_node is nullptr"; MS_LOG(ERROR) << "value_node is nullptr";
@@ -915,6 +920,32 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
// control_depends are not processed currently // control_depends are not processed currently
auto anf_node = func_graph_ptr->NewCNode(inputs); auto anf_node = func_graph_ptr->NewCNode(inputs);
anf_node->set_fullname_with_scope(node_def.name()); anf_node->set_fullname_with_scope(node_def.name());
status = ProcessControlFlowOp(anf_node, op_type, node_def);
if (status != RET_OK) {
MS_LOG(ERROR) << "ProcessControlFlowOp failed.";
return RET_ERROR;
}

if (!input_name_not_found.empty()) {
RecordNullInput(anf_node, input_name_not_found);
}

status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
return status;
}

status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
return status;
}
return status;
}

STATUS TFModelParser::ProcessControlFlowOp(CNodePtr anf_node, const string op_type,
const tensorflow::NodeDef &node_def) {
if (op_type == "StatelessWhile" || op_type == "While") { if (op_type == "StatelessWhile" || op_type == "While") {
MS_LOG(INFO) << "find while node:" << node_def.name(); MS_LOG(INFO) << "find while node:" << node_def.name();
tensorflow::AttrValue attr_value; tensorflow::AttrValue attr_value;
@@ -944,21 +975,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
MS_LOG(DEBUG) << "parse else name:" << else_name; MS_LOG(DEBUG) << "parse else name:" << else_name;
} }
} }

if (!input_name_not_found.empty()) {
RecordNullInput(anf_node, input_name_not_found);
}

status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
}

status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
}
return status;
return RET_OK;
} }


STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size, STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size,
@@ -994,13 +1011,15 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
auto it = all_node_inputs.find(pair.first); auto it = all_node_inputs.find(pair.first);
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node: " << origin_name;
return RET_ERROR;
for (int i = 0; i < node_output_num_[origin_name]; i++) {
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_, i);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node: " << origin_name;
return RET_ERROR;
}
output_nodes.push_back(anf_node);
graph_output_names_.push_back(anf_node->fullname_with_scope());
} }
output_nodes.push_back(anf_node);
graph_output_names_.push_back(anf_node->fullname_with_scope());
} }
} }
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);


+ 4
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -60,6 +60,9 @@ class TFModelParser : public ModelParser {
STATUS ConvertOps(const tensorflow::NodeDef &node_def, STATUS ConvertOps(const tensorflow::NodeDef &node_def,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map); const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map);

STATUS ProcessControlFlowOp(CNodePtr anf_node, const string op_type, const tensorflow::NodeDef &node_def);

STATUS ConvertRootGraphOutputs(); STATUS ConvertRootGraphOutputs();


STATUS ConvertSubgraph(); STATUS ConvertSubgraph();
@@ -95,6 +98,7 @@ class TFModelParser : public ModelParser {
std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{}; std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{};
std::vector<std::string> while_cond_branch_name_; std::vector<std::string> while_cond_branch_name_;
std::vector<std::string> if_then_branch_name_; std::vector<std::string> if_then_branch_name_;
std::unordered_map<std::string, int> node_output_num_;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 6
- 2
mindspore/lite/tools/converter/parser/tf/tf_util.cc View File

@@ -70,9 +70,13 @@ TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, c
} }


bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) { bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) {
if (str_view == nullptr || value == nullptr) {
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return false;
}
if (str_view == nullptr) {
*value = 0; *value = 0;
MS_LOG(ERROR) << "str_view or value is nullptr";
MS_LOG(ERROR) << "str_view is nullptr";
return false; return false;
} }
auto data = str_view->data(); auto data = str_view->data();


+ 1
- 1
mindspore/lite/tools/optimizer/graph/functionalize_cond.h View File

@@ -36,7 +36,7 @@ class FunctionalizeCond {
public: public:
FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {} FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {}


~FunctionalizeCond() = default;
virtual ~FunctionalizeCond() = default;


STATUS Process(); STATUS Process();




+ 1
- 1
mindspore/lite/tools/optimizer/graph/functionalize_while.h View File

@@ -35,7 +35,7 @@ class FunctionalizeWhile {
FunctionalizeWhile(std::vector<AnfNodePtr> node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg) FunctionalizeWhile(std::vector<AnfNodePtr> node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg)
: node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {} : node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {}


~FunctionalizeWhile() = default;
virtual ~FunctionalizeWhile() = default;


// while // while
STATUS BuildWhileNode(); STATUS BuildWhileNode();


Loading…
Cancel
Save