From: @lixiachen Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -30,6 +30,9 @@ | |||
| #endif | |||
| #include "minddata/dataset/kernels/ir/validators.h" | |||
| #ifdef ENABLE_PYTHON | |||
| #include "minddata/dataset/kernels/py_func_op.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -78,7 +81,12 @@ Status OneHotOperation::ValidateParams() { | |||
| std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | |||
| // PreBuiltOperation | |||
| PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {} | |||
| PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) { | |||
| #ifdef ENABLE_PYTHON | |||
| auto pyfunc_tensor_op = std::dynamic_pointer_cast<PyFuncOp>(tensor_op); | |||
| if (pyfunc_tensor_op && pyfunc_tensor_op->IsRandom()) random_op_ = true; | |||
| #endif | |||
| } | |||
| Status PreBuiltOperation::ValidateParams() { return Status::OK(); } | |||
| @@ -129,5 +129,12 @@ Status PyFuncOp::to_json(nlohmann::json *out_json) { | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| bool PyFuncOp::IsRandom() { | |||
| bool random = true; | |||
| if (py::hasattr(py_func_ptr_, "random") && py::reinterpret_borrow<py::bool_>(py_func_ptr_.attr("random")) == false) | |||
| random = false; | |||
| return random; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -51,6 +51,10 @@ class PyFuncOp : public TensorOp { | |||
| std::string Name() const override { return kPyFuncOp; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Check whether this pyfunc op is deterministic | |||
| /// \return True if this pyfunc op is random | |||
| bool IsRandom(); | |||
| private: | |||
| py::function py_func_ptr_; | |||
| DataType::Type output_type_; | |||
| @@ -552,6 +552,7 @@ class PythonTokenizer: | |||
| @check_python_tokenizer | |||
| def __init__(self, tokenizer): | |||
| self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)') | |||
| self.random = False | |||
| def __call__(self, in_array): | |||
| in_array = to_str(in_array) | |||
| @@ -21,6 +21,11 @@ from .validators import check_one_hot_op, check_compose_list, check_random_apply | |||
| from . import py_transforms_util as util | |||
| def not_random(function): | |||
| function.random = False | |||
| return function | |||
| class OneHotOp: | |||
| """ | |||
| Apply one hot encoding transformation to the input label, make label be more smoothing and continuous. | |||
| @@ -42,6 +47,7 @@ class OneHotOp: | |||
| def __init__(self, num_classes, smoothing_rate=0.0): | |||
| self.num_classes = num_classes | |||
| self.smoothing_rate = smoothing_rate | |||
| self.random = False | |||
| def __call__(self, label): | |||
| """ | |||
| @@ -114,6 +120,8 @@ class Compose: | |||
| @check_compose_list | |||
| def __init__(self, transforms): | |||
| self.transforms = transforms | |||
| if all(hasattr(transform, "random") and not transform.random for transform in self.transforms): | |||
| self.random = False | |||
| @check_compose_call | |||
| def __call__(self, *args): | |||
| @@ -45,6 +45,11 @@ DE_PY_BORDER_TYPE = {Border.CONSTANT: 'constant', | |||
| Border.SYMMETRIC: 'symmetric'} | |||
| def not_random(function): | |||
| function.random = False | |||
| return function | |||
| class ToTensor: | |||
| """ | |||
| Convert the input NumPy image array or PIL image of shape (H, W, C) to a NumPy ndarray of shape (C, H, W). | |||
| @@ -70,6 +75,7 @@ class ToTensor: | |||
| def __init__(self, output_type=np.float32): | |||
| self.output_type = output_type | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -105,6 +111,7 @@ class ToType: | |||
| def __init__(self, output_type): | |||
| self.output_type = output_type | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -132,6 +139,9 @@ class HWC2CHW: | |||
| ... input_columns="image") | |||
| """ | |||
| def __init__(self): | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| @@ -160,6 +170,9 @@ class ToPIL: | |||
| ... input_columns="image") | |||
| """ | |||
| def __init__(self): | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| @@ -187,6 +200,9 @@ class Decode: | |||
| ... input_columns="image") | |||
| """ | |||
| def __init__(self): | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| @@ -227,6 +243,7 @@ class Normalize: | |||
| def __init__(self, mean, std): | |||
| self.mean = mean | |||
| self.std = std | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -271,6 +288,7 @@ class NormalizePad: | |||
| self.mean = mean | |||
| self.std = std | |||
| self.dtype = dtype | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -456,6 +474,7 @@ class Resize: | |||
| def __init__(self, size, interpolation=Inter.BILINEAR): | |||
| self.size = size | |||
| self.interpolation = DE_PY_INTER_MODE[interpolation] | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -550,6 +569,7 @@ class CenterCrop: | |||
| @check_crop | |||
| def __init__(self, size): | |||
| self.size = size | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -700,6 +720,7 @@ class FiveCrop: | |||
| @check_crop | |||
| def __init__(self, size): | |||
| self.size = size | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -744,6 +765,7 @@ class TenCrop: | |||
| size = (size, size) | |||
| self.size = size | |||
| self.use_vertical_flip = use_vertical_flip | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -781,6 +803,7 @@ class Grayscale: | |||
| @check_num_channels | |||
| def __init__(self, num_output_channels=1): | |||
| self.num_output_channels = num_output_channels | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -884,6 +907,7 @@ class Pad: | |||
| self.padding = padding | |||
| self.fill_value = fill_value | |||
| self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -1030,6 +1054,7 @@ class Cutout: | |||
| def __init__(self, length, num_patches=1): | |||
| self.length = length | |||
| self.num_patches = num_patches | |||
| self.random = False | |||
| def __call__(self, np_img): | |||
| """ | |||
| @@ -1087,6 +1112,7 @@ class LinearTransformation: | |||
| def __init__(self, transformation_matrix, mean_vector): | |||
| self.transformation_matrix = transformation_matrix | |||
| self.mean_vector = mean_vector | |||
| self.random = False | |||
| def __call__(self, np_img): | |||
| """ | |||
| @@ -1229,6 +1255,7 @@ class MixUp: | |||
| self.batch_size = batch_size | |||
| self.alpha = alpha | |||
| self.is_single = is_single | |||
| self.random = False | |||
| def __call__(self, image, label): | |||
| """ | |||
| @@ -1268,6 +1295,7 @@ class RgbToHsv: | |||
| def __init__(self, is_hwc=False): | |||
| self.is_hwc = is_hwc | |||
| self.random = False | |||
| def __call__(self, rgb_imgs): | |||
| """ | |||
| @@ -1304,6 +1332,7 @@ class HsvToRgb: | |||
| def __init__(self, is_hwc=False): | |||
| self.is_hwc = is_hwc | |||
| self.random = False | |||
| def __call__(self, hsv_imgs): | |||
| """ | |||
| @@ -1414,6 +1443,7 @@ class AutoContrast: | |||
| def __init__(self, cutoff=0.0, ignore=None): | |||
| self.cutoff = cutoff | |||
| self.ignore = ignore | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -1443,6 +1473,9 @@ class Invert: | |||
| ... input_columns="image") | |||
| """ | |||
| def __init__(self): | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| @@ -1472,6 +1505,9 @@ class Equalize: | |||
| """ | |||
| def __init__(self): | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| @@ -1516,6 +1552,7 @@ class UniformAugment: | |||
| def __init__(self, transforms, num_ops=2): | |||
| self.transforms = transforms | |||
| self.num_ops = num_ops | |||
| self.random = False | |||
| def __call__(self, img): | |||
| """ | |||
| @@ -318,6 +318,9 @@ HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1 | |||
| HandleRcExit $? 0 0 | |||
| PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 | |||
| HandleRcExit $? 0 0 | |||
| for i in $(seq 1 3) | |||
| do | |||
| test_name="test_cache_nomap_multiple_cache${i}" | |||
| @@ -20,6 +20,7 @@ import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.vision.c_transforms as c_vision | |||
| import mindspore.dataset.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| from util import save_and_check_md5 | |||
| @@ -481,7 +482,7 @@ def test_cache_map_failure7(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| data = ds.GeneratorDataset(generator_1d, ["data"]) | |||
| data = data.map((lambda x: x), ["data"], cache=some_cache) | |||
| data = data.map(py_vision.not_random(lambda x: x), ["data"], cache=some_cache) | |||
| data = data.repeat(4) | |||
| with pytest.raises(RuntimeError) as e: | |||
| @@ -17,11 +17,13 @@ Testing cache operator with non-mappable datasets | |||
| """ | |||
| import os | |||
| import itertools | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| import mindspore.dataset.vision.c_transforms as c_vision | |||
| import mindspore.dataset.vision.py_transforms as py_vision | |||
| from mindspore import log as logger | |||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||
| @@ -41,6 +43,9 @@ CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json' | |||
| CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' | |||
| TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" | |||
| PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"] | |||
| PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json" | |||
| GENERATE_GOLDEN = False | |||
| @@ -1633,7 +1638,7 @@ def test_cache_nomap_clue2(): | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) | |||
| ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) | |||
| ds1 = ds1.map(py_vision.not_random(lambda x: x), ["label"], cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| @@ -1710,7 +1715,7 @@ def test_cache_nomap_csv2(): | |||
| ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) | |||
| ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache) | |||
| ds1 = ds1.map(py_vision.not_random(lambda x: x), ["col1"], cache=some_cache) | |||
| num_epoch = 4 | |||
| iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) | |||
| @@ -2124,6 +2129,139 @@ def test_cache_nomap_failure5(): | |||
| logger.info('test_cache_nomap_failure5 Ended.\n') | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_pyfunc_lambda(): | |||
| """ | |||
| Test cache after map op with a python lambda function. | |||
| Only allowed if the lambda function is wrapped by 'pyvision.not_random', otherwise an error will be raised. | |||
| Cache | |||
| | | |||
| Map(lambda function1, lambda function2) | |||
| | | |||
| TFRecord | |||
| """ | |||
| logger.info("Test cache nomap pyfunc lambda") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 12 records in it | |||
| data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False) | |||
| transforms = [py_vision.not_random(lambda x: x + x), py_vision.not_random(lambda x: x - 1)] | |||
| data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert num_iter == 12 | |||
| other_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False) | |||
| ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||
| logger.info("test_cache_nomap_pyfunc_lambda Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_pyfunc_builtin(): | |||
| """ | |||
| Test cache after map op with a python builtin PyFunc. | |||
| An error will be raised if the builtin pyfunc containing random operation. | |||
| Cache | |||
| | | |||
| Map([builtin pyfunc1, builtin pyfunc2]) | |||
| | | |||
| TFRecord | |||
| """ | |||
| logger.info("Test cache nomap pyfunc builtin") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||
| ds1 = ds1.map(operations=[py_vision.Decode(), py_vision.ToTensor()], input_columns=["image"], cache=some_cache) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| other_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||
| ds2 = ds2.map(operations=[py_vision.Decode(), py_vision.RandomCrop(224), py_vision.ToTensor()], | |||
| input_columns=["image"], cache=other_cache) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||
| logger.info("test_cache_nomap_pyfunc_builtin Ended.\n") | |||
| @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") | |||
| def test_cache_nomap_pyfunc_function(): | |||
| """ | |||
| Test cache after map op with a python customized function. | |||
| Only allowed if the function is decorated with 'py_vision.not_random', otherwise an error will be raised. | |||
| Cache | |||
| | | |||
| Map([function1, function2]) | |||
| | | |||
| TFRecord | |||
| """ | |||
| @py_vision.not_random | |||
| def not_random_func(x): | |||
| return np.ones(x.shape, dtype=x.dtype) | |||
| def normal_func(x): | |||
| return np.ones(x.shape, dtype=x.dtype) | |||
| logger.info("Test cache nomap pyfunc function") | |||
| if "SESSION_ID" in os.environ: | |||
| session_id = int(os.environ['SESSION_ID']) | |||
| else: | |||
| raise RuntimeError("Testcase requires SESSION_ID environment variable") | |||
| some_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||
| ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache) | |||
| num_iter = 0 | |||
| for _ in ds1.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| other_cache = ds.DatasetCache(session_id=session_id, size=0) | |||
| # This dataset has 3 records in it only | |||
| ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||
| ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache) | |||
| with pytest.raises(RuntimeError) as e: | |||
| num_iter = 0 | |||
| for _ in ds2.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) | |||
| logger.info("test_cache_nomap_pyfunc_function Ended.\n") | |||
| if __name__ == '__main__': | |||
| # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' | |||
| # since cache server is required to be brought up first | |||
| @@ -2180,3 +2318,6 @@ if __name__ == '__main__': | |||
| test_cache_nomap_failure3() | |||
| test_cache_nomap_failure4() | |||
| test_cache_nomap_failure5() | |||
| test_cache_nomap_pyfunc_lambda() | |||
| test_cache_nomap_pyfunc_builtin() | |||
| test_cache_nomap_pyfunc_function() | |||