Browse Source

modify codeDEX_C for Master

pull/15152/head
lilei 4 years ago
parent
commit
c35466ffee
2 changed files with 32 additions and 7 deletions
  1. +11
    -2
      mindspore/core/load_mindir/anf_model_parser.cc
  2. +21
    -5
      mindspore/core/load_mindir/load_model.cc

+ 11
- 2
mindspore/core/load_mindir/anf_model_parser.cc View File

@@ -15,6 +15,7 @@
*/ */


#include "load_mindir/anf_model_parser.h" #include "load_mindir/anf_model_parser.h"
#include <limits.h>
#include <functional> #include <functional>
#include <map> #include <map>
#include <memory> #include <memory>
@@ -799,6 +800,10 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr) { const mind_ir::GraphProto &importProto, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(outputFuncGraph); MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(cnode_ptr);
if (importProto.output_size() < 0 || importProto.output_size() > INT_MAX) {
MS_LOG(ERROR) << "importProto.output_size is : " << importProto.output_size();
return false;
}
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
if (importProto.output_size() > 1) { if (importProto.output_size() > 1) {
inputs.clear(); inputs.clear();
@@ -836,6 +841,10 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto) { const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph); MS_EXCEPTION_IF_NULL(outputFuncGraph);
if (importProto.node_size() < 0 || importProto.node_size() > INT_MAX) {
MS_LOG(ERROR) << "importProto.node_size is : " << importProto.node_size();
return false;
}
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr; CNodePtr cnode_ptr = nullptr;
for (int i = 0; i < importProto.node_size(); ++i) { for (int i = 0; i < importProto.node_size(); ++i) {
@@ -843,14 +852,14 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
const std::string &node_type = node_proto.op_type(); const std::string &node_type = node_proto.op_type();
if (node_type == kConstantValueNode) { if (node_type == kConstantValueNode) {
if (!BuildValueNodeForFuncGraph(node_proto)) { if (!BuildValueNodeForFuncGraph(node_proto)) {
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: " << i;
return false; return false;
} }
continue; continue;
} }
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto);
if (cnode_ptr == nullptr) { if (cnode_ptr == nullptr) {
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: " << i;
return false; return false;
} }
} }


+ 21
- 5
mindspore/core/load_mindir/load_model.cc View File

@@ -81,22 +81,34 @@ bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
return false; return false;
} }
struct stat s; struct stat s;
stat(dir_in.c_str(), &s);
int ret = stat(dir_in.c_str(), &s);
if (ret != 0) {
MS_LOG(ERROR) << "stat error, ret is : " << ret;
return false;
}
if (!S_ISDIR(s.st_mode)) { if (!S_ISDIR(s.st_mode)) {
return false; return false;
} }
DIR *open_dir = opendir(dir_in.c_str()); DIR *open_dir = opendir(dir_in.c_str());
if (NULL == open_dir) { if (NULL == open_dir) {
std::exit(EXIT_FAILURE);
MS_LOG(EXCEPTION) << "open dir " << dir_in.c_str() << " failed";
} }
dirent *p = nullptr; dirent *p = nullptr;
while ((p = readdir(open_dir)) != nullptr) { while ((p = readdir(open_dir)) != nullptr) {
struct stat st; struct stat st;
if (p->d_name[0] != '.') { if (p->d_name[0] != '.') {
std::string name = dir_in + std::string("/") + std::string(p->d_name); std::string name = dir_in + std::string("/") + std::string(p->d_name);
stat(name.c_str(), &st);
ret = stat(name.c_str(), &st);
if (ret != 0) {
MS_LOG(ERROR) << "stat error, ret is : " << ret;
return false;
}
if (S_ISDIR(st.st_mode)) { if (S_ISDIR(st.st_mode)) {
get_all_files(name, files);
ret = get_all_files(name, files);
if (!ret) {
MS_LOG(ERROR) << "Get files failed, ret is : " << ret;
return false;
}
} else if (S_ISREG(st.st_mode)) { } else if (S_ISREG(st.st_mode)) {
files->push_back(name); files->push_back(name);
} }
@@ -134,7 +146,7 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
// Load parameter into graph // Load parameter into graph
if (endsWith(abs_path_buff, "_graph.mindir") && origin_model.graph().parameter_size() == 0) { if (endsWith(abs_path_buff, "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
int path_len = strlen(abs_path_buff) - strlen("graph.mindir"); int path_len = strlen(abs_path_buff) - strlen("graph.mindir");
memcpy(abs_path, abs_path_buff, path_len);
memcpy_s(abs_path, sizeof(abs_path), abs_path_buff, path_len);
abs_path[path_len] = '\0'; abs_path[path_len] = '\0';
snprintf(abs_path + path_len, sizeof(abs_path), "variables"); snprintf(abs_path + path_len, sizeof(abs_path), "variables");
std::ifstream ifs(abs_path); std::ifstream ifs(abs_path);
@@ -157,6 +169,10 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
return nullptr; return nullptr;
} }


if (param_graph.parameter_size() < 0 || param_graph.parameter_size() > INT_MAX) {
MS_LOG(ERROR) << "param_graph.parameter_size() is : " << param_graph.parameter_size();
return nullptr;
}
for (int param_index = 0; param_index < param_graph.parameter_size(); param_index++) { for (int param_index = 0; param_index < param_graph.parameter_size(); param_index++) {
mind_ir::TensorProto *param_proto = mod_graph->add_parameter(); mind_ir::TensorProto *param_proto = mod_graph->add_parameter();
param_proto->set_name(param_graph.parameter(param_index).name()); param_proto->set_name(param_graph.parameter(param_index).name());


Loading…
Cancel
Save