Browse Source

enhance: verify path and return value

tags/v1.6.0
jonyguo 4 years ago
parent
commit
ffa2a9ee85
4 changed files with 47 additions and 84 deletions
  1. +7
    -18
      mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc
  2. +9
    -31
      mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc
  3. +20
    -23
      mindspore/core/utils/file_utils.cc
  4. +11
    -12
      tests/st/dataset/test_chinese_path_on_windows.py

+ 7
- 18
mindspore/ccsrc/minddata/dataset/text/kernels/sentence_piece_tokenizer_op.cc View File

@@ -18,6 +18,7 @@
#include <memory>
#include <vector>

#include "utils/file_utils.h"
#include "utils/ms_utils.h"
#include "minddata/dataset/util/path.h"

@@ -86,26 +87,14 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s
}

Status SentencePieceTokenizerOp::GetModelRealPath(const std::string &model_path, const std::string &filename) {
char real_path[PATH_MAX] = {0};
if (file_path_.size() >= PATH_MAX) {
auto realpath = FileUtils::GetRealPath(model_path.data());
if (!realpath.has_value()) {
RETURN_STATUS_UNEXPECTED(
"SentencePieceTokenizer: Sentence piece model path is invalid for path length longer than 4096.");
"SentencePieceTokenizer: Sentence piece model path is not existed or permission denied. Model path: " +
model_path);
}
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(real_path, common::SafeCStr(model_path), PATH_MAX) == nullptr) {
RETURN_STATUS_UNEXPECTED(
"SentencePieceTokenizer: Sentence piece model path is invalid for path length longer than 4096.");
}
#else
if (realpath(common::SafeCStr(model_path), real_path) == nullptr) {
RETURN_STATUS_UNEXPECTED(
"SentencePieceTokenizer: "
"Sentence piece model path: " +
model_path + " is not existed or permission denied.");
}
#endif
std::string abs_path = real_path;
file_path_ = (Path(abs_path) / Path(filename)).ToString();

file_path_ = (Path(realpath.value()) / Path(filename)).ToString();
return Status::OK();
}



+ 9
- 31
mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc View File

@@ -60,38 +60,16 @@ bool ValidateFieldName(const std::string &str) {

Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr) {
RETURN_UNEXPECTED_IF_NULL(fn_ptr);
char real_path[PATH_MAX] = {0};
char buf[PATH_MAX] = {0};
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to call securec func [strncpy_s], path: " + path);
}
char tmp[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the realpath of mindrecord files. Please check file path: " +
std::string(buf));
}
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
MS_LOG(DEBUG) << "Succeed to get realpath: " << common::SafeCStr(path) << ".";
}
#else
if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the realpath of mindrecord files. Please check file path: " +
std::string(buf));
}
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
MS_LOG(DEBUG) << "Succeed to get realpath: " << common::SafeCStr(path) << ".";
}
#endif
std::string s = real_path;
size_t i = s.rfind(kPathSeparator, s.length());
if (i != std::string::npos) {
if (i + 1 < s.size()) {
*fn_ptr = std::make_shared<std::string>(s.substr(i + 1));
return Status::OK();
}

std::optional<std::string> prefix_path;
std::optional<std::string> file_name;
FileUtils::SplitDirAndFileName(path, &prefix_path, &file_name);
if (!file_name.has_value()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the filename of mindrecord file. Please check file path: " +
path);
}
*fn_ptr = std::make_shared<std::string>(s);
*fn_ptr = std::make_shared<std::string>(file_name.value());

return Status::OK();
}



+ 20
- 23
mindspore/core/utils/file_utils.cc View File

@@ -163,12 +163,14 @@ std::string FileUtils::UTF_8ToGB2312(const char *text) {
out = text;
return out;
}
char buf[4];
char buf[4] = {0};
int len = strlen(text);
char *new_text = const_cast<char *>(text);
char *rst = new char[len + (len >> 2) + 2];
memset_s(buf, 4, 0, 4);
memset_s(rst, len + (len >> 2) + 2, 0, len + (len >> 2) + 2);
std::unique_ptr<char[]> rst(new char[len + (len >> 2) + 2]);
auto ret2 = memset_s(rst.get(), len + (len >> 2) + 2, 0, len + (len >> 2) + 2);
if (ret2 != 0) {
return "";
}

