From: @lixiachen Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -30,6 +30,9 @@ | |||||
| #endif | #endif | ||||
| #include "minddata/dataset/kernels/ir/validators.h" | #include "minddata/dataset/kernels/ir/validators.h" | ||||
| #ifdef ENABLE_PYTHON | |||||
| #include "minddata/dataset/kernels/py_func_op.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -78,7 +81,12 @@ Status OneHotOperation::ValidateParams() { | |||||
| std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | ||||
| // PreBuiltOperation | // PreBuiltOperation | ||||
| PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {} | |||||
| PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) { | |||||
| #ifdef ENABLE_PYTHON | |||||
| auto pyfunc_tensor_op = std::dynamic_pointer_cast<PyFuncOp>(tensor_op); | |||||
| if (pyfunc_tensor_op && pyfunc_tensor_op->IsRandom()) random_op_ = true; | |||||
| #endif | |||||
| } | |||||
| Status PreBuiltOperation::ValidateParams() { return Status::OK(); } | Status PreBuiltOperation::ValidateParams() { return Status::OK(); } | ||||
| @@ -129,5 +129,12 @@ Status PyFuncOp::to_json(nlohmann::json *out_json) { | |||||
| *out_json = args; | *out_json = args; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| bool PyFuncOp::IsRandom() { | |||||
| bool random = true; | |||||
| if (py::hasattr(py_func_ptr_, "random") && py::reinterpret_borrow<py::bool_>(py_func_ptr_.attr("random")) == false) | |||||
| random = false; | |||||
| return random; | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -51,6 +51,10 @@ class PyFuncOp : public TensorOp { | |||||
| std::string Name() const override { return kPyFuncOp; } | std::string Name() const override { return kPyFuncOp; } | ||||
| Status to_json(nlohmann::json *out_json) override; | Status to_json(nlohmann::json *out_json) override; | ||||
| /// \brief Check whether this pyfunc op is deterministic | |||||
| /// \return True if this pyfunc op is random | |||||
| bool IsRandom(); | |||||
| private: | private: | ||||
| py::function py_func_ptr_; | py::function py_func_ptr_; | ||||
| DataType::Type output_type_; | DataType::Type output_type_; | ||||
| @@ -552,6 +552,7 @@ class PythonTokenizer: | |||||
| @check_python_tokenizer | @check_python_tokenizer | ||||
| def __init__(self, tokenizer): | def __init__(self, tokenizer): | ||||
| self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)') | self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)') | ||||
| self.random = False | |||||
| def __call__(self, in_array): | def __call__(self, in_array): | ||||
| in_array = to_str(in_array) | in_array = to_str(in_array) | ||||
| @@ -21,6 +21,11 @@ from .validators import check_one_hot_op, check_compose_list, check_random_apply | |||||
| from . import py_transforms_util as util | from . import py_transforms_util as util | ||||
| def not_random(function): | |||||
| function.random = False | |||||
| return function | |||||
| class OneHotOp: | class OneHotOp: | ||||
| """ | """ | ||||
| Apply one hot encoding transformation to the input label, make label be more smoothing and continuous. | Apply one hot encoding transformation to the input label, make label be more smoothing and continuous. | ||||
| @@ -42,6 +47,7 @@ class OneHotOp: | |||||
| def __init__(self, num_classes, smoothing_rate=0.0): | def __init__(self, num_classes, smoothing_rate=0.0): | ||||
| self.num_classes = num_classes | self.num_classes = num_classes | ||||
| self.smoothing_rate = smoothing_rate | self.smoothing_rate = smoothing_rate | ||||
| self.random = False | |||||
| def __call__(self, label): | def __call__(self, label): | ||||
| """ | """ | ||||
| @@ -114,6 +120,8 @@ class Compose: | |||||
| @check_compose_list | @check_compose_list | ||||
| def __init__(self, transforms): | def __init__(self, transforms): | ||||
| self.transforms = transforms | self.transforms = transforms | ||||
| if all(hasattr(transform, "random") and not transform.random for transform in self.transforms): | |||||
| self.random = False | |||||
| @check_compose_call | @check_compose_call | ||||
| def __call__(self, *args): | def __call__(self, *args): | ||||
| @@ -45,6 +45,11 @@ DE_PY_BORDER_TYPE = {Border.CONSTANT: 'constant', | |||||
| Border.SYMMETRIC: 'symmetric'} | Border.SYMMETRIC: 'symmetric'} | ||||
| def not_random(function): | |||||
| function.random = False | |||||
| return function | |||||
| class ToTensor: | class ToTensor: | ||||
| """ | """ | ||||
| Convert the input NumPy image array or PIL image of shape (H, W, C) to a NumPy ndarray of shape (C, H, W). | Convert the input NumPy image array or PIL image of shape (H, W, C) to a NumPy ndarray of shape (C, H, W). | ||||
| @@ -70,6 +75,7 @@ class ToTensor: | |||||
| def __init__(self, output_type=np.float32): | def __init__(self, output_type=np.float32): | ||||
| self.output_type = output_type | self.output_type = output_type | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -105,6 +111,7 @@ class ToType: | |||||
| def __init__(self, output_type): | def __init__(self, output_type): | ||||
| self.output_type = output_type | self.output_type = output_type | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -132,6 +139,9 @@ class HWC2CHW: | |||||
| ... input_columns="image") | ... input_columns="image") | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| Call method. | Call method. | ||||
| @@ -160,6 +170,9 @@ class ToPIL: | |||||
| ... input_columns="image") | ... input_columns="image") | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| Call method. | Call method. | ||||
| @@ -187,6 +200,9 @@ class Decode: | |||||
| ... input_columns="image") | ... input_columns="image") | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| Call method. | Call method. | ||||
| @@ -227,6 +243,7 @@ class Normalize: | |||||
| def __init__(self, mean, std): | def __init__(self, mean, std): | ||||
| self.mean = mean | self.mean = mean | ||||
| self.std = std | self.std = std | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -271,6 +288,7 @@ class NormalizePad: | |||||
| self.mean = mean | self.mean = mean | ||||
| self.std = std | self.std = std | ||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -456,6 +474,7 @@ class Resize: | |||||
| def __init__(self, size, interpolation=Inter.BILINEAR): | def __init__(self, size, interpolation=Inter.BILINEAR): | ||||
| self.size = size | self.size = size | ||||
| self.interpolation = DE_PY_INTER_MODE[interpolation] | self.interpolation = DE_PY_INTER_MODE[interpolation] | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -550,6 +569,7 @@ class CenterCrop: | |||||
| @check_crop | @check_crop | ||||
| def __init__(self, size): | def __init__(self, size): | ||||
| self.size = size | self.size = size | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -700,6 +720,7 @@ class FiveCrop: | |||||
| @check_crop | @check_crop | ||||
| def __init__(self, size): | def __init__(self, size): | ||||
| self.size = size | self.size = size | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -744,6 +765,7 @@ class TenCrop: | |||||
| size = (size, size) | size = (size, size) | ||||
| self.size = size | self.size = size | ||||
| self.use_vertical_flip = use_vertical_flip | self.use_vertical_flip = use_vertical_flip | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -781,6 +803,7 @@ class Grayscale: | |||||
| @check_num_channels | @check_num_channels | ||||
| def __init__(self, num_output_channels=1): | def __init__(self, num_output_channels=1): | ||||
| self.num_output_channels = num_output_channels | self.num_output_channels = num_output_channels | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -884,6 +907,7 @@ class Pad: | |||||
| self.padding = padding | self.padding = padding | ||||
| self.fill_value = fill_value | self.fill_value = fill_value | ||||
| self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] | self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -1030,6 +1054,7 @@ class Cutout: | |||||
| def __init__(self, length, num_patches=1): | def __init__(self, length, num_patches=1): | ||||
| self.length = length | self.length = length | ||||
| self.num_patches = num_patches | self.num_patches = num_patches | ||||
| self.random = False | |||||
| def __call__(self, np_img): | def __call__(self, np_img): | ||||
| """ | """ | ||||
| @@ -1087,6 +1112,7 @@ class LinearTransformation: | |||||
| def __init__(self, transformation_matrix, mean_vector): | def __init__(self, transformation_matrix, mean_vector): | ||||
| self.transformation_matrix = transformation_matrix | self.transformation_matrix = transformation_matrix | ||||
| self.mean_vector = mean_vector | self.mean_vector = mean_vector | ||||
| self.random = False | |||||
| def __call__(self, np_img): | def __call__(self, np_img): | ||||
| """ | """ | ||||
| @@ -1229,6 +1255,7 @@ class MixUp: | |||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.alpha = alpha | self.alpha = alpha | ||||
| self.is_single = is_single | self.is_single = is_single | ||||
| self.random = False | |||||
| def __call__(self, image, label): | def __call__(self, image, label): | ||||
| """ | """ | ||||
| @@ -1268,6 +1295,7 @@ class RgbToHsv: | |||||
| def __init__(self, is_hwc=False): | def __init__(self, is_hwc=False): | ||||
| self.is_hwc = is_hwc | self.is_hwc = is_hwc | ||||
| self.random = False | |||||
| def __call__(self, rgb_imgs): | def __call__(self, rgb_imgs): | ||||
| """ | """ | ||||
| @@ -1304,6 +1332,7 @@ class HsvToRgb: | |||||
| def __init__(self, is_hwc=False): | def __init__(self, is_hwc=False): | ||||
| self.is_hwc = is_hwc | self.is_hwc = is_hwc | ||||
| self.random = False | |||||
| def __call__(self, hsv_imgs): | def __call__(self, hsv_imgs): | ||||
| """ | """ | ||||
| @@ -1414,6 +1443,7 @@ class AutoContrast: | |||||
| def __init__(self, cutoff=0.0, ignore=None): | def __init__(self, cutoff=0.0, ignore=None): | ||||
| self.cutoff = cutoff | self.cutoff = cutoff | ||||
| self.ignore = ignore | self.ignore = ignore | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -1443,6 +1473,9 @@ class Invert: | |||||
| ... input_columns="image") | ... input_columns="image") | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| Call method. | Call method. | ||||
| @@ -1472,6 +1505,9 @@ class Equalize: | |||||
| """ | """ | ||||
| def __init__(self): | |||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| Call method. | Call method. | ||||
| @@ -1516,6 +1552,7 @@ class UniformAugment: | |||||
| def __init__(self, transforms, num_ops=2): | def __init__(self, transforms, num_ops=2): | ||||
| self.transforms = transforms | self.transforms = transforms | ||||
| self.num_ops = num_ops | self.num_ops = num_ops | ||||
| self.random = False | |||||
| def __call__(self, img): | def __call__(self, img): | ||||
| """ | """ | ||||
| @@ -318,6 +318,9 @@ HandleRcExit $? 0 0 | |||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1 | PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1 | ||||
| HandleRcExit $? 0 0 | HandleRcExit $? 0 0 | ||||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 | |||||
| HandleRcExit $? 0 0 | |||||
| for i in $(seq 1 3) | for i in $(seq 1 3) | ||||
| do | do | ||||
| test_name="test_cache_nomap_multiple_cache${i}" | test_name="test_cache_nomap_multiple_cache${i}" | ||||
| @@ -20,6 +20,7 @@ import pytest | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.vision.c_transforms as c_vision | import mindspore.dataset.vision.c_transforms as c_vision | ||||
| import mindspore.dataset.vision.py_transforms as py_vision | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import save_and_check_md5 | from util import save_and_check_md5 | ||||
| @@ -481,7 +482,7 @@ def test_cache_map_failure7(): | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | some_cache = ds.DatasetCache(session_id=session_id, size=0) | ||||
| data = ds.GeneratorDataset(generator_1d, ["data"]) | data = ds.GeneratorDataset(generator_1d, ["data"]) | ||||
| data = data.map((lambda x: x), ["data"], cache=some_cache) | |||||
| data = data.map(py_vision.not_random(lambda x: x), ["data"], cache=some_cache) | |||||
| data = data.repeat(4) | data = data.repeat(4) | ||||
| with pytest.raises(RuntimeError) as e: | with pytest.raises(RuntimeError) as e: | ||||
| @@ -17,11 +17,13 @@ Testing cache operator with non-mappable datasets | |||||
| """ | """ | ||||
| import os | import os | ||||
| import itertools | import itertools | ||||
| import numpy as np | |||||
| import pytest | import pytest | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.text as text | import mindspore.dataset.text as text | ||||
| import mindspore.dataset.vision.c_transforms as c_vision | import mindspore.dataset.vision.c_transforms as c_vision | ||||
| import mindspore.dataset.vision.py_transforms as py_vision | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | ||||
| @@ -41,6 +43,9 @@ CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json' | |||||
| CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' | CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' | ||||
| TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" | TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" | ||||
| PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"] | |||||
| PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json" | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| @@ -1633,7 +1638,7 @@ def test_cache_nomap_clue2(): | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | some_cache = ds.DatasetCache(session_id=session_id, size=0) | ||||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) | ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) | ||||
| ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) | |||||
| ds1 = ds1.map(py_vision.not_random(lambda x: x), ["label"], cache=some_cache) | |||||
| num_epoch = 4 | num_epoch = 4 | ||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | ||||
| @@ -1710,7 +1715,7 @@ def test_cache_nomap_csv2(): | |||||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | ||||
| column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) | column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) | ||||
| ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache) | |||||
| ds1 = ds1.map(py_vision.not_random(lambda x: x), ["col1"], cache=some_cache) | |||||
| num_epoch = 4 | num_epoch = 4 | ||||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | ||||
| @@ -2124,6 +2129,139 @@ def test_cache_nomap_failure5(): | |||||
| logger.info('test_cache_nomap_failure5 Ended.\n') | logger.info('test_cache_nomap_failure5 Ended.\n') | ||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_pyfunc_lambda(): | |||||
| """ | |||||
| Test cache after map op with a python lambda function. | |||||
| Only allowed if the lambda function is wrapped by 'pyvision.not_random', otherwise an error will be raised. | |||||
| Cache | |||||
| | | |||||
| Map(lambda function1, lambda function2) | |||||
| | | |||||
| TFRecord | |||||
| """ | |||||
| logger.info("Test cache nomap pyfunc lambda") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| # This dataset has 12 records in it | |||||
| data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False) | |||||
| transforms = [py_vision.not_random(lambda x: x + x), py_vision.not_random(lambda x: x - 1)] | |||||
| data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache) | |||||
| num_iter = 0 | |||||
| for _ in data1.create_dict_iterator(num_epochs=1): | |||||
| num_iter += 1 | |||||
| assert num_iter == 12 | |||||
| other_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False) | |||||
| ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds2.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||||
| logger.info("test_cache_nomap_pyfunc_lambda Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_pyfunc_builtin(): | |||||
| """ | |||||
| Test cache after map op with a python builtin PyFunc. | |||||
| An error will be raised if the builtin pyfunc containing random operation. | |||||
| Cache | |||||
| | | |||||
| Map([builtin pyfunc1, builtin pyfunc2]) | |||||
| | | |||||
| TFRecord | |||||
| """ | |||||
| logger.info("Test cache nomap pyfunc builtin") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| # This dataset has 3 records in it only | |||||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| ds1 = ds1.map(operations=[py_vision.Decode(), py_vision.ToTensor()], input_columns=["image"], cache=some_cache) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||||
| num_iter += 1 | |||||
| assert num_iter == 3 | |||||
| other_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| # This dataset has 3 records in it only | |||||
| ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| ds2 = ds2.map(operations=[py_vision.Decode(), py_vision.RandomCrop(224), py_vision.ToTensor()], | |||||
| input_columns=["image"], cache=other_cache) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds2.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||||
| logger.info("test_cache_nomap_pyfunc_builtin Ended.\n") | |||||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||||
| def test_cache_nomap_pyfunc_function(): | |||||
| """ | |||||
| Test cache after map op with a python customized function. | |||||
| Only allowed if the function is decorated with 'py_vision.not_random', otherwise an error will be raised. | |||||
| Cache | |||||
| | | |||||
| Map([function1, function2]) | |||||
| | | |||||
| TFRecord | |||||
| """ | |||||
| @py_vision.not_random | |||||
| def not_random_func(x): | |||||
| return np.ones(x.shape, dtype=x.dtype) | |||||
| def normal_func(x): | |||||
| return np.ones(x.shape, dtype=x.dtype) | |||||
| logger.info("Test cache nomap pyfunc function") | |||||
| if "SESSION_ID" in os.environ: | |||||
| session_id = int(os.environ['SESSION_ID']) | |||||
| else: | |||||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| # This dataset has 3 records in it only | |||||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache) | |||||
| num_iter = 0 | |||||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||||
| num_iter += 1 | |||||
| assert num_iter == 3 | |||||
| other_cache = ds.DatasetCache(session_id=session_id, size=0) | |||||
| # This dataset has 3 records in it only | |||||
| ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache) | |||||
| with pytest.raises(RuntimeError) as e: | |||||
| num_iter = 0 | |||||
| for _ in ds2.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||||
| logger.info("test_cache_nomap_pyfunc_function Ended.\n") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' | # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' | ||||
| # since cache server is required to be brought up first | # since cache server is required to be brought up first | ||||
| @@ -2180,3 +2318,6 @@ if __name__ == '__main__': | |||||
| test_cache_nomap_failure3() | test_cache_nomap_failure3() | ||||
| test_cache_nomap_failure4() | test_cache_nomap_failure4() | ||||
| test_cache_nomap_failure5() | test_cache_nomap_failure5() | ||||
| test_cache_nomap_pyfunc_lambda() | |||||
| test_cache_nomap_pyfunc_builtin() | |||||
| test_cache_nomap_pyfunc_function() | |||||