| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test dataset performance about mindspore.MindDataset, mindspore.TFRecordDataset, tf.data.TFRecordDataset""" | |||
| import tensorflow as tf | |||
| import time | |||
| import tensorflow as tf | |||
| import mindspore.dataset as ds | |||
| from mindspore.mindrecord import FileReader | |||
| @@ -32,9 +32,9 @@ def test_apply_generator_case(): | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data2 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| def dataset_fn(ds): | |||
| ds = ds.repeat(2) | |||
| return ds.batch(4) | |||
| def dataset_fn(ds_): | |||
| ds_ = ds_.repeat(2) | |||
| return ds_.batch(4) | |||
| data1 = data1.apply(dataset_fn) | |||
| data2 = data2.repeat(2) | |||
| @@ -52,11 +52,11 @@ def test_apply_imagefolder_case(): | |||
| decode_op = vision.Decode() | |||
| normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0]) | |||
| def dataset_fn(ds): | |||
| ds = ds.map(operations=decode_op) | |||
| ds = ds.map(operations=normalize_op) | |||
| ds = ds.repeat(2) | |||
| return ds | |||
| def dataset_fn(ds_): | |||
| ds_ = ds_.map(operations=decode_op) | |||
| ds_ = ds_.map(operations=normalize_op) | |||
| ds_ = ds_.repeat(2) | |||
| return ds_ | |||
| data1 = data1.apply(dataset_fn) | |||
| data2 = data2.map(operations=decode_op) | |||
| @@ -67,125 +67,125 @@ def test_apply_imagefolder_case(): | |||
| assert np.array_equal(item1["image"], item2["image"]) | |||
| def test_apply_flow_case_0(id=0): | |||
| def test_apply_flow_case_0(id_=0): | |||
| # apply control flow operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| def dataset_fn(ds): | |||
| if id == 0: | |||
| ds = ds.batch(4) | |||
| elif id == 1: | |||
| ds = ds.repeat(2) | |||
| elif id == 2: | |||
| ds = ds.batch(4) | |||
| ds = ds.repeat(2) | |||
| def dataset_fn(ds_): | |||
| if id_ == 0: | |||
| ds_ = ds_.batch(4) | |||
| elif id_ == 1: | |||
| ds_ = ds_.repeat(2) | |||
| elif id_ == 2: | |||
| ds_ = ds_.batch(4) | |||
| ds_ = ds_.repeat(2) | |||
| else: | |||
| ds = ds.shuffle(buffer_size=4) | |||
| return ds | |||
| ds_ = ds_.shuffle(buffer_size=4) | |||
| return ds_ | |||
| data1 = data1.apply(dataset_fn) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_iter = num_iter + 1 | |||
| if id == 0: | |||
| if id_ == 0: | |||
| assert num_iter == 16 | |||
| elif id == 1: | |||
| elif id_ == 1: | |||
| assert num_iter == 128 | |||
| elif id == 2: | |||
| elif id_ == 2: | |||
| assert num_iter == 32 | |||
| else: | |||
| assert num_iter == 64 | |||
| def test_apply_flow_case_1(id=1): | |||
| def test_apply_flow_case_1(id_=1): | |||
| # apply control flow operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| def dataset_fn(ds): | |||
| if id == 0: | |||
| ds = ds.batch(4) | |||
| elif id == 1: | |||
| ds = ds.repeat(2) | |||
| elif id == 2: | |||
| ds = ds.batch(4) | |||
| ds = ds.repeat(2) | |||
| def dataset_fn(ds_): | |||
| if id_ == 0: | |||
| ds_ = ds_.batch(4) | |||
| elif id_ == 1: | |||
| ds_ = ds_.repeat(2) | |||
| elif id_ == 2: | |||
| ds_ = ds_.batch(4) | |||
| ds_ = ds_.repeat(2) | |||
| else: | |||
| ds = ds.shuffle(buffer_size=4) | |||
| return ds | |||
| ds_ = ds_.shuffle(buffer_size=4) | |||
| return ds_ | |||
| data1 = data1.apply(dataset_fn) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_iter = num_iter + 1 | |||
| if id == 0: | |||
| if id_ == 0: | |||
| assert num_iter == 16 | |||
| elif id == 1: | |||
| elif id_ == 1: | |||
| assert num_iter == 128 | |||
| elif id == 2: | |||
| elif id_ == 2: | |||
| assert num_iter == 32 | |||
| else: | |||
| assert num_iter == 64 | |||
| def test_apply_flow_case_2(id=2): | |||
| def test_apply_flow_case_2(id_=2): | |||
| # apply control flow operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| def dataset_fn(ds): | |||
| if id == 0: | |||
| ds = ds.batch(4) | |||
| elif id == 1: | |||
| ds = ds.repeat(2) | |||
| elif id == 2: | |||
| ds = ds.batch(4) | |||
| ds = ds.repeat(2) | |||
| def dataset_fn(ds_): | |||
| if id_ == 0: | |||
| ds_ = ds_.batch(4) | |||
| elif id_ == 1: | |||
| ds_ = ds_.repeat(2) | |||
| elif id_ == 2: | |||
| ds_ = ds_.batch(4) | |||
| ds_ = ds_.repeat(2) | |||
| else: | |||
| ds = ds.shuffle(buffer_size=4) | |||
| return ds | |||
| ds_ = ds_.shuffle(buffer_size=4) | |||
| return ds_ | |||
| data1 = data1.apply(dataset_fn) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_iter = num_iter + 1 | |||
| if id == 0: | |||
| if id_ == 0: | |||
| assert num_iter == 16 | |||
| elif id == 1: | |||
| elif id_ == 1: | |||
| assert num_iter == 128 | |||
| elif id == 2: | |||
| elif id_ == 2: | |||
| assert num_iter == 32 | |||
| else: | |||
| assert num_iter == 64 | |||
| def test_apply_flow_case_3(id=3): | |||
| def test_apply_flow_case_3(id_=3): | |||
| # apply control flow operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| def dataset_fn(ds): | |||
| if id == 0: | |||
| ds = ds.batch(4) | |||
| elif id == 1: | |||
| ds = ds.repeat(2) | |||
| elif id == 2: | |||
| ds = ds.batch(4) | |||
| ds = ds.repeat(2) | |||
| def dataset_fn(ds_): | |||
| if id_ == 0: | |||
| ds_ = ds_.batch(4) | |||
| elif id_ == 1: | |||
| ds_ = ds_.repeat(2) | |||
| elif id_ == 2: | |||
| ds_ = ds_.batch(4) | |||
| ds_ = ds_.repeat(2) | |||
| else: | |||
| ds = ds.shuffle(buffer_size=4) | |||
| return ds | |||
| ds_ = ds_.shuffle(buffer_size=4) | |||
| return ds_ | |||
| data1 = data1.apply(dataset_fn) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_iter = num_iter + 1 | |||
| if id == 0: | |||
| if id_ == 0: | |||
| assert num_iter == 16 | |||
| elif id == 1: | |||
| elif id_ == 1: | |||
| assert num_iter == 128 | |||
| elif id == 2: | |||
| elif id_ == 2: | |||
| assert num_iter == 32 | |||
| else: | |||
| assert num_iter == 64 | |||
| @@ -195,11 +195,11 @@ def test_apply_exception_case(): | |||
| # apply exception operations | |||
| data1 = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| def dataset_fn(ds): | |||
| ds = ds.repeat(2) | |||
| return ds.batch(4) | |||
| def dataset_fn(ds_): | |||
| ds_ = ds_.repeat(2) | |||
| return ds_.batch(4) | |||
| def exception_fn(ds): | |||
| def exception_fn(): | |||
| return np.array([[0], [1], [3], [4], [5]]) | |||
| try: | |||
| @@ -220,12 +220,12 @@ def test_apply_exception_case(): | |||
| try: | |||
| data2 = data1.apply(dataset_fn) | |||
| data3 = data1.apply(dataset_fn) | |||
| _ = data1.apply(dataset_fn) | |||
| for _, _ in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| pass | |||
| assert False | |||
| except ValueError: | |||
| pass | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| if __name__ == '__main__': | |||
| @@ -58,7 +58,7 @@ def test_auto_contrast(plot=False): | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, label) in enumerate(ds_original): | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -79,7 +79,7 @@ def test_auto_contrast(plot=False): | |||
| ds_auto_contrast = ds_auto_contrast.batch(512) | |||
| for idx, (image, label) in enumerate(ds_auto_contrast): | |||
| for idx, (image, _) in enumerate(ds_auto_contrast): | |||
| if idx == 0: | |||
| images_auto_contrast = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -273,7 +273,7 @@ def test_batch_exception_01(): | |||
| data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "num_parallel_workers" in str(e) | |||
| @@ -290,7 +290,7 @@ def test_batch_exception_02(): | |||
| data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "num_parallel_workers" in str(e) | |||
| @@ -307,7 +307,7 @@ def test_batch_exception_03(): | |||
| data1 = data1.batch(batch_size=0) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "batch_size" in str(e) | |||
| @@ -324,7 +324,7 @@ def test_batch_exception_04(): | |||
| data1 = data1.batch(batch_size=-1) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "batch_size" in str(e) | |||
| @@ -341,7 +341,7 @@ def test_batch_exception_05(): | |||
| data1 = data1.batch(batch_size=False) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "batch_size" in str(e) | |||
| @@ -358,7 +358,7 @@ def test_batch_exception_07(): | |||
| data1 = data1.batch(3, drop_remainder=0) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "drop_remainder" in str(e) | |||
| @@ -375,7 +375,7 @@ def test_batch_exception_08(): | |||
| data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "num_parallel_workers" in str(e) | |||
| @@ -392,7 +392,7 @@ def test_batch_exception_09(): | |||
| data1 = data1.batch(drop_remainder=True, num_parallel_workers=4) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "batch_size" in str(e) | |||
| @@ -409,7 +409,7 @@ def test_batch_exception_10(): | |||
| data1 = data1.batch(batch_size=4, num_parallel_workers=8192) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "num_parallel_workers" in str(e) | |||
| @@ -429,7 +429,7 @@ def test_batch_exception_11(): | |||
| data1 = data1.batch(batch_size, num_parallel_workers) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "drop_remainder" in str(e) | |||
| @@ -450,7 +450,7 @@ def test_batch_exception_12(): | |||
| data1 = data1.batch(drop_remainder, batch_size=batch_size) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "batch_size" in str(e) | |||
| @@ -469,7 +469,7 @@ def test_batch_exception_13(): | |||
| data1 = data1.batch(batch_size, shard_id=1) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "shard_id" in str(e) | |||
| @@ -24,18 +24,18 @@ from mindspore import log as logger | |||
| # In generator dataset: Number of rows is 3; its values are 0, 1, 2 | |||
| def generator(): | |||
| for i in range(3): | |||
| yield np.array([i]), | |||
| yield (np.array([i]),) | |||
| # In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9 | |||
| def generator_10(): | |||
| for i in range(3, 10): | |||
| yield np.array([i]), | |||
| yield (np.array([i]),) | |||
| # In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19 | |||
| def generator_20(): | |||
| for i in range(10, 20): | |||
| yield np.array([i]), | |||
| yield (np.array([i]),) | |||
| def test_concat_01(): | |||
| @@ -85,7 +85,7 @@ def test_concat_03(): | |||
| data3 = data1 + data2 | |||
| try: | |||
| for i, d in enumerate(data3): | |||
| for _, _ in enumerate(data3): | |||
| pass | |||
| assert False | |||
| except RuntimeError: | |||
| @@ -104,7 +104,7 @@ def test_concat_04(): | |||
| data3 = data1 + data2 | |||
| try: | |||
| for i, d in enumerate(data3): | |||
| for _, _ in enumerate(data3): | |||
| pass | |||
| assert False | |||
| except RuntimeError: | |||
| @@ -125,7 +125,7 @@ def test_concat_05(): | |||
| data3 = data1 + data2 | |||
| try: | |||
| for i, d in enumerate(data3): | |||
| for _, _ in enumerate(data3): | |||
| pass | |||
| assert False | |||
| except RuntimeError: | |||
| @@ -31,7 +31,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| def test_basic(): | |||
| """ | |||
| Test basic configuration functions | |||
| Test basic configuration functions | |||
| """ | |||
| # Save original configuration values | |||
| num_parallel_workers_original = ds.config.get_num_parallel_workers() | |||
| @@ -138,7 +138,7 @@ def test_deterministic_run_fail(): | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| np.testing.assert_equal(item1["image"], item2["image"]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| # two datasets split the number out of the sequence a | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Array" in str(e) | |||
| @@ -157,7 +157,7 @@ def test_deterministic_run_pass(): | |||
| # Save original configuration values | |||
| num_parallel_workers_original = ds.config.get_num_parallel_workers() | |||
| seed_original = ds.config.get_seed() | |||
| ds.config.set_seed(0) | |||
| ds.config.set_num_parallel_workers(1) | |||
| @@ -179,7 +179,7 @@ def test_deterministic_run_pass(): | |||
| try: | |||
| for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): | |||
| np.testing.assert_equal(item1["image"], item2["image"]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| # two datasets both use numbers from the generated sequence "a" | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Array" in str(e) | |||
| @@ -344,7 +344,7 @@ def test_deterministic_python_seed_multi_thread(): | |||
| try: | |||
| np.testing.assert_equal(data1_output, data2_output) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| # expect output to not match during multi-threaded excution | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Array" in str(e) | |||
| @@ -107,14 +107,20 @@ def test_tfrecord_shardings4(print_res=False): | |||
| assert len(result_list) == expect_length | |||
| assert set(result_list) == expect_set | |||
| check_result(sharding_config(2, 0, None, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) | |||
| check_result(sharding_config(2, 1, None, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) | |||
| check_result(sharding_config(2, 0, None, 1), 20, | |||
| {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) | |||
| check_result(sharding_config(2, 1, None, 1), 20, | |||
| {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) | |||
| check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21}) | |||
| check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31}) | |||
| check_result(sharding_config(2, 0, 40, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) | |||
| check_result(sharding_config(2, 1, 40, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) | |||
| check_result(sharding_config(2, 0, 55, 1), 20, {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) | |||
| check_result(sharding_config(2, 1, 55, 1), 20, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) | |||
| check_result(sharding_config(2, 0, 40, 1), 20, | |||
| {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) | |||
| check_result(sharding_config(2, 1, 40, 1), 20, | |||
| {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) | |||
| check_result(sharding_config(2, 0, 55, 1), 20, | |||
| {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) | |||
| check_result(sharding_config(2, 1, 55, 1), 20, | |||
| {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) | |||
| check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31}) | |||
| check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8}) | |||
| check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28}) | |||
| @@ -49,7 +49,7 @@ def test_textline_dataset_totext(): | |||
| strs = i["text"].item().decode("utf8") | |||
| assert strs == line[count] | |||
| count += 1 | |||
| assert (count == 5) | |||
| assert count == 5 | |||
| # Restore configuration num_parallel_workers | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| @@ -24,10 +24,10 @@ def test_voc_segmentation(): | |||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) | |||
| num = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| assert (item["image"].shape[0] == IMAGE_SHAPE[num]) | |||
| assert (item["target"].shape[0] == TARGET_SHAPE[num]) | |||
| assert item["image"].shape[0] == IMAGE_SHAPE[num] | |||
| assert item["target"].shape[0] == TARGET_SHAPE[num] | |||
| num += 1 | |||
| assert (num == 10) | |||
| assert num == 10 | |||
| def test_voc_detection(): | |||
| @@ -35,12 +35,12 @@ def test_voc_detection(): | |||
| num = 0 | |||
| count = [0, 0, 0, 0, 0, 0] | |||
| for item in data1.create_dict_iterator(): | |||
| assert (item["image"].shape[0] == IMAGE_SHAPE[num]) | |||
| assert item["image"].shape[0] == IMAGE_SHAPE[num] | |||
| for bbox in item["annotation"]: | |||
| count[bbox[0]] += 1 | |||
| num += 1 | |||
| assert (num == 9) | |||
| assert (count == [3, 2, 1, 2, 4, 3]) | |||
| assert num == 9 | |||
| assert count == [3, 2, 1, 2, 4, 3] | |||
| def test_voc_class_index(): | |||
| @@ -58,8 +58,8 @@ def test_voc_class_index(): | |||
| assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5) | |||
| count[bbox[0]] += 1 | |||
| num += 1 | |||
| assert (num == 6) | |||
| assert (count == [3, 2, 0, 0, 0, 3]) | |||
| assert num == 6 | |||
| assert count == [3, 2, 0, 0, 0, 3] | |||
| def test_voc_get_class_indexing(): | |||
| @@ -76,8 +76,8 @@ def test_voc_get_class_indexing(): | |||
| assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5) | |||
| count[bbox[0]] += 1 | |||
| num += 1 | |||
| assert (num == 9) | |||
| assert (count == [3, 2, 1, 2, 4, 3]) | |||
| assert num == 9 | |||
| assert count == [3, 2, 1, 2, 4, 3] | |||
| def test_case_0(): | |||
| @@ -93,9 +93,9 @@ def test_case_0(): | |||
| data1 = data1.batch(batch_size, drop_remainder=True) | |||
| num = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| for _ in data1.create_dict_iterator(): | |||
| num += 1 | |||
| assert (num == 20) | |||
| assert num == 20 | |||
| def test_case_1(): | |||
| @@ -110,9 +110,9 @@ def test_case_1(): | |||
| data1 = data1.batch(batch_size, drop_remainder=True, pad_info={}) | |||
| num = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| for _ in data1.create_dict_iterator(): | |||
| num += 1 | |||
| assert (num == 18) | |||
| assert num == 18 | |||
| def test_voc_exception(): | |||
| @@ -58,7 +58,7 @@ def test_equalize(plot=False): | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, label) in enumerate(ds_original): | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -79,7 +79,7 @@ def test_equalize(plot=False): | |||
| ds_equalize = ds_equalize.batch(512) | |||
| for idx, (image, label) in enumerate(ds_equalize): | |||
| for idx, (image, _) in enumerate(ds_equalize): | |||
| if idx == 0: | |||
| images_equalize = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -15,9 +15,7 @@ | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.dataset.transforms.vision.c_transforms as cde | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| @@ -31,7 +29,6 @@ def test_diff_predicate_func(): | |||
| cde.Decode(), | |||
| cde.Resize([64, 64]) | |||
| ] | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False) | |||
| dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1) | |||
| dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4) | |||
| @@ -40,7 +37,6 @@ def test_diff_predicate_func(): | |||
| label_list = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| num_iter += 1 | |||
| ori_img = data["image"] | |||
| label = data["label"] | |||
| label_list.append(label) | |||
| assert num_iter == 1 | |||
| @@ -200,6 +196,7 @@ def generator_1d_zip2(): | |||
| def filter_func_zip(data1, data2): | |||
| _ = data2 | |||
| if data1 > 20: | |||
| return False | |||
| return True | |||
| @@ -249,6 +246,7 @@ def test_filter_by_generator_with_zip_after(): | |||
| def filter_func_map(col1, col2): | |||
| _ = col2 | |||
| if col1[0] > 8: | |||
| return True | |||
| return False | |||
| @@ -262,6 +260,7 @@ def filter_func_map_part(col1): | |||
| def filter_func_map_all(col1, col2): | |||
| _, _ = col1, col2 | |||
| return True | |||
| @@ -334,6 +333,7 @@ def test_filter_by_generator_with_rename(): | |||
| # test input_column | |||
| def filter_func_input_column1(col1, col2): | |||
| _ = col2 | |||
| if col1[0] < 8: | |||
| return True | |||
| return False | |||
| @@ -346,6 +346,7 @@ def filter_func_input_column2(col1): | |||
| def filter_func_input_column3(col1): | |||
| _ = col1 | |||
| return True | |||
| @@ -380,6 +381,7 @@ def generator_mc_p1(maxid=20): | |||
| def filter_func_Partial_0(col1, col2, col3, col4): | |||
| _, _, _ = col2, col3, col4 | |||
| filter_data = [0, 1, 2, 3, 4, 11] | |||
| if col1[0] in filter_data: | |||
| return False | |||
| @@ -439,6 +441,7 @@ def test_filter_by_generator_Partial2(): | |||
| def filter_func_Partial(col1, col2): | |||
| _ = col2 | |||
| if col1[0] % 3 == 0: | |||
| return True | |||
| return False | |||
| @@ -461,6 +464,7 @@ def test_filter_by_generator_Partial(): | |||
| def filter_func_cifar(col1, col2): | |||
| _ = col1 | |||
| if col2 % 3 == 0: | |||
| return True | |||
| return False | |||
| @@ -490,6 +494,7 @@ def generator_sort2(maxid=20): | |||
| def filter_func_part_sort(col1, col2, col3, col4, col5, col6): | |||
| _, _, _, _, _, _ = col1, col2, col3, col4, col5, col6 | |||
| return True | |||
| @@ -58,7 +58,7 @@ def test_invert(plot=False): | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, label) in enumerate(ds_original): | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -79,7 +79,7 @@ def test_invert(plot=False): | |||
| ds_invert = ds_invert.batch(512) | |||
| for idx, (image, label) in enumerate(ds_invert): | |||
| for idx, (image, _) in enumerate(ds_invert): | |||
| if idx == 0: | |||
| images_invert = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -17,11 +17,11 @@ This is the test module for mindrecord | |||
| """ | |||
| import collections | |||
| import json | |||
| import numpy as np | |||
| import os | |||
| import pytest | |||
| import re | |||
| import string | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| @@ -46,9 +46,10 @@ def add_and_remove_cv_file(): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists( | |||
| "{}.db".format(x)) else None | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | |||
| data = get_data(CV_DIR_NAME) | |||
| cv_schema_json = {"id": {"type": "int32"}, | |||
| @@ -117,7 +118,9 @@ def add_and_remove_nlp_compress_file(): | |||
| 255, 256, -32768, 32767, -32769, 32768, -2147483648, | |||
| 2147483647], dtype=np.int32), [-1]), | |||
| "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, | |||
| 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), | |||
| 256, -32768, 32767, -32769, 32768, | |||
| -2147483648, 2147483647, -2147483649, 2147483649, | |||
| -922337036854775808, 9223372036854775807]), [1, -1]), | |||
| "array_c": str.encode("nlp data"), | |||
| "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) | |||
| }) | |||
| @@ -151,7 +154,9 @@ def test_nlp_compress_data(add_and_remove_nlp_compress_file): | |||
| 255, 256, -32768, 32767, -32769, 32768, -2147483648, | |||
| 2147483647], dtype=np.int32), [-1]), | |||
| "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255, | |||
| 256, -32768, 32767, -32769, 32768, -2147483648, 2147483647, -2147483649, 2147483649, -922337036854775808, 9223372036854775807]), [1, -1]), | |||
| 256, -32768, 32767, -32769, 32768, | |||
| -2147483648, 2147483647, -2147483649, 2147483649, | |||
| -922337036854775808, 9223372036854775807]), [1, -1]), | |||
| "array_c": str.encode("nlp data"), | |||
| "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1]) | |||
| }) | |||
| @@ -194,9 +199,10 @@ def test_cv_minddataset_writer_tutorial(): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists( | |||
| "{}.db".format(x)) else None | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | |||
| data = get_data(CV_DIR_NAME) | |||
| cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, | |||
| @@ -478,9 +484,10 @@ def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): | |||
| paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| for x in paths: | |||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | |||
| os.remove("{}.db".format(x)) if os.path.exists( | |||
| "{}.db".format(x)) else None | |||
| if os.path.exists("{}".format(x)): | |||
| os.remove("{}".format(x)) | |||
| if os.path.exists("{}.db".format(x)): | |||
| os.remove("{}.db".format(x)) | |||
| writer = FileWriter(CV1_FILE_NAME, FILES_NUM) | |||
| data = get_data(CV_DIR_NAME) | |||
| cv_schema_json = {"id": {"type": "int32"}, | |||
| @@ -779,7 +786,7 @@ def get_nlp_data(dir_name, vocab_file, num): | |||
| """ | |||
| if not os.path.isdir(dir_name): | |||
| raise IOError("Directory {} not exists".format(dir_name)) | |||
| for root, dirs, files in os.walk(dir_name): | |||
| for root, _, files in os.walk(dir_name): | |||
| for index, file_name_extension in enumerate(files): | |||
| if index < num: | |||
| file_path = os.path.join(root, file_name_extension) | |||
| @@ -851,7 +858,7 @@ def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): | |||
| if os.path.exists("{}".format(mindrecord_file_name)): | |||
| os.remove("{}".format(mindrecord_file_name)) | |||
| if os.path.exists("{}.db".format(mindrecord_file_name)): | |||
| os.remove("{}.db".format(x)) | |||
| os.remove("{}.db".format(mindrecord_file_name)) | |||
| data = [{"file_name": "001.jpg", "label": 4, | |||
| "image1": bytes("image1 bytes abc", encoding='UTF-8'), | |||
| "image2": bytes("image1 bytes def", encoding='UTF-8'), | |||
| @@ -26,8 +26,10 @@ CV1_FILE_NAME = "./imagenet1.mindrecord" | |||
| def create_cv_mindrecord(files_num): | |||
| """tutorial for cv dataset writer.""" | |||
| os.remove(CV_FILE_NAME) if os.path.exists(CV_FILE_NAME) else None | |||
| os.remove("{}.db".format(CV_FILE_NAME)) if os.path.exists("{}.db".format(CV_FILE_NAME)) else None | |||
| if os.path.exists(CV_FILE_NAME): | |||
| os.remove(CV_FILE_NAME) | |||
| if os.path.exists("{}.db".format(CV_FILE_NAME)): | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| writer = FileWriter(CV_FILE_NAME, files_num) | |||
| cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} | |||
| data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] | |||
| @@ -39,8 +41,10 @@ def create_cv_mindrecord(files_num): | |||
| def create_diff_schema_cv_mindrecord(files_num): | |||
| """tutorial for cv dataset writer.""" | |||
| os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None | |||
| os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None | |||
| if os.path.exists(CV1_FILE_NAME): | |||
| os.remove(CV1_FILE_NAME) | |||
| if os.path.exists("{}.db".format(CV1_FILE_NAME)): | |||
| os.remove("{}.db".format(CV1_FILE_NAME)) | |||
| writer = FileWriter(CV1_FILE_NAME, files_num) | |||
| cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} | |||
| data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] | |||
| @@ -52,8 +56,10 @@ def create_diff_schema_cv_mindrecord(files_num): | |||
| def create_diff_page_size_cv_mindrecord(files_num): | |||
| """tutorial for cv dataset writer.""" | |||
| os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None | |||
| os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None | |||
| if os.path.exists(CV1_FILE_NAME): | |||
| os.remove(CV1_FILE_NAME) | |||
| if os.path.exists("{}.db".format(CV1_FILE_NAME)): | |||
| os.remove("{}.db".format(CV1_FILE_NAME)) | |||
| writer = FileWriter(CV1_FILE_NAME, files_num) | |||
| writer.set_page_size(1 << 26) # 64MB | |||
| cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} | |||
| @@ -69,8 +75,8 @@ def test_cv_lack_json(): | |||
| create_cv_mindrecord(1) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception) as err: | |||
| data_set = ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers) | |||
| with pytest.raises(Exception): | |||
| ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -80,7 +86,7 @@ def test_cv_lack_mindrecord(): | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="does not exist or permission denied"): | |||
| data_set = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers) | |||
| _ = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers) | |||
| def test_invalid_mindrecord(): | |||
| @@ -134,7 +140,7 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle(): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, | |||
| sampler=sampler, shuffle=False) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -149,7 +155,7 @@ def test_cv_minddataset_reader_different_schema(): | |||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | |||
| num_readers) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -166,7 +172,7 @@ def test_cv_minddataset_reader_different_page_size(): | |||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | |||
| num_readers) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -181,7 +187,7 @@ def test_minddataset_invalidate_num_shards(): | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 0, 1) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -194,7 +200,7 @@ def test_minddataset_invalidate_shard_id(): | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -207,13 +213,13 @@ def test_minddataset_shard_id_bigger_than_num_shard(): | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| os.remove(CV_FILE_NAME) | |||
| @@ -50,7 +50,7 @@ def test_cv_minddataset_reader_multi_image_and_ndarray_tutorial(): | |||
| assert os.path.exists(CV_FILE_NAME) | |||
| assert os.path.exists(CV_FILE_NAME + ".db") | |||
| """tutorial for minderdataset.""" | |||
| # tutorial for minderdataset. | |||
| columns_list = ["id", "image_0", "image_2", "image_3", "image_4", "input_mask", "segments"] | |||
| num_readers = 1 | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) | |||
| @@ -20,7 +20,6 @@ import pytest | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| from mindspore.dataset.transforms.vision import Inter | |||
| from mindspore.dataset.text import to_str | |||
| from mindspore.mindrecord import FileWriter | |||
| @@ -39,7 +39,7 @@ def test_on_tokenized_line(): | |||
| res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], | |||
| [11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32) | |||
| for i, d in enumerate(data.create_dict_iterator()): | |||
| np.testing.assert_array_equal(d["text"], res[i]), i | |||
| _ = (np.testing.assert_array_equal(d["text"], res[i]), i) | |||
| if __name__ == '__main__': | |||
| @@ -199,7 +199,7 @@ def test_jieba_5(): | |||
| def gen(): | |||
| text = np.array("今天天气太好了我们一起去外面玩吧".encode("UTF8"), dtype='S') | |||
| yield text, | |||
| yield (text,) | |||
| def pytoken_op(input_data): | |||
| @@ -109,10 +109,9 @@ def test_decode_op(): | |||
| data1 = data1.map(input_columns=["image"], operations=decode_op) | |||
| num_iter = 0 | |||
| image = None | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("Looping inside iterator {}".format(num_iter)) | |||
| image = item["image"] | |||
| _ = item["image"] | |||
| # plt.subplot(131) | |||
| # plt.imshow(image) | |||
| # plt.title("DE image") | |||
| @@ -134,10 +133,9 @@ def test_decode_normalize_op(): | |||
| data1 = data1.map(input_columns=["image"], operations=[decode_op, normalize_op]) | |||
| num_iter = 0 | |||
| image = None | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("Looping inside iterator {}".format(num_iter)) | |||
| image = item["image"] | |||
| _ = item["image"] | |||
| # plt.subplot(131) | |||
| # plt.imshow(image) | |||
| # plt.title("DE image") | |||
| @@ -37,8 +37,7 @@ def test_case_0(): | |||
| data1 = data1.batch(2) | |||
| i = 0 | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| for _ in data1.create_dict_iterator(): # each data is a dictionary | |||
| pass | |||
| @@ -72,7 +72,7 @@ def test_pad_op(): | |||
| # pylint: disable=unnecessary-lambda | |||
| def test_pad_grayscale(): | |||
| """ | |||
| Tests that the pad works for grayscale images | |||
| Tests that the pad works for grayscale images | |||
| """ | |||
| def channel_swap(image): | |||
| @@ -92,7 +92,7 @@ def test_pad_grayscale(): | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| data1 = data1.map(input_columns=["image"], operations=transform()) | |||
| # if input is grayscale, the output dimensions should be single channel | |||
| # if input is grayscale, the output dimensions should be single channel | |||
| pad_gray = c_vision.Pad(100, fill_value=(20, 20, 20)) | |||
| data1 = data1.map(input_columns=["image"], operations=pad_gray) | |||
| dataset_shape_1 = [] | |||
| @@ -100,11 +100,11 @@ def test_pad_grayscale(): | |||
| c_image = item1["image"] | |||
| dataset_shape_1.append(c_image.shape) | |||
| # Dataset for comparison | |||
| # Dataset for comparison | |||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| decode_op = c_vision.Decode() | |||
| # we use the same padding logic | |||
| # we use the same padding logic | |||
| ctrans = [decode_op, pad_gray] | |||
| dataset_shape_2 = [] | |||
| @@ -119,7 +119,7 @@ def batch_padding_performance_3d(): | |||
| num_batches = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_batches += 1 | |||
| res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| # print(res) | |||
| @@ -135,7 +135,7 @@ def batch_padding_performance_1d(): | |||
| num_batches = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_batches += 1 | |||
| res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| # print(res) | |||
| @@ -151,7 +151,7 @@ def batch_pyfunc_padding_3d(): | |||
| num_batches = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_batches += 1 | |||
| res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| # print(res) | |||
| @@ -166,7 +166,7 @@ def batch_pyfunc_padding_1d(): | |||
| num_batches = 0 | |||
| for _ in data1.create_dict_iterator(): | |||
| num_batches += 1 | |||
| res = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| _ = "total number of batch:" + str(num_batches) + " time elapsed:" + str(time.time() - start_time) | |||
| # print(res) | |||
| @@ -58,7 +58,7 @@ def test_random_color(degrees=(0.1, 1.9), plot=False): | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, label) in enumerate(ds_original): | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -79,7 +79,7 @@ def test_random_color(degrees=(0.1, 1.9), plot=False): | |||
| ds_random_color = ds_random_color.batch(512) | |||
| for idx, (image, label) in enumerate(ds_random_color): | |||
| for idx, (image, _) in enumerate(ds_random_color): | |||
| if idx == 0: | |||
| images_random_color = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -256,7 +256,7 @@ def test_random_color_adjust_op_hue(plot=False): | |||
| # pylint: disable=unnecessary-lambda | |||
| def test_random_color_adjust_grayscale(): | |||
| """ | |||
| Tests that the random color adjust works for grayscale images | |||
| Tests that the random color adjust works for grayscale images | |||
| """ | |||
| def channel_swap(image): | |||
| @@ -284,7 +284,7 @@ def test_random_color_adjust_grayscale(): | |||
| for item1 in data1.create_dict_iterator(): | |||
| c_image = item1["image"] | |||
| dataset_shape_1.append(c_image.shape) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| @@ -200,7 +200,7 @@ def test_random_crop_04_c(): | |||
| for item in data.create_dict_iterator(): | |||
| image = item["image"] | |||
| image_list.append(image.shape) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| def test_random_crop_04_py(): | |||
| @@ -227,7 +227,7 @@ def test_random_crop_04_py(): | |||
| for item in data.create_dict_iterator(): | |||
| image = (item["image"].transpose(1, 2, 0) * 255).astype(np.uint8) | |||
| image_list.append(image.shape) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| def test_random_crop_05_c(): | |||
| @@ -439,7 +439,7 @@ def test_random_crop_09(): | |||
| for item in data.create_dict_iterator(): | |||
| image = item["image"] | |||
| image_list.append(image.shape) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "should be PIL Image" in str(e) | |||
| @@ -60,7 +60,7 @@ def test_random_resize_op(): | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| image_de_resized = item["image"] | |||
| _ = item["image"] | |||
| # Uncomment below line if you want to visualize images | |||
| # visualize(image_de_resized, image_np_resized, mse) | |||
| num_iter += 1 | |||
| @@ -58,7 +58,7 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False): | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, label) in enumerate(ds_original): | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -79,7 +79,7 @@ def test_random_sharpness(degrees=(0.1, 1.9), plot=False): | |||
| ds_random_sharpness = ds_random_sharpness.batch(512) | |||
| for idx, (image, label) in enumerate(ds_random_sharpness): | |||
| for idx, (image, _) in enumerate(ds_random_sharpness): | |||
| if idx == 0: | |||
| images_random_sharpness = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -25,7 +25,7 @@ from mindspore import log as logger | |||
| def test_sequential_sampler(print_res=False): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(num_samples, num_repeats=None): | |||
| sampler = ds.SequentialSampler() | |||
| @@ -36,7 +36,7 @@ def test_sequential_sampler(print_res=False): | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("item[image].shape[0]: {}, item[label].item(): {}" | |||
| .format(item["image"].shape[0], item["label"].item())) | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| res.append(map_[(item["image"].shape[0], item["label"].item())]) | |||
| if print_res: | |||
| logger.info("image.shapes and labels: {}".format(res)) | |||
| return res | |||
| @@ -48,7 +48,7 @@ def test_sequential_sampler(print_res=False): | |||
| def test_random_sampler(print_res=False): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(replacement, num_samples, num_repeats): | |||
| sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples) | |||
| @@ -56,7 +56,7 @@ def test_random_sampler(print_res=False): | |||
| data1 = data1.repeat(num_repeats) | |||
| res = [] | |||
| for item in data1.create_dict_iterator(): | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| res.append(map_[(item["image"].shape[0], item["label"].item())]) | |||
| if print_res: | |||
| logger.info("image.shapes and labels: {}".format(res)) | |||
| return res | |||
| @@ -71,7 +71,7 @@ def test_random_sampler(print_res=False): | |||
| def test_random_sampler_multi_iter(print_res=False): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(replacement, num_samples, num_repeats, validate): | |||
| sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples) | |||
| @@ -79,7 +79,7 @@ def test_random_sampler_multi_iter(print_res=False): | |||
| while num_repeats > 0: | |||
| res = [] | |||
| for item in data1.create_dict_iterator(): | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| res.append(map_[(item["image"].shape[0], item["label"].item())]) | |||
| if print_res: | |||
| logger.info("image.shapes and labels: {}".format(res)) | |||
| if validate != sorted(res): | |||
| @@ -112,7 +112,7 @@ def test_sampler_py_api(): | |||
| def test_python_sampler(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| class Sp1(ds.Sampler): | |||
| def __iter__(self): | |||
| @@ -138,7 +138,7 @@ def test_python_sampler(): | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("item[image].shape[0]: {}, item[label].item(): {}" | |||
| .format(item["image"].shape[0], item["label"].item())) | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| res.append(map_[(item["image"].shape[0], item["label"].item())]) | |||
| # print(res) | |||
| return res | |||
| @@ -167,7 +167,7 @@ def test_python_sampler(): | |||
| def test_subset_sampler(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(num_samples, start_index, subset_size): | |||
| sampler = ds.SubsetSampler(start_index, subset_size) | |||
| @@ -175,7 +175,7 @@ def test_subset_sampler(): | |||
| res = [] | |||
| for item in d.create_dict_iterator(): | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| res.append(map_[(item["image"].shape[0], item["label"].item())]) | |||
| return res | |||
| @@ -196,7 +196,7 @@ def test_subset_sampler(): | |||
| def test_sampler_chain(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(num_shards, shard_id): | |||
| sampler = ds.DistributedSampler(num_shards, shard_id, False) | |||
| @@ -209,7 +209,7 @@ def test_sampler_chain(): | |||
| for item in data1.create_dict_iterator(): | |||
| logger.info("item[image].shape[0]: {}, item[label].item(): {}" | |||
| .format(item["image"].shape[0], item["label"].item())) | |||
| res.append(map[(item["image"].shape[0], item["label"].item())]) | |||
| res.append(map_[(item["image"].shape[0], item["label"].item())]) | |||
| return res | |||
| assert test_config(2, 0) == [0, 2, 4] | |||
| @@ -222,7 +222,7 @@ def test_sampler_chain(): | |||
| def test_add_sampler_invalid_input(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| _ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| data1 = ds.ManifestDataset(manifest_file) | |||
| with pytest.raises(TypeError) as info: | |||
| @@ -18,9 +18,8 @@ Testing dataset serialize and deserialize in DE | |||
| import filecmp | |||
| import glob | |||
| import json | |||
| import numpy as np | |||
| import os | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as c | |||
| @@ -28,6 +27,8 @@ import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| from mindspore.dataset.transforms.vision import Inter | |||
| from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME | |||
| def test_imagefolder(remove_json_files=True): | |||
| """ | |||
| @@ -186,7 +187,7 @@ def test_random_crop(): | |||
| # Serializing into python dictionary | |||
| ds1_dict = ds.serialize(data1) | |||
| # Serializing into json object | |||
| ds1_json = json.dumps(ds1_dict, indent=2) | |||
| _ = json.dumps(ds1_dict, indent=2) | |||
| # Reconstruct dataset pipeline from its serialized form | |||
| data1_1 = ds.deserialize(input_dict=ds1_dict) | |||
| @@ -198,7 +199,7 @@ def test_random_crop(): | |||
| for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(), | |||
| data2.create_dict_iterator()): | |||
| assert np.array_equal(item1['image'], item1_1['image']) | |||
| image2 = item2["image"] | |||
| _ = item2["image"] | |||
| def validate_jsonfile(filepath): | |||
| @@ -221,10 +222,6 @@ def delete_json_files(): | |||
| # Test save load minddataset | |||
| from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME, FILES_NUM, \ | |||
| FileWriter, Inter | |||
| def test_minddataset(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| @@ -247,7 +244,7 @@ def test_minddataset(add_and_remove_cv_file): | |||
| assert ds1_json == ds2_json | |||
| data = get_data(CV_DIR_NAME) | |||
| _ = get_data(CV_DIR_NAME) | |||
| assert data_set.get_dataset_size() == 5 | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| @@ -152,7 +152,7 @@ def test_shuffle_exception_01(): | |||
| data1 = data1.shuffle(buffer_size=-1) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| @@ -170,7 +170,7 @@ def test_shuffle_exception_02(): | |||
| data1 = data1.shuffle(buffer_size=0) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| @@ -188,7 +188,7 @@ def test_shuffle_exception_03(): | |||
| data1 = data1.shuffle(buffer_size=1) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| @@ -206,7 +206,7 @@ def test_shuffle_exception_05(): | |||
| data1 = data1.shuffle() | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| @@ -224,7 +224,7 @@ def test_shuffle_exception_06(): | |||
| data1 = data1.shuffle(buffer_size=False) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| @@ -242,7 +242,7 @@ def test_shuffle_exception_07(): | |||
| data1 = data1.shuffle(buffer_size=True) | |||
| sum([1 for _ in data1]) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| @@ -70,7 +70,6 @@ def test_skip_1(): | |||
| buf = [] | |||
| for data in ds1: | |||
| buf.append(data[0][0]) | |||
| assert len(buf) == 0 | |||
| assert buf == [] | |||
| @@ -29,47 +29,47 @@ text_file_data = ["This is a text file.", "Another file.", "Be happy every day." | |||
| def split_with_invalid_inputs(d): | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([]) | |||
| _, _ = d.split([]) | |||
| assert "sizes cannot be empty" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([5, 0.6]) | |||
| _, _ = d.split([5, 0.6]) | |||
| assert "sizes should be list of int or list of float" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([-1, 6]) | |||
| _, _ = d.split([-1, 6]) | |||
| assert "there should be no negative numbers" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| s1, s2 = d.split([3, 1]) | |||
| _, _ = d.split([3, 1]) | |||
| assert "sum of split sizes 4 is not equal to dataset size 5" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| s1, s2 = d.split([5, 1]) | |||
| _, _ = d.split([5, 1]) | |||
| assert "sum of split sizes 6 is not equal to dataset size 5" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| s1, s2 = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) | |||
| _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) | |||
| assert "sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([-0.5, 0.5]) | |||
| _, _ = d.split([-0.5, 0.5]) | |||
| assert "there should be no numbers outside the range [0, 1]" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([1.5, 0.5]) | |||
| _, _ = d.split([1.5, 0.5]) | |||
| assert "there should be no numbers outside the range [0, 1]" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([0.5, 0.6]) | |||
| _, _ = d.split([0.5, 0.6]) | |||
| assert "percentages do not sum up to 1" in str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| s1, s2 = d.split([0.3, 0.6]) | |||
| _, _ = d.split([0.3, 0.6]) | |||
| assert "percentages do not sum up to 1" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| s1, s2 = d.split([0.05, 0.95]) | |||
| _, _ = d.split([0.05, 0.95]) | |||
| assert "percentage 0.05 is too small" in str(info.value) | |||
| @@ -79,7 +79,7 @@ def test_unmappable_invalid_input(): | |||
| d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) | |||
| with pytest.raises(RuntimeError) as info: | |||
| s1, s2 = d.split([4, 1]) | |||
| _, _ = d.split([4, 1]) | |||
| assert "dataset should not be sharded before split" in str(info.value) | |||
| @@ -273,7 +273,7 @@ def test_mappable_invalid_input(): | |||
| d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) | |||
| with pytest.raises(RuntimeError) as info: | |||
| s1, s2 = d.split([4, 1]) | |||
| _, _ = d.split([4, 1]) | |||
| assert "dataset should not be sharded before split" in str(info.value) | |||
| @@ -28,8 +28,8 @@ class Augment: | |||
| def __init__(self, loss): | |||
| self.loss = loss | |||
| def preprocess(self, input): | |||
| return input | |||
| def preprocess(self, input_): | |||
| return input_ | |||
| def update(self, data): | |||
| self.loss = data["loss"] | |||
| @@ -143,7 +143,7 @@ def test_multiple_iterators(): | |||
| dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) | |||
| dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) | |||
| dataset = dataset.batch(batch_size, drop_remainder=True) | |||
| # 2nd dataset | |||
| # 2nd dataset | |||
| dataset2 = ds.GeneratorDataset(gen, column_names=["input"]) | |||
| aug = Augment(0) | |||
| @@ -175,7 +175,7 @@ def test_sync_exception_01(): | |||
| try: | |||
| dataset = dataset.shuffle(shuffle_size) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| assert "shuffle" in str(e) | |||
| dataset = dataset.batch(batch_size) | |||
| @@ -197,7 +197,7 @@ def test_sync_exception_02(): | |||
| try: | |||
| dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| assert "name" in str(e) | |||
| dataset = dataset.batch(batch_size) | |||
| @@ -46,7 +46,7 @@ def test_take_01(): | |||
| data1 = data1.take(1) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| # Here i refers to index, d refers to data element | |||
| for _, d in enumerate(data1): | |||
| assert d[0][0] == 0 | |||
| @@ -63,7 +63,7 @@ def test_take_02(): | |||
| data1 = data1.take(2) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| # Here i refers to index, d refers to data element | |||
| for i, d in enumerate(data1): | |||
| assert i % 2 == d[0][0] | |||
| @@ -80,7 +80,7 @@ def test_take_03(): | |||
| data1 = data1.take(3) | |||
| data1 = data1.repeat(2) | |||
| # Here i refers to index, d refers to data element | |||
| # Here i refers to index, d refers to data elements | |||
| for i, d in enumerate(data1): | |||
| assert i % 3 == d[0][0] | |||
| @@ -12,15 +12,13 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import mindspore._c_dataengine as cde | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore.dataset.text import to_str, to_bytes | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore._c_dataengine as cde | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.dataset.text import to_str | |||
| # pylint: disable=comparison-with-itself | |||
| def test_basic(): | |||
| @@ -34,7 +32,7 @@ def compare(strings): | |||
| arr = np.array(strings, dtype='S') | |||
| def gen(): | |||
| yield arr, | |||
| (yield arr,) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| @@ -50,7 +48,7 @@ def test_generator(): | |||
| def test_batching_strings(): | |||
| def gen(): | |||
| yield np.array(["ab", "cde", "121"], dtype='S'), | |||
| yield (np.array(["ab", "cde", "121"], dtype='S'),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10) | |||
| @@ -62,7 +60,7 @@ def test_batching_strings(): | |||
| def test_map(): | |||
| def gen(): | |||
| yield np.array(["ab cde 121"], dtype='S'), | |||
| yield (np.array(["ab cde 121"], dtype='S'),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| @@ -79,7 +77,7 @@ def test_map(): | |||
| def test_map2(): | |||
| def gen(): | |||
| yield np.array(["ab cde 121"], dtype='S'), | |||
| yield (np.array(["ab cde 121"], dtype='S'),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| @@ -215,7 +215,7 @@ def test_case_tf_file_no_schema_columns_list(): | |||
| assert row["col_sint16"] == [-32768] | |||
| with pytest.raises(KeyError) as info: | |||
| a = row["col_sint32"] | |||
| _ = row["col_sint32"] | |||
| assert "col_sint32" in str(info.value) | |||
| @@ -234,7 +234,7 @@ def test_tf_record_schema_columns_list(): | |||
| assert row["col_sint16"] == [-32768] | |||
| with pytest.raises(KeyError) as info: | |||
| a = row["col_sint32"] | |||
| _ = row["col_sint32"] | |||
| assert "col_sint32" in str(info.value) | |||
| @@ -246,7 +246,7 @@ def test_case_invalid_files(): | |||
| data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) | |||
| with pytest.raises(RuntimeError) as info: | |||
| row = data.create_dict_iterator().get_next() | |||
| _ = data.create_dict_iterator().get_next() | |||
| assert "cannot be opened" in str(info.value) | |||
| assert "not valid tfrecord files" in str(info.value) | |||
| assert valid_file not in str(info.value) | |||
| @@ -123,7 +123,7 @@ def test_to_type_03(): | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Numpy" in str(e) | |||
| @@ -145,7 +145,7 @@ def test_to_type_04(): | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "missing" in str(e) | |||
| @@ -167,7 +167,7 @@ def test_to_type_05(): | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "data type" in str(e) | |||
| @@ -59,7 +59,7 @@ def test_uniform_augment(plot=False, num_ops=2): | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, label) in enumerate(ds_original): | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -87,7 +87,7 @@ def test_uniform_augment(plot=False, num_ops=2): | |||
| ds_ua = ds_ua.batch(512) | |||
| for idx, (image, label) in enumerate(ds_ua): | |||
| for idx, (image, _) in enumerate(ds_ua): | |||
| if idx == 0: | |||
| images_ua = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -122,7 +122,7 @@ def test_cpp_uniform_augment(plot=False, num_ops=2): | |||
| ds_original = ds_original.batch(512) | |||
| for idx, (image, label) in enumerate(ds_original): | |||
| for idx, (image, _) in enumerate(ds_original): | |||
| if idx == 0: | |||
| images_original = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -149,7 +149,7 @@ def test_cpp_uniform_augment(plot=False, num_ops=2): | |||
| ds_ua = ds_ua.batch(512) | |||
| for idx, (image, label) in enumerate(ds_ua): | |||
| for idx, (image, _) in enumerate(ds_ua): | |||
| if idx == 0: | |||
| images_ua = np.transpose(image, (0, 2, 3, 1)) | |||
| else: | |||
| @@ -180,9 +180,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): | |||
| F.Invert()] | |||
| try: | |||
| uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "operations" in str(e) | |||
| @@ -200,9 +200,9 @@ def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | |||
| C.RandomRotation(degrees=45)] | |||
| try: | |||
| uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "num_ops" in str(e) | |||
| @@ -220,9 +220,9 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): | |||
| C.RandomRotation(degrees=45)] | |||
| try: | |||
| uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "num_ops" in str(e) | |||
| @@ -239,9 +239,9 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): | |||
| C.RandomRotation(degrees=45)] | |||
| try: | |||
| uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "integer" in str(e) | |||
| @@ -250,7 +250,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): | |||
| Test UniformAugment with greater crop size | |||
| """ | |||
| logger.info("Test CPP UniformAugment with random_crop bad input") | |||
| batch_size=2 | |||
| batch_size = 2 | |||
| cifar10_dir = "../data/dataset/testCifar10Data" | |||
| ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3] | |||
| @@ -266,9 +266,9 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): | |||
| ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1) | |||
| num_batches = 0 | |||
| try: | |||
| for data in ds1.create_dict_iterator(): | |||
| for _ in ds1.create_dict_iterator(): | |||
| num_batches += 1 | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| assert "Crop size" in str(e) | |||
| @@ -75,6 +75,7 @@ def test_variable_size_batch(): | |||
| return batchInfo.get_epoch_num() + 1 | |||
| def simple_copy(colList, batchInfo): | |||
| _ = batchInfo | |||
| return ([np.copy(arr) for arr in colList],) | |||
| def test_repeat_batch(gen_num, r, drop, func, res): | |||
| @@ -186,6 +187,7 @@ def test_batch_multi_col_map(): | |||
| yield (np.array([i]), np.array([i ** 2])) | |||
| def col1_col2_add_num(col1, col2, batchInfo): | |||
| _ = batchInfo | |||
| return ([[np.copy(arr + 100) for arr in col1], | |||
| [np.copy(arr + 300) for arr in col2]]) | |||
| @@ -287,11 +289,11 @@ def test_exception(): | |||
| def bad_batch_size(batchInfo): | |||
| raise StopIteration | |||
| return batchInfo.get_batch_num() | |||
| #return batchInfo.get_batch_num() | |||
| def bad_map_func(col, batchInfo): | |||
| raise StopIteration | |||
| return (col,) | |||
| #return (col,) | |||
| data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size) | |||
| try: | |||
| @@ -143,7 +143,7 @@ def test_zip_exception_01(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in zipped dataz: {}".format(num_iter)) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| @@ -164,7 +164,7 @@ def test_zip_exception_02(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in zipped dataz: {}".format(num_iter)) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| @@ -185,7 +185,7 @@ def test_zip_exception_03(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in zipped dataz: {}".format(num_iter)) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| @@ -205,7 +205,7 @@ def test_zip_exception_04(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in zipped dataz: {}".format(num_iter)) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| @@ -226,7 +226,7 @@ def test_zip_exception_05(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in zipped dataz: {}".format(num_iter)) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| @@ -246,7 +246,7 @@ def test_zip_exception_06(): | |||
| num_iter += 1 | |||
| logger.info("Number of data in zipped dataz: {}".format(num_iter)) | |||
| except BaseException as e: | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| @@ -300,16 +300,16 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file): | |||
| info = reader.read_category_info() | |||
| logger.info("category info: {}".format(info)) | |||
| with pytest.raises(ParamValueError) as err: | |||
| with pytest.raises(ParamValueError): | |||
| reader.read_at_page_by_id(0, "0", 1) | |||
| with pytest.raises(ParamValueError) as err: | |||
| with pytest.raises(ParamValueError): | |||
| reader.read_at_page_by_id(0, 0, "b") | |||
| with pytest.raises(ParamValueError) as err: | |||
| with pytest.raises(ParamValueError): | |||
| reader.read_at_page_by_name("822", "e", 1) | |||
| with pytest.raises(ParamValueError) as err: | |||
| with pytest.raises(ParamValueError): | |||
| reader.read_at_page_by_name("822", 0, "qwer") | |||
| with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | |||
| @@ -330,14 +330,14 @@ def test_mindpage_filename_not_exist(fixture_cv_file): | |||
| info = reader.read_category_info() | |||
| logger.info("category info: {}".format(info)) | |||
| with pytest.raises(MRMFetchDataError) as err: | |||
| with pytest.raises(MRMFetchDataError): | |||
| reader.read_at_page_by_id(9999, 0, 1) | |||
| with pytest.raises(MRMFetchDataError) as err: | |||
| with pytest.raises(MRMFetchDataError): | |||
| reader.read_at_page_by_name("abc.jpg", 0, 1) | |||
| with pytest.raises(ParamValueError) as err: | |||
| with pytest.raises(ParamValueError): | |||
| reader.read_at_page_by_name(1, 0, 1) | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| _ = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| @@ -14,10 +14,9 @@ | |||
| """test mnist to mindrecord tool""" | |||
| import gzip | |||
| import os | |||
| import pytest | |||
| import numpy as np | |||
| import cv2 | |||
| import pytest | |||
| from mindspore import log as logger | |||
| from mindspore.mindrecord import FileReader | |||
| @@ -14,12 +14,12 @@ | |||
| # ============================================================================ | |||
| """utils for test""" | |||
| import collections | |||
| import json | |||
| import numpy as np | |||
| import os | |||
| import re | |||
| import string | |||
| import collections | |||
| import json | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| @@ -185,7 +185,7 @@ def get_nlp_data(dir_name, vocab_file, num): | |||
| """ | |||
| if not os.path.isdir(dir_name): | |||
| raise IOError("Directory {} not exists".format(dir_name)) | |||
| for root, dirs, files in os.walk(dir_name): | |||
| for root, _, files in os.walk(dir_name): | |||
| for index, file_name_extension in enumerate(files): | |||
| if index < num: | |||
| file_path = os.path.join(root, file_name_extension) | |||