int i = 0;
int j = 0;
@@ -191,9 +193,7 @@ std::string FileUtils::UTF_8ToGB2312(const char *text) {
}

rst[j] = '\0';
out = rst;
delete[] rst;
rst = nullptr;
out = rst.get();
return out;
}

@@ -208,24 +208,21 @@ std::string FileUtils::GB2312ToUTF_8(const char *gb2312) {
}

int len = MultiByteToWideChar(CP_ACP, 0, gb2312, -1, nullptr, 0);
wchar_t *wstr = new wchar_t[len + 1];
memset_s(wstr, len + 1, 0, len + 1);
MultiByteToWideChar(CP_ACP, 0, gb2312, -1, wstr, len);
len = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, nullptr, 0, nullptr, nullptr);

char *str = new char[len + 1];
memset_s(str, len + 1, 0, len + 1);
WideCharToMultiByte(CP_UTF8, 0, wstr, -1, str, len, nullptr, nullptr);
std::string str_temp(str);

if (wstr != nullptr) {
delete[] wstr;
wstr = nullptr;
std::unique_ptr<wchar_t[]> wstr(new wchar_t[len + 1]);
auto ret = memset_s(wstr.get(), len + 1, 0, len + 1);
if (ret != 0) {
return "";
}
if (str != nullptr) {
delete[] str;
str = nullptr;
MultiByteToWideChar(CP_ACP, 0, gb2312, -1, wstr.get(), len);
len = WideCharToMultiByte(CP_UTF8, 0, wstr.get(), -1, nullptr, 0, nullptr, nullptr);

std::unique_ptr<char[]> str(new char[len + 1]);
auto ret2 = memset_s(str.get(), len + 1, 0, len + 1);
if (ret2 != 0) {
return "";
}
WideCharToMultiByte(CP_UTF8, 0, wstr.get(), -1, str.get(), len, nullptr, nullptr);
std::string str_temp(str.get());

return str_temp;
}


+ 11
- 12
tests/st/dataset/test_chinese_path_on_windows.py View File

@@ -25,12 +25,6 @@ import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore.mindrecord import FileWriter, SUCCESS

FILES_NUM = 4
CV_MINDRECORD_FILE = "../data/test.mindrecord"
CV_DIR_NAME_CN = "../data/数据集/train/"
FILE_NAME = "test.mindrecord"
FILE_NAME2 = "./训练集/test.mindrecord"

def add_and_remove_cv_file(mindrecord):
"""add/remove cv file"""
try:
@@ -92,6 +86,11 @@ def test_chinese_path_on_windows():
Description: None
Expectation: raise axception
"""
mindrecord_file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
cv_mindrecord_file = "../data/" + mindrecord_file_name
cv_dir_name_cn = "../data/数据集/train/"
file_name = mindrecord_file_name
file_name2 = "./训练集/" + mindrecord_file_name

if platform.system().lower() != "windows":
pass
@@ -100,7 +99,7 @@ def test_chinese_path_on_windows():

# current dir in english, mindrecord path in english
dir_path = "./"
mindrecord_path = CV_MINDRECORD_FILE
mindrecord_path = cv_mindrecord_file

add_and_remove_cv_file(dir_path + mindrecord_path)

@@ -112,7 +111,7 @@ def test_chinese_path_on_windows():

# current dir in english, mindrecord path in chinese
dir_path = "./"
mindrecord_path = CV_DIR_NAME_CN + "/" + FILE_NAME
mindrecord_path = cv_dir_name_cn + "/" + file_name

add_and_remove_cv_file(dir_path + mindrecord_path)

@@ -123,8 +122,8 @@ def test_chinese_path_on_windows():
add_and_remove_cv_file(dir_path + mindrecord_path)

# current dir in chinese, mindrecord path in english
dir_path = CV_DIR_NAME_CN
mindrecord_path = FILE_NAME
dir_path = cv_dir_name_cn
mindrecord_path = file_name

add_and_remove_cv_file(dir_path + mindrecord_path)

@@ -135,8 +134,8 @@ def test_chinese_path_on_windows():
add_and_remove_cv_file(dir_path + mindrecord_path)

# current dir in chinese, mindrecord path in chinese
dir_path = CV_DIR_NAME_CN
mindrecord_path = FILE_NAME2
dir_path = cv_dir_name_cn
mindrecord_path = file_name2

add_and_remove_cv_file(dir_path + mindrecord_path)



Loading…
Cancel
Save