|
|
|
@@ -14,11 +14,12 @@ |
|
|
|
# ============================================================================== |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore import log as logger |
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
# just a basic test with parallel random data op |
|
|
|
def test_randomdataset_basic1(): |
|
|
|
print("Test randomdataset basic") |
|
|
|
logger.info("Test randomdataset basic") |
|
|
|
|
|
|
|
schema = ds.Schema() |
|
|
|
schema.add_column('image', de_type=mstype.uint8, shape=[2]) |
|
|
|
@@ -31,16 +32,16 @@ def test_randomdataset_basic1(): |
|
|
|
num_iter = 0 |
|
|
|
for data in ds1.create_dict_iterator(): # each data is a dictionary |
|
|
|
# in this example, each dictionary has keys "image" and "label" |
|
|
|
print("{} image: {}".format(num_iter, data["image"])) |
|
|
|
print("{} label: {}".format(num_iter, data["label"])) |
|
|
|
logger.info("{} image: {}".format(num_iter, data["image"])) |
|
|
|
logger.info("{} label: {}".format(num_iter, data["label"])) |
|
|
|
num_iter += 1 |
|
|
|
|
|
|
|
print("Number of data in ds1: ", num_iter) |
|
|
|
logger.info("Number of data in ds1: ", num_iter) |
|
|
|
assert(num_iter == 200) |
|
|
|
|
|
|
|
# Another simple test |
|
|
|
def test_randomdataset_basic2(): |
|
|
|
print("Test randomdataset basic 2") |
|
|
|
logger.info("Test randomdataset basic 2") |
|
|
|
|
|
|
|
schema = ds.Schema() |
|
|
|
schema.add_column('image', de_type=mstype.uint8, shape=[640,480,3]) # 921600 bytes (a bit less than 1 MB per image) |
|
|
|
@@ -55,16 +56,16 @@ def test_randomdataset_basic2(): |
|
|
|
num_iter = 0 |
|
|
|
for data in ds1.create_dict_iterator(): # each data is a dictionary |
|
|
|
# in this example, each dictionary has keys "image" and "label" |
|
|
|
#print(data["image"]) |
|
|
|
print("printing the label: {}".format(data["label"])) |
|
|
|
#logger.info(data["image"]) |
|
|
|
logger.info("printing the label: {}".format(data["label"])) |
|
|
|
num_iter += 1 |
|
|
|
|
|
|
|
print("Number of data in ds1: ", num_iter) |
|
|
|
logger.info("Number of data in ds1: ", num_iter) |
|
|
|
assert(num_iter == 40) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_randomdataset_basic1() |
|
|
|
test_randomdataset_basic2() |
|
|
|
print('test_randomdataset_basic Ended.\n') |
|
|
|
logger.info('test_randomdataset_basic Ended.\n') |
|
|
|
|