Browse Source

!15711 modify clean code for r1.2

From: @changzherui
Reviewed-by: @zhoufeng54
Signed-off-by:
pull/15711/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
0fa0d89c7b
3 changed files with 17 additions and 10 deletions
  1. +13
    -8
      mindspore/core/load_mindir/load_model.cc
  2. +2
    -0
      mindspore/core/load_mindir/load_model.h
  3. +2
    -2
      tests/ut/python/runtest.sh

+ 13
- 8
mindspore/core/load_mindir/load_model.cc View File

@@ -76,10 +76,14 @@ std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
return buf;
}

bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
bool get_all_files(const std::string &dir_in, std::vector<std::string> *files, int max_dep) {
if (dir_in.empty()) {
return false;
}
max_dep--;
if (max_dep < 0) {
MS_LOG(EXCEPTION) << "The file is greater than " << max_dep << ", exit the program.";
}
struct stat s;
int ret = stat(dir_in.c_str(), &s);
if (ret != 0) {
@@ -104,7 +108,7 @@ bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
return false;
}
if (S_ISDIR(st.st_mode)) {
ret = get_all_files(name, files);
ret = get_all_files(name, files, max_dep);
if (!ret) {
MS_LOG(ERROR) << "Get files failed, ret is : " << ret;
return false;
@@ -118,8 +122,6 @@ bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
return true;
}

int endsWith(string s, string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }

std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite) {
const char *file_path = reinterpret_cast<const char *>(file_name.c_str());
char abs_path_buff[PATH_MAX];
@@ -143,17 +145,20 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
return nullptr;
}

char tag_str[] = "_graph.mindir";
std::string str(abs_path_buff + strlen(abs_path_buff) - strlen(tag_str), abs_path_buff + strlen(abs_path_buff));
// Load parameter into graph
if (endsWith(abs_path_buff, "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
if (str == tag_str && origin_model.graph().parameter_size() == 0) {
int path_len = strlen(abs_path_buff) - strlen("graph.mindir");
memcpy_s(abs_path, sizeof(abs_path), abs_path_buff, path_len);
abs_path[path_len] = '\0';
snprintf(abs_path + path_len, sizeof(abs_path), "variables");
char var[] = "variables";
memcpy_s(abs_path + path_len, PATH_MAX, var, strlen(var));
abs_path[path_len + strlen(var)] = '\0';
std::ifstream ifs(abs_path);
if (ifs.good()) {
MS_LOG(DEBUG) << "MindIR file has variables path, load parameter into graph.";
string path = abs_path;
get_all_files(path, &files);
get_all_files(path, &files, MAX_FILE_DEPTH_RECURSION);
} else {
MS_LOG(ERROR) << "Load graph's variable folder failed, please check the correctness of variable folder.";
return nullptr;


+ 2
- 0
mindspore/core/load_mindir/load_model.h View File

@@ -23,6 +23,8 @@
#include "proto/mind_ir.pb.h"
#include "ir/func_graph.h"

const int MAX_FILE_DEPTH_RECURSION = 1000;

namespace mindspore {
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false);
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);


+ 2
- 2
tests/ut/python/runtest.sh View File

@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
CURRPATH=$(cd "$(dirname $0)"; pwd)
CURRPATH=$(cd "$(dirname $0)" || exit; pwd)
IGNORE_EXEC="--ignore=$CURRPATH/exec"
PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd)
PROJECT_PATH=$(cd ${CURRPATH}/../../.. || exit; pwd)

if [ $BUILD_PATH ];then
echo "BUILD_PATH = $BUILD_PATH"


Loading…
Cancel
Save