Browse Source

!5385 dataset fixes: Update OneHot API doc; fixup UTs

Merge pull request !5385 from cathwong/ckw_dataset_ut_cleanup8
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ab29dbf98b
12 changed files with 61 additions and 54 deletions
  1. +29
    -29
      mindspore/dataset/engine/validators.py
  2. +1
    -2
      mindspore/dataset/text/validators.py
  3. +2
    -2
      mindspore/dataset/transforms/c_transforms.py
  4. +3
    -2
      mindspore/dataset/transforms/py_transforms.py
  5. +8
    -8
      mindspore/dataset/transforms/vision/validators.py
  6. BIN
      tests/ut/data/dataset/golden/random_color_01_result.npz
  7. BIN
      tests/ut/data/dataset/golden/random_sharpness_cpp_01_result.npz
  8. BIN
      tests/ut/data/dataset/golden/random_sharpness_py_01_result.npz
  9. +7
    -2
      tests/ut/python/dataset/test_five_crop.py
  10. +3
    -3
      tests/ut/python/dataset/test_random_color.py
  11. +6
    -4
      tests/ut/python/dataset/test_random_sharpness.py
  12. +2
    -2
      tests/ut/python/dataset/test_repeat.py

+ 29
- 29
mindspore/dataset/engine/validators.py View File

@@ -36,7 +36,7 @@ from .. import callback


def check_imagefolderdatasetv2(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2)."""
"""A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDatasetV2)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -62,7 +62,7 @@ def check_imagefolderdatasetv2(method):


def check_mnist_cifar_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -85,7 +85,7 @@ def check_mnist_cifar_dataset(method):


def check_manifestdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -112,7 +112,7 @@ def check_manifestdataset(method):


def check_tfrecorddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -138,7 +138,7 @@ def check_tfrecorddataset(method):


def check_vocdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(VOCDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -179,7 +179,7 @@ def check_vocdataset(method):


def check_cocodataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CocoDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -215,7 +215,7 @@ def check_cocodataset(method):


def check_celebadataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CelebADataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -251,7 +251,7 @@ def check_celebadataset(method):


def check_save(method):
"""A wrapper that wrap a parameter checker to the save op."""
"""A wrapper that wraps a parameter checker around the saved operator."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -271,7 +271,7 @@ def check_save(method):


def check_minddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -303,7 +303,7 @@ def check_minddataset(method):


def check_generatordataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -369,7 +369,7 @@ def check_generatordataset(method):


def check_random_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -794,7 +794,7 @@ def check_add_column(method):


def check_cluedataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -824,7 +824,7 @@ def check_cluedataset(method):


def check_csvdataset(method):
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -871,7 +871,7 @@ def check_csvdataset(method):


def check_textfiledataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -964,7 +964,7 @@ def check_gnn_graphdata(method):


def check_gnn_get_all_nodes(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -977,7 +977,7 @@ def check_gnn_get_all_nodes(method):


def check_gnn_get_all_edges(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_edges` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_all_edges` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -990,7 +990,7 @@ def check_gnn_get_all_edges(method):


def check_gnn_get_nodes_from_edges(method):
"""A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1003,7 +1003,7 @@ def check_gnn_get_nodes_from_edges(method):


def check_gnn_get_all_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1018,7 +1018,7 @@ def check_gnn_get_all_neighbors(method):


def check_gnn_get_sampled_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1046,7 +1046,7 @@ def check_gnn_get_sampled_neighbors(method):


def check_gnn_get_neg_sampled_neighbors(method):
"""A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1062,7 +1062,7 @@ def check_gnn_get_neg_sampled_neighbors(method):


def check_gnn_random_walk(method):
"""A wrapper that wraps a parameter checker to the GNN `random_walk` function."""
"""A wrapper that wraps a parameter checker around the GNN `random_walk` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1110,7 +1110,7 @@ def check_aligned_list(param, param_name, member_type):


def check_gnn_get_node_feature(method):
"""A wrapper that wraps a parameter checker to the GNN `get_node_feature` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_node_feature` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1132,7 +1132,7 @@ def check_gnn_get_node_feature(method):


def check_gnn_get_edge_feature(method):
"""A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function."""
"""A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1154,7 +1154,7 @@ def check_gnn_get_edge_feature(method):


def check_numpyslicesdataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -1195,17 +1195,17 @@ def check_numpyslicesdataset(method):


def check_paddeddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(PaddedDataset)."""
"""A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""

@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)

