Browse Source

split_codedex

pull/15567/head
yefeng 4 years ago
parent
commit
1cc60732f7
6 changed files with 58 additions and 28 deletions
  1. +3
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc
  2. +43
    -24
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  3. +4
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  4. +6
    -2
      mindspore/lite/tools/converter/parser/tf/tf_util.cc
  5. +1
    -1
      mindspore/lite/tools/optimizer/graph/functionalize_cond.h
  6. +1
    -1
      mindspore/lite/tools/optimizer/graph/functionalize_while.h

+ 3
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_pad_parser.cc View File

@@ -45,6 +45,9 @@ ops::PrimitiveC *OnnxPadParser::Parse(const onnx::GraphProto &onnx_graph, const
padding_mode = mindspore::PaddingMode::REFLECT;
} else if (mode == "edge") {
padding_mode = mindspore::PaddingMode::SYMMETRIC;
} else {
MS_LOG(ERROR) << "Unsupported pad mode: " << mode;
return nullptr;
}
prim->set_padding_mode(padding_mode);
}


+ 43
- 24
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -40,13 +40,14 @@ bool IsTensorListOp(const AnfNodePtr &anf_node) {
opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve);
}

AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
int index = 0) {
AnfNodePtr ret = nullptr;
auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name);
if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) {
ret = anf_node_map.at(flat_anf_name);
} else if (anf_node_map.find(name + ":0") != anf_node_map.end()) {
ret = anf_node_map.at(flat_anf_name + ":0");
} else if (anf_node_map.find(name + ":" + to_string(index)) != anf_node_map.end()) {
ret = anf_node_map.at(flat_anf_name + ":" + to_string(index));
}
return ret;
}
@@ -901,6 +902,10 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
MS_LOG(ERROR) << "node " << op_type << " parser failed";
return RET_ERROR;
}
node_output_num_[node_def.name()] = output_size;
for (int i = 0; i < output_size; i++) {
node_output_num_[node_def.name() + ":" + to_string(i)] = 1;
}
auto value_node = NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitiveC));
if (value_node == nullptr) {
MS_LOG(ERROR) << "value_node is nullptr";
@@ -915,6 +920,32 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
// control_depends are not processed currently
auto anf_node = func_graph_ptr->NewCNode(inputs);
anf_node->set_fullname_with_scope(node_def.name());
status = ProcessControlFlowOp(anf_node, op_type, node_def);
if (status != RET_OK) {
MS_LOG(ERROR) << "ProcessControlFlowOp failed.";
return RET_ERROR;
}

if (!input_name_not_found.empty()) {
RecordNullInput(anf_node, input_name_not_found);
}

status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
return status;
}

status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
return status;
}
return status;
}

STATUS TFModelParser::ProcessControlFlowOp(CNodePtr anf_node, const string op_type,
const tensorflow::NodeDef &node_def) {
if (op_type == "StatelessWhile" || op_type == "While") {
MS_LOG(INFO) << "find while node:" << node_def.name();
tensorflow::AttrValue attr_value;
@@ -944,21 +975,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
MS_LOG(DEBUG) << "parse else name:" << else_name;
}
}

if (!input_name_not_found.empty()) {
RecordNullInput(anf_node, input_name_not_found);
}

status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
}

status = ConvertQuantParams(inputs.size() - 1, output_size, primitiveC);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
}
return status;
return RET_OK;
}

STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size,
@@ -994,13 +1011,15 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
auto it = all_node_inputs.find(pair.first);
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node: " << origin_name;
return RET_ERROR;
for (int i = 0; i < node_output_num_[origin_name]; i++) {
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_, i);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node: " << origin_name;
return RET_ERROR;
}
output_nodes.push_back(anf_node);
graph_output_names_.push_back(anf_node->fullname_with_scope());
}
output_nodes.push_back(anf_node);
graph_output_names_.push_back(anf_node->fullname_with_scope());
}
}
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);


+ 4
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -60,6 +60,9 @@ class TFModelParser : public ModelParser {
STATUS ConvertOps(const tensorflow::NodeDef &node_def,
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map);

STATUS ProcessControlFlowOp(CNodePtr anf_node, const string op_type, const tensorflow::NodeDef &node_def);

STATUS ConvertRootGraphOutputs();

STATUS ConvertSubgraph();
@@ -95,6 +98,7 @@ class TFModelParser : public ModelParser {
std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{};
std::vector<std::string> while_cond_branch_name_;
std::vector<std::string> if_then_branch_name_;
std::unordered_map<std::string, int> node_output_num_;
};
} // namespace lite
} // namespace mindspore


+ 6
- 2
mindspore/lite/tools/converter/parser/tf/tf_util.cc View File

@@ -70,9 +70,13 @@ TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, c
}

bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) {
if (str_view == nullptr || value == nullptr) {
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return false;
}
if (str_view == nullptr) {
*value = 0;
MS_LOG(ERROR) << "str_view or value is nullptr";
MS_LOG(ERROR) << "str_view is nullptr";
return false;
}
auto data = str_view->data();


+ 1
- 1
mindspore/lite/tools/optimizer/graph/functionalize_cond.h View File

@@ -36,7 +36,7 @@ class FunctionalizeCond {
public:
FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {}

~FunctionalizeCond() = default;
virtual ~FunctionalizeCond() = default;

STATUS Process();



+ 1
- 1
mindspore/lite/tools/optimizer/graph/functionalize_while.h View File

@@ -35,7 +35,7 @@ class FunctionalizeWhile {
FunctionalizeWhile(std::vector<AnfNodePtr> node_cluster, const CNodePtr &loop_cond_node, FuncGraphPtr fg)
: node_cluster_(node_cluster), loop_cond_node_(loop_cond_node), fg_(fg) {}

~FunctionalizeWhile() = default;
virtual ~FunctionalizeWhile() = default;

// while
STATUS BuildWhileNode();


Loading…
Cancel
Save