Browse Source

Throw error when load config failed

tags/v0.5.0-beta
ms_yan 5 years ago
parent
commit
8d1dae46ac
4 changed files with 4 additions and 15 deletions
  1. +1
    -1
      mindspore/ccsrc/dataset/api/python_bindings.cc
  2. +3
    -6
      mindspore/ccsrc/dataset/core/config_manager.cc
  3. +0
    -3
      tests/ut/data/dataset/declient_filter.cfg
  4. +0
    -5
      tests/ut/python/dataset/test_filterop.py

+ 1
- 1
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -276,7 +276,7 @@ void bindTensor(py::module *m) {
.def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_seed", &ConfigManager::seed)
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });

(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
.def(py::init([](py::array arr) {


+ 3
- 6
mindspore/ccsrc/dataset/core/config_manager.cc View File

@@ -48,7 +48,7 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
Status ConfigManager::LoadFile(const std::string &settingsFile) {
Status rc;
if (!Path(settingsFile).Exists()) {
RETURN_STATUS_UNEXPECTED("File is not found");
RETURN_STATUS_UNEXPECTED("File is not found.");
}
// Some settings are mandatory, others are not (with default). If a setting
// is optional it will set a default value if the config is missing from the file.
@@ -59,14 +59,11 @@ Status ConfigManager::LoadFile(const std::string &settingsFile) {
rc = FromJson(js);
} catch (const nlohmann::json::type_error &e) {
std::ostringstream ss;
ss << "Client settings failed to load:\n" << e.what();
ss << "Client file failed to load:\n" << e.what();
std::string err_msg = ss.str();
RETURN_STATUS_UNEXPECTED(err_msg);
} catch (const std::exception &err) {
std::ostringstream ss;
ss << "Client settings failed to load:\n" << err.what();
std::string err_msg = ss.str();
RETURN_STATUS_UNEXPECTED(err_msg);
RETURN_STATUS_UNEXPECTED("Client file failed to load.");
}
return rc;
}


+ 0
- 3
tests/ut/data/dataset/declient_filter.cfg View File

@@ -1,3 +0,0 @@
{
"rowsPerBuffer": 10,
}

+ 0
- 5
tests/ut/python/dataset/test_filterop.py View File

@@ -390,7 +390,6 @@ def filter_func_Partial_0(col1, col2, col3, col4):

# test with row_data_buffer > 1
def test_filter_by_generator_Partial0():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
dataset_zip = ds.zip((dataset1, dataset2))
@@ -404,7 +403,6 @@ def test_filter_by_generator_Partial0():

# test with row_data_buffer > 1
def test_filter_by_generator_Partial1():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
dataset_zip = ds.zip((dataset1, dataset2))
@@ -419,7 +417,6 @@ def test_filter_by_generator_Partial1():

# test with row_data_buffer > 1
def test_filter_by_generator_Partial2():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])

@@ -454,7 +451,6 @@ def generator_big(maxid=20):

# test with row_data_buffer > 1
def test_filter_by_generator_Partial():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset = ds.GeneratorDataset(source=generator_mc(99), column_names=["col1", "col2"])
dataset_s = dataset.shuffle(4)
dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1)
@@ -473,7 +469,6 @@ def filter_func_cifar(col1, col2):
# test with cifar10
def test_filte_case_dataset_cifar10():
DATA_DIR_10 = "../data/dataset/testCifar10Data"
ds.config.load('../data/dataset/declient_filter.cfg')
dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False)
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1)
for item in dataset_f1.create_dict_iterator():


Loading…
Cancel
Save