| @@ -1137,8 +1137,8 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||
| InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { | |||
| try { | |||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | |||
| if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) { | |||
| RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or channel is not 3."); | |||
| if (input_cv->Rank() == 1 || input_cv->Rank() > 3) { | |||
| RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or <H,W>."); | |||
| } | |||
| cv::Mat affine_mat(mat); | |||
| @@ -13,9 +13,9 @@ | |||
| # limitations under the License. | |||
| """ | |||
| This module provides APIs to load and process various common datasets such as MNIST, | |||
| CIFAR-10, CIFAR-100, VOC, ImageNet, CelebA, etc. It also supports datasets in standard | |||
| format, including MindRecord, TFRecord, Manifest, etc. Users can also define their own | |||
| datasets with this module. | |||
| CIFAR-10, CIFAR-100, VOC, COCO, ImageNet, CelebA, CLUE, etc. It also supports datasets | |||
| in standard format, including MindRecord, TFRecord, Manifest, etc. Users can also define | |||
| their owndatasets with this module. | |||
| Besides, this module provides APIs to sample data while loading. | |||
| @@ -74,6 +74,14 @@ def check_value(value, valid_range, arg_name=""): | |||
| valid_range[1])) | |||
| def check_value_cutoff(value, valid_range, arg_name=""): | |||
| arg_name = pad_arg_name(arg_name) | |||
| if value < valid_range[0] or value >= valid_range[1]: | |||
| raise ValueError( | |||
| "Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0], | |||
| valid_range[1])) | |||
| def check_value_normalize_std(value, valid_range, arg_name=""): | |||
| arg_name = pad_arg_name(arg_name) | |||
| if value <= valid_range[0] or value > valid_range[1]: | |||
| @@ -404,7 +412,7 @@ def check_tensor_op(param, param_name): | |||
| def check_c_tensor_op(param, param_name): | |||
| """check whether param is a tensor op or a callable Python function but not a py_transform""" | |||
| if callable(param) and getattr(param, 'parse', True): | |||
| if callable(param) and str(param).find("py_transform") >= 0: | |||
| raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name)) | |||
| if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): | |||
| raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | |||
| @@ -531,7 +531,8 @@ class PythonTokenizer: | |||
| self.random = False | |||
| def __call__(self, in_array): | |||
| in_array = to_str(in_array) | |||
| if not isinstance(in_array, str): | |||
| in_array = to_str(in_array) | |||
| tokens = self.tokenizer(in_array) | |||
| return tokens | |||
| @@ -104,7 +104,8 @@ class AutoContrast(ImageTensorOperation): | |||
| Apply automatic contrast on input image. | |||
| Args: | |||
| cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). | |||
| cutoff (float, optional): Percent of pixels to cut off from the histogram, | |||
| the value must be in the range [0.0, 50.0) (default=0.0). | |||
| ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). | |||
| Examples: | |||
| @@ -770,7 +771,7 @@ class RandomCropDecodeResize(ImageTensorOperation): | |||
| if img.ndim != 1 or img.dtype.type is not np.uint8: | |||
| raise TypeError("Input should be an encoded image with uint8 type in 1-D NumPy format, " + | |||
| "got format:{}, dtype:{}.".format(type(img), img.dtype.type)) | |||
| super().__call__(img=img) | |||
| return super().__call__(img) | |||
| class RandomCropWithBBox(ImageTensorOperation): | |||
| @@ -1031,7 +1031,7 @@ class RandomErasing: | |||
| class Cutout: | |||
| """ | |||
| Randomly cut (mask) out a given number of square patches from the input NumPy image array. | |||
| Randomly cut (mask) out a given number of square patches from the input NumPy image array of shape (C, H, W). | |||
| Terrance DeVries and Graham W. Taylor 'Improved Regularization of Convolutional Neural Networks with Cutout' 2017 | |||
| See https://arxiv.org/pdf/1708.04552.pdf | |||
| @@ -1068,6 +1068,9 @@ class Cutout: | |||
| """ | |||
| if not isinstance(np_img, np.ndarray): | |||
| raise TypeError("img should be NumPy array. Got {}.".format(type(np_img))) | |||
| if np_img.ndim != 3: | |||
| raise TypeError('img dimension should be 3. Got {}.'.format(np_img.ndim)) | |||
| _, image_h, image_w = np_img.shape | |||
| scale = (self.length * self.length) / (image_h * image_w) | |||
| bounded = False | |||
| @@ -1426,7 +1429,8 @@ class AutoContrast: | |||
| Automatically maximize the contrast of the input PIL image. | |||
| Args: | |||
| cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). | |||
| cutoff (float, optional): Percent of pixels to cut off from the histogram, | |||
| the value must be in the range [0.0, 50.0) (default=0.0). | |||
| ignore (Union[int, sequence], optional): Pixel values to ignore (default=None). | |||
| Examples: | |||
| @@ -56,13 +56,16 @@ def normalize(img, mean, std, pad_channel=False, dtype="float32"): | |||
| Returns: | |||
| img (numpy.ndarray), Normalized image. | |||
| """ | |||
| if not is_numpy(img): | |||
| raise TypeError("img should be NumPy image. Got {}.".format(type(img))) | |||
| if img.ndim != 3: | |||
| raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim)) | |||
| if np.issubdtype(img.dtype, np.integer): | |||
| raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]." | |||
| .format(img.dtype)) | |||
| if not is_numpy(img): | |||
| raise TypeError("img should be NumPy image. Got {}.".format(type(img))) | |||
| num_channels = img.shape[0] # shape is (C, H, W) | |||
| if len(mean) != len(std): | |||
| @@ -119,9 +122,11 @@ def hwc_to_chw(img): | |||
| Returns: | |||
| img (numpy.ndarray), Converted image. | |||
| """ | |||
| if is_numpy(img): | |||
| return img.transpose(2, 0, 1).copy() | |||
| raise TypeError('img should be NumPy array. Got {}.'.format(type(img))) | |||
| if not is_numpy(img): | |||
| raise TypeError('img should be NumPy array. Got {}.'.format(type(img))) | |||
| if img.ndim != 3: | |||
| raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim)) | |||
| return img.transpose(2, 0, 1).copy() | |||
| def to_tensor(img, output_type): | |||
| @@ -140,7 +145,7 @@ def to_tensor(img, output_type): | |||
| img = np.asarray(img) | |||
| if img.ndim not in (2, 3): | |||
| raise ValueError("img dimension should be 2 or 3. Got {}.".format(img.ndim)) | |||
| raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim)) | |||
| if img.ndim == 2: | |||
| img = img[:, :, None] | |||
| @@ -856,8 +861,8 @@ def pad(img, padding, fill_value, padding_mode): | |||
| elif isinstance(padding, (tuple, list)): | |||
| if len(padding) == 2: | |||
| left = right = padding[0] | |||
| top = bottom = padding[1] | |||
| left = top = padding[0] | |||
| right = bottom = padding[1] | |||
| elif len(padding) == 4: | |||
| left = padding[0] | |||
| top = padding[1] | |||
| @@ -877,10 +882,10 @@ def pad(img, padding, fill_value, padding_mode): | |||
| if padding_mode == 'constant': | |||
| if img.mode == 'P': | |||
| palette = img.getpalette() | |||
| image = ImageOps.expand(img, border=padding, fill=fill_value) | |||
| image = ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value) | |||
| image.putpalette(palette) | |||
| return image | |||
| return ImageOps.expand(img, border=padding, fill=fill_value) | |||
| return ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value) | |||
| if img.mode == 'P': | |||
| palette = img.getpalette() | |||
| @@ -1254,6 +1259,9 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc): | |||
| if not is_numpy(np_rgb_imgs): | |||
| raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs))) | |||
| if not isinstance(is_hwc, bool): | |||
| raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc))) | |||
| shape_size = len(np_rgb_imgs.shape) | |||
| if not shape_size in (3, 4): | |||
| @@ -1322,6 +1330,9 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc): | |||
| if not is_numpy(np_hsv_imgs): | |||
| raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs))) | |||
| if not isinstance(is_hwc, bool): | |||
| raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc))) | |||
| shape_size = len(np_hsv_imgs.shape) | |||
| if not shape_size in (3, 4): | |||
| @@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation | |||
| from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ | |||
| check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \ | |||
| check_c_tensor_op, UINT8_MAX, check_value_normalize_std | |||
| check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff | |||
| from .utils import Inter, Border, ImageBatchFormat | |||
| @@ -650,7 +650,7 @@ def check_auto_contrast(method): | |||
| def new_method(self, *args, **kwargs): | |||
| [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(cutoff, (int, float), "cutoff") | |||
| check_value(cutoff, [0, 100], "cutoff") | |||
| check_value_cutoff(cutoff, [0, 50], "cutoff") | |||
| if ignore is not None: | |||
| type_check(ignore, (list, tuple, int), "ignore") | |||
| if isinstance(ignore, int): | |||
| @@ -270,7 +270,7 @@ def test_auto_contrast_invalid_cutoff_param_c(): | |||
| data_set = data_set.map(operations=C.AutoContrast(cutoff=-10.0), input_columns="image") | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) | |||
| assert "Input cutoff is not within the required interval of [0, 50)." in str(error) | |||
| try: | |||
| data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) | |||
| data_set = data_set.map(operations=[C.Decode(), | |||
| @@ -280,7 +280,7 @@ def test_auto_contrast_invalid_cutoff_param_c(): | |||
| data_set = data_set.map(operations=C.AutoContrast(cutoff=120.0), input_columns="image") | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) | |||
| assert "Input cutoff is not within the required interval of [0, 50)." in str(error) | |||
| def test_auto_contrast_invalid_ignore_param_py(): | |||
| @@ -327,7 +327,7 @@ def test_auto_contrast_invalid_cutoff_param_py(): | |||
| input_columns=["image"]) | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) | |||
| assert "Input cutoff is not within the required interval of [0, 50)." in str(error) | |||
| try: | |||
| data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) | |||
| data_set = data_set.map( | |||
| @@ -338,7 +338,7 @@ def test_auto_contrast_invalid_cutoff_param_py(): | |||
| input_columns=["image"]) | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input cutoff is not within the required interval of (0 to 100)." in str(error) | |||
| assert "Input cutoff is not within the required interval of [0, 50)." in str(error) | |||
| if __name__ == "__main__": | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import mindspore.dataset.text.transforms as T | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import log as logger | |||
| def test_sliding_window(): | |||
| txt = ["Welcome", "to", "Beijing", "!"] | |||
| sliding_window = T.SlidingWindow(width=2) | |||
| txt = sliding_window(txt) | |||
| logger.info("Result: {}".format(txt)) | |||
| expected = [['Welcome', 'to'], ['to', 'Beijing'], ['Beijing', '!']] | |||
| np.testing.assert_equal(txt, expected) | |||
| def test_to_number(): | |||
| txt = ["123456"] | |||
| to_number = T.ToNumber(mstype.int32) | |||
| txt = to_number(txt) | |||
| logger.info("Result: {}, type: {}".format(txt, type(txt[0]))) | |||
| assert txt == 123456 | |||
| def test_whitespace_tokenizer(): | |||
| txt = "Welcome to Beijing !" | |||
| txt = T.WhitespaceTokenizer()(txt) | |||
| logger.info("Tokenize result: {}".format(txt)) | |||
| expected = ['Welcome', 'to', 'Beijing', '!'] | |||
| np.testing.assert_equal(txt, expected) | |||
| def test_python_tokenizer(): | |||
| # whitespace tokenizer | |||
| def my_tokenizer(line): | |||
| words = line.split() | |||
| if not words: | |||
| return [""] | |||
| return words | |||
| txt = "Welcome to Beijing !" | |||
| txt = T.PythonTokenizer(my_tokenizer)(txt) | |||
| logger.info("Tokenize result: {}".format(txt)) | |||
| expected = ['Welcome', 'to', 'Beijing', '!'] | |||
| np.testing.assert_equal(txt, expected) | |||
| if __name__ == '__main__': | |||
| test_sliding_window() | |||
| test_to_number() | |||
| test_whitespace_tokenizer() | |||
| test_python_tokenizer() | |||