paddedSamples = param_dict.get("padded_samples")
if not paddedSamples:
padded_samples = param_dict.get("padded_samples")
if not padded_samples:
raise ValueError("Argument padded_samples cannot be empty")
type_check(paddedSamples, (list,), "padded_samples")
type_check(paddedSamples[0], (dict,), "padded_element")
type_check(padded_samples, (list,), "padded_samples")
type_check(padded_samples[0], (dict,), "padded_element")
return method(self, *args, **kwargs)

return new_method

+ 1
- 2
mindspore/dataset/text/validators.py View File

@@ -328,7 +328,7 @@ def check_from_dataset(method):
return new_method

def check_slidingwindow(method):
"""A wrapper that wrap a parameter checker to the original function(sliding window operation)."""
"""A wrapper that wraps a parameter checker to the original function(sliding window operation)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -496,4 +496,3 @@ def check_save_model(method):
return method(self, *args, **kwargs)

return new_method

+ 2
- 2
mindspore/dataset/transforms/c_transforms.py View File

@@ -31,8 +31,8 @@ class OneHot(cde.OneHotOp):
Tensor operation to apply one hot encoding.

Args:
num_classes (int): Number of classes of the label
it should be bigger than largest label number in dataset.
num_classes (int): Number of classes of the label.
It should be larger than the largest label number in the dataset.

Raises:
RuntimeError: feature size is bigger than num_classes.


+ 3
- 2
mindspore/dataset/transforms/py_transforms.py View File

@@ -27,8 +27,9 @@ class OneHotOp:
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.

Args:
num_classes (int): Num class of object in dataset, type is int and value over 0.
smoothing_rate (float): The adjustable Hyper parameter decides the label smoothing level , 0.0 means not do it.
num_classes (int): Number of classes of objects in dataset. Value must be larger than 0.
smoothing_rate (float, optional): Adjustable hyperparameter for label smoothing level.
(Default=0.0 means no smoothing is applied.)
"""

@check_one_hot_op


+ 8
- 8
mindspore/dataset/transforms/vision/validators.py View File

@@ -152,7 +152,7 @@ def check_erasing_value(value):


def check_crop(method):
"""A wrapper that wraps a parameter checker to the original function(crop operation)."""
"""A wrapper that wraps a parameter checker around the original function(crop operation)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -165,7 +165,7 @@ def check_crop(method):


def check_posterize(method):
""""A wrapper that wraps a parameter checker to the original function(posterize operation)."""
"""A wrapper that wraps a parameter checker around the original function(posterize operation)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -187,7 +187,7 @@ def check_posterize(method):


def check_resize_interpolation(method):
"""A wrapper that wraps a parameter checker to the original function(resize interpolation operation)."""
"""A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -202,7 +202,7 @@ def check_resize_interpolation(method):


def check_resize(method):
"""A wrapper that wraps a parameter checker to the original function(resize operation)."""
"""A wrapper that wraps a parameter checker around the original function(resize operation)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -235,7 +235,7 @@ def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)


def check_random_resize_crop(method):
"""A wrapper that wraps a parameter checker to the original function(random resize crop operation)."""
"""A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -250,7 +250,7 @@ def check_random_resize_crop(method):


def check_prob(method):
"""A wrapper that wraps a parameter checker(check the probability) to the original function."""
"""A wrapper that wraps a parameter checker (to confirm probability) around the original function."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -264,7 +264,7 @@ def check_prob(method):


def check_normalize_c(method):
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in C++)."""
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""

@wraps(method)
def new_method(self, *args, **kwargs):
@@ -277,7 +277,7 @@ def check_normalize_c(method):


