| @@ -15,10 +15,10 @@ | |||||
| """test_occlusion_sensitivity""" | """test_occlusion_sensitivity""" | ||||
| import pytest | import pytest | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import nn | |||||
| from mindspore import nn, context | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.nn.metrics import OcclusionSensitivity | from mindspore.nn.metrics import OcclusionSensitivity | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| class DenseNet(nn.Cell): | class DenseNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -96,6 +96,7 @@ def test_on_momentum(): | |||||
| def test_data_parallel_with_cast(): | def test_data_parallel_with_cast(): | ||||
| """test_data_parallel_with_cast""" | """test_data_parallel_with_cast""" | ||||
| context.set_context(device_target='Ascend') | |||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=8) | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=8) | ||||
| predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | ||||
| @@ -281,6 +281,7 @@ def test_same_primal_used_by_multi_j(): | |||||
| def test_same_primal_used_by_multi_j_with_monad1(): | def test_same_primal_used_by_multi_j_with_monad1(): | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| class AdamNet(nn.Cell): | class AdamNet(nn.Cell): | ||||
| def __init__(self, var, m, v): | def __init__(self, var, m, v): | ||||
| super(AdamNet, self).__init__() | super(AdamNet, self).__init__() | ||||
| @@ -322,6 +323,7 @@ def test_same_primal_used_by_multi_j_with_monad1(): | |||||
| def test_same_primal_used_by_multi_j_with_monad2(): | def test_same_primal_used_by_multi_j_with_monad2(): | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| class AdamNet(nn.Cell): | class AdamNet(nn.Cell): | ||||
| def __init__(self, var, m, v): | def __init__(self, var, m, v): | ||||
| super(AdamNet, self).__init__() | super(AdamNet, self).__init__() | ||||
| @@ -194,6 +194,7 @@ def test_compile_f16_model_train(): | |||||
| def test_compile_f16_model_train_fixed(): | def test_compile_f16_model_train_fixed(): | ||||
| context.set_context(device_target='Ascend') | |||||
| dataset_types = (np.float32, np.float32) | dataset_types = (np.float32, np.float32) | ||||
| dataset_shapes = ((16, 16), (16, 16)) | dataset_shapes = ((16, 16), (16, 16)) | ||||
| @@ -54,6 +54,7 @@ def test_log_setlevel(): | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| # logger_instance = logger._get_logger() | # logger_instance = logger._get_logger() | ||||
| # del logger_instance | # del logger_instance | ||||
| _clear_logger(logger) | |||||
| loglevel = logger.get_level() | loglevel = logger.get_level() | ||||
| log_str = 'print debug informations' | log_str = 'print debug informations' | ||||
| logger.debug("5 test log message debug:%s", log_str) | logger.debug("5 test log message debug:%s", log_str) | ||||
| @@ -147,6 +147,7 @@ def test_compile_model_train_O2(): | |||||
| def test_compile_model_train_O2_parallel(): | def test_compile_model_train_O2_parallel(): | ||||
| dataset_types = (np.float32, np.float32) | dataset_types = (np.float32, np.float32) | ||||
| dataset_shapes = ((16, 16), (16, 16)) | dataset_shapes = ((16, 16), (16, 16)) | ||||
| context.set_context(device_target='Ascend') | |||||
| context.set_auto_parallel_context( | context.set_auto_parallel_context( | ||||
| global_rank=0, device_num=8, | global_rank=0, device_num=8, | ||||
| gradients_mean=True, parameter_broadcast=True, | gradients_mean=True, parameter_broadcast=True, | ||||
| @@ -89,10 +89,10 @@ def test_dataset_iter_ge(): | |||||
| @pytest.mark.skipif('context.get_context("enable_ge")') | @pytest.mark.skipif('context.get_context("enable_ge")') | ||||
| def test_dataset_iter_ms_loop_sink(): | def test_dataset_iter_ms_loop_sink(): | ||||
| context.set_context(device_target='Ascend', mode=context.GRAPH_MODE) | |||||
| GlobalComm.CHECK_ENVS = False | GlobalComm.CHECK_ENVS = False | ||||
| init("hccl") | init("hccl") | ||||
| GlobalComm.CHECK_ENVS = True | GlobalComm.CHECK_ENVS = True | ||||
| context.set_context(enable_loop_sink=True) | |||||
| dataset = get_dataset(32) | dataset = get_dataset(32) | ||||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | ||||
| count = 0 | count = 0 | ||||
| @@ -105,9 +105,9 @@ def test_dataset_iter_ms_loop_sink(): | |||||
| @pytest.mark.skipif('context.get_context("enable_ge")') | @pytest.mark.skipif('context.get_context("enable_ge")') | ||||
| def test_dataset_iter_ms(): | def test_dataset_iter_ms(): | ||||
| context.set_context(device_target='Ascend', mode=context.GRAPH_MODE) | |||||
| GlobalComm.CHECK_ENVS = False | GlobalComm.CHECK_ENVS = False | ||||
| init("hccl") | init("hccl") | ||||
| GlobalComm.CHECK_ENVS = True | GlobalComm.CHECK_ENVS = True | ||||
| context.set_context(enable_loop_sink=False) | |||||
| dataset = get_dataset(32) | dataset = get_dataset(32) | ||||
| DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | ||||
| @@ -112,6 +112,7 @@ def test_multiple_argument(): | |||||
| def test_train_feed_mode(test_with_simu): | def test_train_feed_mode(test_with_simu): | ||||
| """ test_train_feed_mode """ | """ test_train_feed_mode """ | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| dataset = get_dataset() | dataset = get_dataset() | ||||
| model = get_model() | model = get_model() | ||||
| if test_with_simu: | if test_with_simu: | ||||
| @@ -162,6 +163,7 @@ class TestGraphMode: | |||||
| def test_train_minddata_graph_mode(self, test_with_simu): | def test_train_minddata_graph_mode(self, test_with_simu): | ||||
| """ test_train_minddata_graph_mode """ | """ test_train_minddata_graph_mode """ | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| # pylint: disable=unused-argument | # pylint: disable=unused-argument | ||||
| dataset_types = (np.float32, np.float32) | dataset_types = (np.float32, np.float32) | ||||
| dataset_shapes = ((32, 3, 224, 224), (32, 3)) | dataset_shapes = ((32, 3, 224, 224), (32, 3)) | ||||
| @@ -193,6 +195,7 @@ class CallbackTest(Callback): | |||||
| def test_train_callback(test_with_simu): | def test_train_callback(test_with_simu): | ||||
| """ test_train_callback """ | """ test_train_callback """ | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| dataset = get_dataset() | dataset = get_dataset() | ||||
| model = get_model() | model = get_model() | ||||
| callback = CallbackTest() | callback = CallbackTest() | ||||
| @@ -88,33 +88,6 @@ def test_model_checkpoint_prefix_invalid(): | |||||
| ModelCheckpoint(prefix="ckpt_2", directory="./test_files") | ModelCheckpoint(prefix="ckpt_2", directory="./test_files") | ||||
| def test_save_checkpoint(): | |||||
| """Test save checkpoint.""" | |||||
| train_config = CheckpointConfig( | |||||
| save_checkpoint_steps=16, | |||||
| save_checkpoint_seconds=0, | |||||
| keep_checkpoint_max=5, | |||||
| keep_checkpoint_per_n_minutes=0) | |||||
| cb_params = _InternalCallbackParam() | |||||
| net = Net() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| network_ = WithLossCell(net, loss) | |||||
| _train_network = TrainOneStepCell(network_, optim) | |||||
| cb_params.train_network = _train_network | |||||
| cb_params.epoch_num = 10 | |||||
| cb_params.cur_epoch_num = 5 | |||||
| cb_params.cur_step_num = 0 | |||||
| cb_params.batch_num = 32 | |||||
| ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config) | |||||
| run_context = RunContext(cb_params) | |||||
| ckpoint_cb.begin(run_context) | |||||
| ckpoint_cb.step_end(run_context) | |||||
| if os.path.exists('./test_files/test_ckpt-model.pkl'): | |||||
| os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE) | |||||
| os.remove('./test_files/test_ckpt-model.pkl') | |||||
| def test_loss_monitor_sink_mode(): | def test_loss_monitor_sink_mode(): | ||||
| """Test loss monitor sink mode.""" | """Test loss monitor sink mode.""" | ||||
| cb_params = _InternalCallbackParam() | cb_params = _InternalCallbackParam() | ||||
| @@ -153,8 +126,35 @@ def test_loss_monitor_normal_mode(): | |||||
| loss_cb.end(run_context) | loss_cb.end(run_context) | ||||
| def test_chg_ckpt_file_name_if_same_exist(): | |||||
| """Test chg ckpt file name if same exist.""" | |||||
| def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist(): | |||||
| """ | |||||
| Feature: Save checkpoint and check if there is a file with the same name. | |||||
| Description: Save checkpoint and check if there is a file with the same name. | |||||
| Expectation: Checkpoint is saved and checking is successful. | |||||
| """ | |||||
| train_config = CheckpointConfig( | |||||
| save_checkpoint_steps=16, | |||||
| save_checkpoint_seconds=0, | |||||
| keep_checkpoint_max=5, | |||||
| keep_checkpoint_per_n_minutes=0) | |||||
| cb_params = _InternalCallbackParam() | |||||
| net = Net() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| network_ = WithLossCell(net, loss) | |||||
| _train_network = TrainOneStepCell(network_, optim) | |||||
| cb_params.train_network = _train_network | |||||
| cb_params.epoch_num = 10 | |||||
| cb_params.cur_epoch_num = 5 | |||||
| cb_params.cur_step_num = 0 | |||||
| cb_params.batch_num = 32 | |||||
| ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config) | |||||
| run_context = RunContext(cb_params) | |||||
| ckpoint_cb.begin(run_context) | |||||
| ckpoint_cb.step_end(run_context) | |||||
| if os.path.exists('./test_files/test_ckpt-model.pkl'): | |||||
| os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE) | |||||
| os.remove('./test_files/test_ckpt-model.pkl') | |||||
| _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") | _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") | ||||
| @@ -122,8 +122,23 @@ def test_save_checkpoint_for_list(): | |||||
| save_checkpoint(parameter_list, ckpt_file_name) | save_checkpoint(parameter_list, ckpt_file_name) | ||||
| def test_save_checkpoint_for_list_append_info(): | |||||
| """ test save_checkpoint for list append info""" | |||||
| def test_load_checkpoint_error_filename(): | |||||
| """ | |||||
| Feature: Load checkpoint. | |||||
| Description: Load checkpoint with error filename. | |||||
| Expectation: Raise value error for error filename. | |||||
| """ | |||||
| ckpt_file_name = 1 | |||||
| with pytest.raises(ValueError): | |||||
| load_checkpoint(ckpt_file_name) | |||||
| def test_save_checkpoint_for_list_append_info_and_load_checkpoint(): | |||||
| """ | |||||
| Feature: Save checkpoint for list append info and load checkpoint. | |||||
| Description: Save checkpoint for list append info and load checkpoint with list append info. | |||||
| Expectation: Checkpoint for list append info can be saved and reloaded. | |||||
| """ | |||||
| parameter_list = [] | parameter_list = [] | ||||
| one_param = {} | one_param = {} | ||||
| param1 = {} | param1 = {} | ||||
| @@ -144,16 +159,6 @@ def test_save_checkpoint_for_list_append_info(): | |||||
| ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') | ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') | ||||
| save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict) | save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict) | ||||
| def test_load_checkpoint_error_filename(): | |||||
| ckpt_file_name = 1 | |||||
| with pytest.raises(ValueError): | |||||
| load_checkpoint(ckpt_file_name) | |||||
| def test_load_checkpoint(): | |||||
| ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') | |||||
| par_dict = load_checkpoint(ckpt_file_name) | par_dict = load_checkpoint(ckpt_file_name) | ||||
| assert len(par_dict) == 6 | assert len(par_dict) == 6 | ||||