From 674415f7be476ea76da7840afe44be41c1fbb415 Mon Sep 17 00:00:00 2001 From: hesham Date: Fri, 19 Jun 2020 12:28:21 -0400 Subject: [PATCH] Cleanup work for Concate, Mask, Slice, PadEnd and TruncatePair --- mindspore/dataset/text/validators.py | 2 +- mindspore/dataset/transforms/c_transforms.py | 19 +++++-- mindspore/dataset/transforms/validators.py | 55 +++++++++----------- tests/ut/python/dataset/test_mask_op.py | 16 +++--- 4 files changed, 47 insertions(+), 45 deletions(-) diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index 74ff31dd7a..96f568e523 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -403,7 +403,7 @@ def check_to_number(method): if not isinstance(data_type, typing.Type): raise TypeError("data_type is not a MindSpore data type.") - if not data_type in mstype.number_type: + if data_type not in mstype.number_type: raise TypeError("data_type is not numeric data type.") kwargs["data_type"] = data_type diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 8af67fab60..d320c722e1 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -79,12 +79,13 @@ class Slice(cde.SliceOp): (Currently only rank 1 Tensors are supported) Args: - *slices: Maximum n number of objects to slice a tensor of rank n. - One object in slices can be one of: + *slices(Variable length argument list): Maximum `n` number of arguments to slice a tensor of rank `n`. + One object in slices can be one of: 1. int: slice this index only. Negative index is supported. 2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`. 3. None: slice the whole dimension. Similar to `:` in python indexing. 4. Ellipses ...: slice all dimensions between the two slices. + Examples: >>> # Data before >>> # | col | @@ -134,11 +135,13 @@ class Mask(cde.MaskOp): """ Mask content of the input tensor with the given predicate. Any element of the tensor that matches the predicate will be evaluated to True, otherwise False. + Args: operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE constant (python types (str, int, float, or bool): constant to be compared to. Constant will be casted to the type of the input tensor dtype (optional, mindspore.dtype): type of the generated mask. Default to bool + Examples: >>> # Data before >>> # | col1 | @@ -163,11 +166,13 @@ class Mask(cde.MaskOp): class PadEnd(cde.PadEndOp): """ Pad input tensor according to `pad_shape`, need to have same rank. + Args: pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values. pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty string in case of Tensors of strings. + Examples: >>> # Data before >>> # | col | @@ -201,21 +206,25 @@ class Concatenate(cde.ConcatenateOp): @check_concat_type def __init__(self, axis=0, prepend=None, append=None): - # add some validations here later + if prepend is not None: + prepend = cde.Tensor(np.array(prepend)) + if append is not None: + append = cde.Tensor(np.array(append)) super().__init__(axis, prepend, append) class Duplicate(cde.DuplicateOp): """ Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list. - Examples: + + Examples: >>> # Data before >>> # | x | >>> # +---------+ >>> # | [1,2,3] | >>> # +---------+ >>> data = data.map(input_columns=["x"], operations=Duplicate(), - >>> output_columns=["x", "y"], output_order=["x", "y"]) + >>> output_columns=["x", "y"], columns_order=["x", "y"]) >>> # Data after >>> # | x | y | >>> # +---------+---------+ diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index d4148e00d1..6b5760e0c5 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -17,7 +17,6 @@ from functools import wraps import numpy as np -import mindspore._c_dataengine as cde from mindspore._c_expression import typing # POS_INT_MIN is used to limit values from starting from 0 @@ -243,12 +242,13 @@ def check_mask_op(method): if not isinstance(constant, (str, float, bool, int, bytes)): raise TypeError("constant must be either a primitive python str, float, bool, bytes or int") - if not isinstance(dtype, typing.Type): - raise TypeError("dtype is not a MindSpore data type.") + if dtype is not None: + if not isinstance(dtype, typing.Type): + raise TypeError("dtype is not a MindSpore data type.") + kwargs["dtype"] = dtype kwargs["operator"] = operator kwargs["constant"] = constant - kwargs["dtype"] = dtype return method(self, **kwargs) @@ -269,8 +269,10 @@ def check_pad_end(method): if pad_shape is None: raise ValueError("pad_shape is not provided.") - if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)): - raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes.") + if pad_value is not None: + if not isinstance(pad_value, (str, float, bool, int, bytes)): + raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes") + kwargs["pad_value"] = pad_value if not isinstance(pad_shape, list): raise TypeError("pad_shape must be a list") @@ -283,7 +285,6 @@ def check_pad_end(method): raise TypeError("a value in the list is not an integer.") kwargs["pad_shape"] = pad_shape - kwargs["pad_value"] = pad_value return method(self, **kwargs) @@ -303,30 +304,22 @@ def check_concat_type(method): if "axis" in kwargs: axis = kwargs.get("axis") - if not isinstance(axis, (type(None), int)): - raise TypeError("axis type is not valid, must be None or an integer.") - - if isinstance(axis, type(None)): - axis = 0 - - if axis not in (None, 0, -1): - raise ValueError("only 1D concatenation supported.") - - if not isinstance(prepend, (type(None), np.ndarray)): - raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") - - if not isinstance(append, (type(None), np.ndarray)): - raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") - - if isinstance(prepend, np.ndarray): - prepend = cde.Tensor(prepend) - - if isinstance(append, np.ndarray): - append = cde.Tensor(append) - - kwargs["axis"] = axis - kwargs["prepend"] = prepend - kwargs["append"] = append + if axis is not None: + if not isinstance(axis, int): + raise TypeError("axis type is not valid, must be an integer.") + if axis not in (0, -1): + raise ValueError("only 1D concatenation supported.") + kwargs["axis"] = axis + + if prepend is not None: + if not isinstance(prepend, (type(None), np.ndarray)): + raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") + kwargs["prepend"] = prepend + + if append is not None: + if not isinstance(append, (type(None), np.ndarray)): + raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") + kwargs["append"] = append return method(self, **kwargs) diff --git a/tests/ut/python/dataset/test_mask_op.py b/tests/ut/python/dataset/test_mask_op.py index 878f786f97..54f2cc65be 100644 --- a/tests/ut/python/dataset/test_mask_op.py +++ b/tests/ut/python/dataset/test_mask_op.py @@ -62,7 +62,7 @@ def mask_compare(array, op, constant, dtype=mstype.bool_): np.testing.assert_array_equal(array, d[0]) -def test_int_comparison(): +def test_mask_int_comparison(): for k in mstype_to_np_type: if k == mstype.string: continue @@ -74,7 +74,7 @@ def test_int_comparison(): mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k) -def test_float_comparison(): +def test_mask_float_comparison(): for k in mstype_to_np_type: if k == mstype.string: continue @@ -86,7 +86,7 @@ def test_float_comparison(): mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k) -def test_float_comparison2(): +def test_mask_float_comparison2(): for k in mstype_to_np_type: if k == mstype.string: continue @@ -98,7 +98,7 @@ def test_float_comparison2(): mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k) -def test_string_comparison(): +def test_mask_string_comparison(): for k in mstype_to_np_type: if k == mstype.string: continue @@ -125,8 +125,8 @@ def test_mask_exceptions_str(): if __name__ == "__main__": - test_int_comparison() - test_float_comparison() - test_float_comparison2() - test_string_comparison() + test_mask_int_comparison() + test_mask_float_comparison() + test_mask_float_comparison2() + test_mask_string_comparison() test_mask_exceptions_str()