def check_normalize_py(method):
"""A wrapper that wraps a parameter checker to the original function(normalize operation written in Python)."""
"""A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""

@wraps(method)
def new_method(self, *args, **kwargs):


BIN
tests/ut/data/dataset/golden/random_color_01_result.npz View File


BIN
tests/ut/data/dataset/golden/random_sharpness_cpp_01_result.npz View File


BIN
tests/ut/data/dataset/golden/random_sharpness_py_01_result.npz View File


+ 7
- 2
tests/ut/python/dataset/test_five_crop.py View File

@@ -86,8 +86,13 @@ def test_five_crop_error_msg():
transform = vision.ComposeOp(transforms)
data = data.map(input_columns=["image"], operations=transform())

with pytest.raises(RuntimeError):
data.create_tuple_iterator().__next__()
with pytest.raises(RuntimeError) as info:
for _ in data:
pass
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"

# error msg comes from ToTensor()
assert error_msg in str(info.value)


def test_five_crop_md5():


+ 3
- 3
tests/ut/python/dataset/test_random_color.py View File

@@ -149,7 +149,7 @@ def test_random_color_py_md5():
data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)

transforms = F.ComposeOp([F.Decode(),
F.RandomColor((0.1, 1.9)),
F.RandomColor((2.0, 2.5)),
F.ToTensor()])

data = data.map(input_columns="image", operations=transforms())
@@ -244,12 +244,12 @@ def test_random_color_c_errors():
if __name__ == "__main__":
test_random_color_py()
test_random_color_py(plot=True)
test_random_color_py(degrees=(0.5, 1.5), plot=True)
test_random_color_py(degrees=(2.0, 2.5), plot=True) # Test with degree values that show more obvious transformation
test_random_color_py_md5()

test_random_color_c()
test_random_color_c(plot=True)
test_random_color_c(degrees=(0.5, 1.5), plot=True, run_golden=False)
test_random_color_c(degrees=(2.0, 2.5), plot=True, run_golden=False) # Test with degree values that show more obvious transformation
test_random_color_c(degrees=(0.1, 0.1), plot=True, run_golden=False)
test_compare_random_color_op(plot=True)
test_random_color_c_errors()

+ 6
- 4
tests/ut/python/dataset/test_random_sharpness.py View File

@@ -103,7 +103,7 @@ def test_random_sharpness_py_md5():
# define map operations
transforms = [
F.Decode(),
F.RandomSharpness((0.1, 1.9)),
F.RandomSharpness((20.0, 25.0)),
F.ToTensor()
]
transform = F.ComposeOp(transforms)
@@ -193,7 +193,7 @@ def test_random_sharpness_c_md5():
# define map operations
transforms = [
C.Decode(),
C.RandomSharpness((0.1, 1.9))
C.RandomSharpness((10.0, 15.0))
]

# Generate dataset
@@ -337,14 +337,16 @@ def test_random_sharpness_invalid_params():

if __name__ == "__main__":
test_random_sharpness_py(plot=True)
test_random_sharpness_py(None, plot=True) # test with default values
test_random_sharpness_py(None, plot=True) # Test with default values
test_random_sharpness_py(degrees=(20.0, 25.0), plot=True) # Test with degree values that show more obvious transformation
test_random_sharpness_py_md5()
test_random_sharpness_c(plot=True)
test_random_sharpness_c(None, plot=True) # test with default values
test_random_sharpness_c(degrees=[10, 15], plot=True) # Test with degrees values that show more obvious transformation
test_random_sharpness_c_md5()
test_random_sharpness_c_py(degrees=[1.5, 1.5], plot=True)
test_random_sharpness_c_py(degrees=[1, 1], plot=True)
test_random_sharpness_c_py(degrees=[10, 10], plot=True)
test_random_sharpness_one_channel_c(degrees=[1.7, 1.7], plot=True)
test_random_sharpness_one_channel_c(degrees=None, plot=True) # test with default values
test_random_sharpness_one_channel_c(degrees=None, plot=True) # Test with default values
test_random_sharpness_invalid_params()

+ 2
- 2
tests/ut/python/dataset/test_repeat.py View File

@@ -303,7 +303,7 @@ def test_repeat_count0():
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(0)
assert "count" in str(info)
assert "count" in str(info.value)

def test_repeat_countneg2():
"""
@@ -313,7 +313,7 @@ def test_repeat_countneg2():
with pytest.raises(ValueError) as info:
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1.repeat(-2)
assert "count" in str(info)
assert "count" in str(info.value)

if __name__ == "__main__":
test_tf_repeat_01()


Loading…
Cancel
Save