| @@ -25,7 +25,7 @@ from mindspore._c_expression import typing | |||||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | ||||
| INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | ||||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | ||||
| check_columns, check_positive, check_pos_int32 | |||||
| check_columns, check_pos_int32 | |||||
| from . import datasets | from . import datasets | ||||
| from . import samplers | from . import samplers | ||||
| @@ -319,10 +319,9 @@ def check_generatordataset(method): | |||||
| # These two parameters appear together. | # These two parameters appear together. | ||||
| raise ValueError("num_shards and shard_id need to be passed in together") | raise ValueError("num_shards and shard_id need to be passed in together") | ||||
| if num_shards is not None: | if num_shards is not None: | ||||
| type_check(num_shards, (int,), "num_shards") | |||||
| check_positive(num_shards, "num_shards") | |||||
| check_pos_int32(num_shards, "num_shards") | |||||
| if shard_id >= num_shards: | if shard_id >= num_shards: | ||||
| raise ValueError("shard_id should be less than num_shards") | |||||
| raise ValueError("shard_id should be less than num_shards.") | |||||
| sampler = param_dict.get("sampler") | sampler = param_dict.get("sampler") | ||||
| if sampler is not None: | if sampler is not None: | ||||
| @@ -417,7 +416,7 @@ def check_bucket_batch_by_length(method): | |||||
| all_non_negative = all(item > 0 for item in bucket_boundaries) | all_non_negative = all(item > 0 for item in bucket_boundaries) | ||||
| if not all_non_negative: | if not all_non_negative: | ||||
| raise ValueError("bucket_boundaries cannot contain any negative numbers.") | |||||
| raise ValueError("bucket_boundaries must only contain positive numbers.") | |||||
| for i in range(len(bucket_boundaries) - 1): | for i in range(len(bucket_boundaries) - 1): | ||||
| if not bucket_boundaries[i + 1] > bucket_boundaries[i]: | if not bucket_boundaries[i + 1] > bucket_boundaries[i]: | ||||
| @@ -1044,7 +1043,8 @@ def check_numpyslicesdataset(method): | |||||
| data = param_dict.get("data") | data = param_dict.get("data") | ||||
| column_names = param_dict.get("column_names") | column_names = param_dict.get("column_names") | ||||
| if not data: | |||||
| raise ValueError("Argument data cannot be empty") | |||||
| type_check(data, (list, tuple, dict, np.ndarray), "data") | type_check(data, (list, tuple, dict, np.ndarray), "data") | ||||
| if isinstance(data, tuple): | if isinstance(data, tuple): | ||||
| type_check(data[0], (list, np.ndarray), "data[0]") | type_check(data[0], (list, np.ndarray), "data[0]") | ||||
| @@ -62,7 +62,8 @@ def check_from_file(method): | |||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, | [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, | ||||
| **kwargs) | **kwargs) | ||||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||||
| if special_tokens is not None: | |||||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||||
| type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) | type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) | ||||
| if vocab_size is not None: | if vocab_size is not None: | ||||
| check_value(vocab_size, (-1, INT32_MAX), "vocab_size") | check_value(vocab_size, (-1, INT32_MAX), "vocab_size") | ||||
| @@ -45,6 +45,7 @@ def test_bucket_batch_invalid_input(): | |||||
| bucket_boundaries = [1, 2, 3] | bucket_boundaries = [1, 2, 3] | ||||
| empty_bucket_boundaries = [] | empty_bucket_boundaries = [] | ||||
| invalid_bucket_boundaries = ["1", "2", "3"] | invalid_bucket_boundaries = ["1", "2", "3"] | ||||
| zero_start_bucket_boundaries = [0, 2, 3] | |||||
| negative_bucket_boundaries = [1, 2, -3] | negative_bucket_boundaries = [1, 2, -3] | ||||
| decreasing_bucket_boundaries = [3, 2, 1] | decreasing_bucket_boundaries = [3, 2, 1] | ||||
| non_increasing_bucket_boundaries = [1, 2, 2] | non_increasing_bucket_boundaries = [1, 2, 2] | ||||
| @@ -69,9 +70,13 @@ def test_bucket_batch_invalid_input(): | |||||
| _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) | _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) | ||||
| assert "bucket_boundaries should be a list of int" in str(info.value) | assert "bucket_boundaries should be a list of int" in str(info.value) | ||||
| with pytest.raises(ValueError) as info: | |||||
| _ = dataset.bucket_batch_by_length(column_names, zero_start_bucket_boundaries, bucket_batch_sizes) | |||||
| assert "bucket_boundaries must only contain positive numbers." in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) | _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) | ||||
| assert "bucket_boundaries cannot contain any negative numbers" in str(info.value) | |||||
| assert "bucket_boundaries must only contain positive numbers." in str(info.value) | |||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) | _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) | ||||
| @@ -108,7 +108,7 @@ def test_concatenate_op_type_mismatch(): | |||||
| with pytest.raises(RuntimeError) as error_info: | with pytest.raises(RuntimeError) as error_info: | ||||
| for _ in data: | for _ in data: | ||||
| pass | pass | ||||
| assert "Tensor types do not match" in repr(error_info.value) | |||||
| assert "Tensor types do not match" in str(error_info.value) | |||||
| def test_concatenate_op_type_mismatch2(): | def test_concatenate_op_type_mismatch2(): | ||||
| @@ -123,7 +123,7 @@ def test_concatenate_op_type_mismatch2(): | |||||
| with pytest.raises(RuntimeError) as error_info: | with pytest.raises(RuntimeError) as error_info: | ||||
| for _ in data: | for _ in data: | ||||
| pass | pass | ||||
| assert "Tensor types do not match" in repr(error_info.value) | |||||
| assert "Tensor types do not match" in str(error_info.value) | |||||
| def test_concatenate_op_incorrect_dim(): | def test_concatenate_op_incorrect_dim(): | ||||
| @@ -138,13 +138,13 @@ def test_concatenate_op_incorrect_dim(): | |||||
| with pytest.raises(RuntimeError) as error_info: | with pytest.raises(RuntimeError) as error_info: | ||||
| for _ in data: | for _ in data: | ||||
| pass | pass | ||||
| assert "Only 1D tensors supported" in repr(error_info.value) | |||||
| assert "Only 1D tensors supported" in str(error_info.value) | |||||
| def test_concatenate_op_wrong_axis(): | def test_concatenate_op_wrong_axis(): | ||||
| with pytest.raises(ValueError) as error_info: | with pytest.raises(ValueError) as error_info: | ||||
| data_trans.Concatenate(2) | data_trans.Concatenate(2) | ||||
| assert "only 1D concatenation supported." in repr(error_info.value) | |||||
| assert "only 1D concatenation supported." in str(error_info.value) | |||||
| def test_concatenate_op_negative_axis(): | def test_concatenate_op_negative_axis(): | ||||
| @@ -167,7 +167,7 @@ def test_concatenate_op_incorrect_input_dim(): | |||||
| with pytest.raises(ValueError) as error_info: | with pytest.raises(ValueError) as error_info: | ||||
| data_trans.Concatenate(0, prepend_tensor) | data_trans.Concatenate(0, prepend_tensor) | ||||
| assert "can only prepend 1D arrays." in repr(error_info.value) | |||||
| assert "can only prepend 1D arrays." in str(error_info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -12,12 +12,13 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | |||||
| import sys | |||||
| import pytest | import pytest | ||||
| import numpy as np | |||||
| import pandas as pd | |||||
| import mindspore.dataset as de | import mindspore.dataset as de | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| import mindspore.dataset.transforms.vision.c_transforms as vision | import mindspore.dataset.transforms.vision.c_transforms as vision | ||||
| import pandas as pd | |||||
| def test_numpy_slices_list_1(): | def test_numpy_slices_list_1(): | ||||
| @@ -173,6 +174,25 @@ def test_numpy_slices_distributed_sampler(): | |||||
| assert sum([1 for _ in ds]) == 2 | assert sum([1 for _ in ds]) == 2 | ||||
| def test_numpy_slices_distributed_shard_limit(): | |||||
| logger.info("Test Slicing a 1D list.") | |||||
| np_data = [1, 2, 3] | |||||
| num = sys.maxsize | |||||
| with pytest.raises(ValueError) as err: | |||||
| de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False) | |||||
| assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value) | |||||
| def test_numpy_slices_distributed_zero_shard(): | |||||
| logger.info("Test Slicing a 1D list.") | |||||
| np_data = [1, 2, 3] | |||||
| with pytest.raises(ValueError) as err: | |||||
| de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False) | |||||
| assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value) | |||||
| def test_numpy_slices_sequential_sampler(): | def test_numpy_slices_sequential_sampler(): | ||||
| logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") | logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") | ||||
| @@ -210,6 +230,15 @@ def test_numpy_slices_invalid_empty_column_names(): | |||||
| assert "column_names should not be empty" in str(err.value) | assert "column_names should not be empty" in str(err.value) | ||||
| def test_numpy_slices_invalid_empty_data_column(): | |||||
| logger.info("Test incorrect column_names input") | |||||
| np_data = [] | |||||
| with pytest.raises(ValueError) as err: | |||||
| de.NumpySlicesDataset(np_data, shuffle=False) | |||||
| assert "Argument data cannot be empty" in str(err.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_numpy_slices_list_1() | test_numpy_slices_list_1() | ||||
| test_numpy_slices_list_2() | test_numpy_slices_list_2() | ||||
| @@ -223,7 +252,10 @@ if __name__ == "__main__": | |||||
| test_numpy_slices_csv_dict() | test_numpy_slices_csv_dict() | ||||
| test_numpy_slices_num_samplers() | test_numpy_slices_num_samplers() | ||||
| test_numpy_slices_distributed_sampler() | test_numpy_slices_distributed_sampler() | ||||
| test_numpy_slices_distributed_shard_limit() | |||||
| test_numpy_slices_distributed_zero_shard() | |||||
| test_numpy_slices_sequential_sampler() | test_numpy_slices_sequential_sampler() | ||||
| test_numpy_slices_invalid_column_names_type() | test_numpy_slices_invalid_column_names_type() | ||||
| test_numpy_slices_invalid_column_names_string() | test_numpy_slices_invalid_column_names_string() | ||||
| test_numpy_slices_invalid_empty_column_names() | test_numpy_slices_invalid_empty_column_names() | ||||
| test_numpy_slices_invalid_empty_data_column() | |||||
| @@ -82,9 +82,9 @@ def test_fillop_error_handling(): | |||||
| data = data.map(input_columns=["col"], operations=fill_op) | data = data.map(input_columns=["col"], operations=fill_op) | ||||
| with pytest.raises(RuntimeError) as error_info: | with pytest.raises(RuntimeError) as error_info: | ||||
| for data_row in data: | |||||
| print(data_row) | |||||
| assert "Types do not match" in repr(error_info.value) | |||||
| for _ in data: | |||||
| pass | |||||
| assert "Types do not match" in str(error_info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard(): | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) | |||||
| with pytest.raises(Exception) as error_info: | with pytest.raises(Exception) as error_info: | ||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) | |||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -39,8 +39,27 @@ def test_on_tokenized_line(): | |||||
| res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], | res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], | ||||
| [11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32) | [11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32) | ||||
| for i, d in enumerate(data.create_dict_iterator()): | for i, d in enumerate(data.create_dict_iterator()): | ||||
| _ = (np.testing.assert_array_equal(d["text"], res[i]), i) | |||||
| np.testing.assert_array_equal(d["text"], res[i]) | |||||
| def test_on_tokenized_line_with_no_special_tokens(): | |||||
| data = ds.TextFileDataset("../data/dataset/testVocab/lines.txt", shuffle=False) | |||||
| jieba_op = text.JiebaTokenizer(HMM_FILE, MP_FILE, mode=text.JiebaMode.MP) | |||||
| with open(VOCAB_FILE, 'r') as f: | |||||
| for line in f: | |||||
| word = line.split(',')[0] | |||||
| jieba_op.add_word(word) | |||||
| data = data.map(input_columns=["text"], operations=jieba_op) | |||||
| vocab = text.Vocab.from_file(VOCAB_FILE, ",") | |||||
| lookup = text.Lookup(vocab, "not") | |||||
| data = data.map(input_columns=["text"], operations=lookup) | |||||
| res = np.array([[8, 0, 9, 0, 10, 0, 13, 0, 11, 0, 12], | |||||
| [9, 0, 10, 0, 8, 0, 12, 0, 11, 0, 13]], dtype=np.int32) | |||||
| for i, d in enumerate(data.create_dict_iterator()): | |||||
| np.testing.assert_array_equal(d["text"], res[i]) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_on_tokenized_line() | test_on_tokenized_line() | ||||
| test_on_tokenized_line_with_no_special_tokens() | |||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -163,7 +163,6 @@ def test_sync_exception_01(): | |||||
| """ | """ | ||||
| logger.info("test_sync_exception_01") | logger.info("test_sync_exception_01") | ||||
| shuffle_size = 4 | shuffle_size = 4 | ||||
| batch_size = 10 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | dataset = ds.GeneratorDataset(gen, column_names=["input"]) | ||||
| @@ -171,11 +170,9 @@ def test_sync_exception_01(): | |||||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | ||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | ||||
| try: | |||||
| dataset = dataset.shuffle(shuffle_size) | |||||
| except Exception as e: | |||||
| assert "shuffle" in str(e) | |||||
| dataset = dataset.batch(batch_size) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| dataset.shuffle(shuffle_size) | |||||
| assert "No shuffle after sync operators" in str(e.value) | |||||
| def test_sync_exception_02(): | def test_sync_exception_02(): | ||||
| @@ -183,7 +180,6 @@ def test_sync_exception_02(): | |||||
| Test sync: with duplicated condition name | Test sync: with duplicated condition name | ||||
| """ | """ | ||||
| logger.info("test_sync_exception_02") | logger.info("test_sync_exception_02") | ||||
| batch_size = 6 | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | dataset = ds.GeneratorDataset(gen, column_names=["input"]) | ||||
| @@ -192,11 +188,9 @@ def test_sync_exception_02(): | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | ||||
| try: | |||||
| dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") | |||||
| except Exception as e: | |||||
| assert "name" in str(e) | |||||
| dataset = dataset.batch(batch_size) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| dataset.sync_wait(num_batch=2, condition_name="every batch") | |||||
| assert "Condition name is already in use" in str(e.value) | |||||
| def test_sync_exception_03(): | def test_sync_exception_03(): | ||||
| @@ -209,12 +203,9 @@ def test_sync_exception_03(): | |||||
| aug = Augment(0) | aug = Augment(0) | ||||
| # try to create dataset with batch_size < 0 | # try to create dataset with batch_size < 0 | ||||
| try: | |||||
| dataset = dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update) | |||||
| except Exception as e: | |||||
| assert "num_batch" in str(e) | |||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||||
| with pytest.raises(ValueError) as e: | |||||
| dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update) | |||||
| assert "num_batch need to be greater than 0." in str(e.value) | |||||
| def test_sync_exception_04(): | def test_sync_exception_04(): | ||||
| @@ -230,14 +221,13 @@ def test_sync_exception_04(): | |||||
| dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | ||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | ||||
| count = 0 | count = 0 | ||||
| try: | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| for _ in dataset.create_dict_iterator(): | for _ in dataset.create_dict_iterator(): | ||||
| count += 1 | count += 1 | ||||
| data = {"loss": count} | data = {"loss": count} | ||||
| # dataset.disable_sync() | |||||
| dataset.sync_update(condition_name="every batch", num_batch=-1, data=data) | dataset.sync_update(condition_name="every batch", num_batch=-1, data=data) | ||||
| except Exception as e: | |||||
| assert "batch" in str(e) | |||||
| assert "Sync_update batch size can only be positive" in str(e.value) | |||||
| def test_sync_exception_05(): | def test_sync_exception_05(): | ||||
| """ | """ | ||||
| @@ -251,15 +241,15 @@ def test_sync_exception_05(): | |||||
| # try to create dataset with batch_size < 0 | # try to create dataset with batch_size < 0 | ||||
| dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) | ||||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | ||||
| try: | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| for _ in dataset.create_dict_iterator(): | for _ in dataset.create_dict_iterator(): | ||||
| dataset.disable_sync() | dataset.disable_sync() | ||||
| count += 1 | count += 1 | ||||
| data = {"loss": count} | data = {"loss": count} | ||||
| dataset.disable_sync() | dataset.disable_sync() | ||||
| dataset.sync_update(condition_name="every", data=data) | dataset.sync_update(condition_name="every", data=data) | ||||
| except Exception as e: | |||||
| assert "name" in str(e) | |||||
| assert "Condition name not found" in str(e.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_simple_sync_wait() | test_simple_sync_wait() | ||||
| @@ -16,6 +16,7 @@ | |||||
| Testing UniformAugment in DE | Testing UniformAugment in DE | ||||
| """ | """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| import mindspore.dataset.transforms.vision.c_transforms as C | import mindspore.dataset.transforms.vision.c_transforms as C | ||||
| @@ -164,14 +165,13 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): | |||||
| C.RandomRotation(degrees=45), | C.RandomRotation(degrees=45), | ||||
| F.Invert()] | F.Invert()] | ||||
| try: | |||||
| with pytest.raises(TypeError) as e: | |||||
| _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | ||||
| except Exception as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "Argument tensor_op_5 with value" \ | |||||
| " <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e) | |||||
| assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e) | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "Argument tensor_op_5 with value" \ | |||||
| " <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e.value) | |||||
| assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e.value) | |||||
| def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | ||||