Merge pull request !2833 from nhussain/engine_validatorstags/v0.6.0-beta
| @@ -0,0 +1,342 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| General Validators. | |||
| """ | |||
| import inspect | |||
| from multiprocessing import cpu_count | |||
| import os | |||
| import numpy as np | |||
| from ..engine import samplers | |||
| # POS_INT_MIN is used to limit values from starting from 0 | |||
| POS_INT_MIN = 1 | |||
| UINT8_MAX = 255 | |||
| UINT8_MIN = 0 | |||
| UINT32_MAX = 4294967295 | |||
| UINT32_MIN = 0 | |||
| UINT64_MAX = 18446744073709551615 | |||
| UINT64_MIN = 0 | |||
| INT32_MAX = 2147483647 | |||
| INT32_MIN = -2147483648 | |||
| INT64_MAX = 9223372036854775807 | |||
| INT64_MIN = -9223372036854775808 | |||
| FLOAT_MAX_INTEGER = 16777216 | |||
| FLOAT_MIN_INTEGER = -16777216 | |||
| DOUBLE_MAX_INTEGER = 9007199254740992 | |||
| DOUBLE_MIN_INTEGER = -9007199254740992 | |||
| valid_detype = [ | |||
| "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", | |||
| "uint32", "uint64", "float16", "float32", "float64", "string" | |||
| ] | |||
| def pad_arg_name(arg_name): | |||
| if arg_name != "": | |||
| arg_name = arg_name + " " | |||
| return arg_name | |||
| def check_value(value, valid_range, arg_name=""): | |||
| arg_name = pad_arg_name(arg_name) | |||
| if value < valid_range[0] or value > valid_range[1]: | |||
| raise ValueError( | |||
| "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], | |||
| valid_range[1])) | |||
| def check_range(values, valid_range, arg_name=""): | |||
| arg_name = pad_arg_name(arg_name) | |||
| if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: | |||
| raise ValueError( | |||
| "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], | |||
| valid_range[1])) | |||
| def check_positive(value, arg_name=""): | |||
| arg_name = pad_arg_name(arg_name) | |||
| if value <= 0: | |||
| raise ValueError("Input {0}must be greater than 0.".format(arg_name)) | |||
| def check_positive_float(value, arg_name=""): | |||
| arg_name = pad_arg_name(arg_name) | |||
| type_check(value, (float,), arg_name) | |||
| check_positive(value, arg_name) | |||
| def check_2tuple(value, arg_name=""): | |||
| if not (isinstance(value, tuple) and len(value) == 2): | |||
| raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name)) | |||
| def check_uint8(value, arg_name=""): | |||
| type_check(value, (int,), arg_name) | |||
| check_value(value, [UINT8_MIN, UINT8_MAX]) | |||
| def check_uint32(value, arg_name=""): | |||
| type_check(value, (int,), arg_name) | |||
| check_value(value, [UINT32_MIN, UINT32_MAX]) | |||
| def check_pos_int32(value, arg_name=""): | |||
| type_check(value, (int,), arg_name) | |||
| check_value(value, [POS_INT_MIN, INT32_MAX]) | |||
| def check_uint64(value, arg_name=""): | |||
| type_check(value, (int,), arg_name) | |||
| check_value(value, [UINT64_MIN, UINT64_MAX]) | |||
| def check_pos_int64(value, arg_name=""): | |||
| type_check(value, (int,), arg_name) | |||
| check_value(value, [UINT64_MIN, INT64_MAX]) | |||
| def check_pos_float32(value, arg_name=""): | |||
| check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name) | |||
| def check_pos_float64(value, arg_name=""): | |||
| check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name) | |||
| def check_valid_detype(type_): | |||
| if type_ not in valid_detype: | |||
| raise ValueError("Unknown column type") | |||
| return True | |||
| def check_columns(columns, name): | |||
| type_check(columns, (list, str), name) | |||
| if isinstance(columns, list): | |||
| if not columns: | |||
| raise ValueError("Column names should not be empty") | |||
| col_names = ["col_{0}".format(i) for i in range(len(columns))] | |||
| type_check_list(columns, (str,), col_names) | |||
| def parse_user_args(method, *args, **kwargs): | |||
| """ | |||
| Parse user arguments in a function | |||
| Args: | |||
| method (method): a callable function | |||
| *args: user passed args | |||
| **kwargs: user passed kwargs | |||
| Returns: | |||
| user_filled_args (list): values of what the user passed in for the arguments, | |||
| ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. | |||
| """ | |||
| sig = inspect.signature(method) | |||
| if 'self' in sig.parameters or 'cls' in sig.parameters: | |||
| ba = sig.bind(method, *args, **kwargs) | |||
| ba.apply_defaults() | |||
| params = list(sig.parameters.keys())[1:] | |||
| else: | |||
| ba = sig.bind(*args, **kwargs) | |||
| ba.apply_defaults() | |||
| params = list(sig.parameters.keys()) | |||
| user_filled_args = [ba.arguments.get(arg_value) for arg_value in params] | |||
| return user_filled_args, ba.arguments | |||
| def type_check_list(args, types, arg_names): | |||
| """ | |||
| Check the type of each parameter in the list | |||
| Args: | |||
| args (list, tuple): a list or tuple of any variable | |||
| types (tuple): tuple of all valid types for arg | |||
| arg_names (list, tuple of str): the names of args | |||
| Returns: | |||
| Exception: when the type is not correct, otherwise nothing | |||
| """ | |||
| type_check(args, (list, tuple,), arg_names) | |||
| if len(args) != len(arg_names): | |||
| raise ValueError("List of arguments is not the same length as argument_names.") | |||
| for arg, arg_name in zip(args, arg_names): | |||
| type_check(arg, types, arg_name) | |||
| def type_check(arg, types, arg_name): | |||
| """ | |||
| Check the type of the parameter | |||
| Args: | |||
| arg : any variable | |||
| types (tuple): tuple of all valid types for arg | |||
| arg_name (str): the name of arg | |||
| Returns: | |||
| Exception: when the type is not correct, otherwise nothing | |||
| """ | |||
| # handle special case of booleans being a subclass of ints | |||
| print_value = '\"\"' if repr(arg) == repr('') else arg | |||
| if int in types and bool not in types: | |||
| if isinstance(arg, bool): | |||
| raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) | |||
| if not isinstance(arg, types): | |||
| raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) | |||
| def check_filename(path): | |||
| """ | |||
| check the filename in the path | |||
| Args: | |||
| path (str): the path | |||
| Returns: | |||
| Exception: when error | |||
| """ | |||
| if not isinstance(path, str): | |||
| raise TypeError("path: {} is not string".format(path)) | |||
| filename = os.path.basename(path) | |||
| # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', | |||
| # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', | |||
| # '*', '(', '%', ')', '-', '=', '{', '?', '$' | |||
| forbidden_symbols = set(r'\/:*?"<>|`&\';') | |||
| if set(filename) & forbidden_symbols: | |||
| raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'") | |||
| if filename.startswith(' ') or filename.endswith(' '): | |||
| raise ValueError("filename should not start/end with space") | |||
| return True | |||
| def check_dir(dataset_dir): | |||
| if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): | |||
| raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) | |||
| def check_file(dataset_file): | |||
| check_filename(dataset_file) | |||
| if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): | |||
| raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) | |||
| def check_sampler_shuffle_shard_options(param_dict): | |||
| """ | |||
| Check for valid shuffle, sampler, num_shards, and shard_id inputs. | |||
| Args: | |||
| param_dict (dict): param_dict | |||
| Returns: | |||
| Exception: ValueError or RuntimeError if error | |||
| """ | |||
| shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') | |||
| num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | |||
| type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler") | |||
| if sampler is not None: | |||
| if shuffle is not None: | |||
| raise RuntimeError("sampler and shuffle cannot be specified at the same time.") | |||
| if num_shards is not None: | |||
| check_pos_int32(num_shards) | |||
| if shard_id is None: | |||
| raise RuntimeError("num_shards is specified and currently requires shard_id as well.") | |||
| check_value(shard_id, [0, num_shards - 1], "shard_id") | |||
| if num_shards is None and shard_id is not None: | |||
| raise RuntimeError("shard_id is specified but num_shards is not.") | |||
| def check_padding_options(param_dict): | |||
| """ | |||
| Check for valid padded_sample and num_padded of padded samples | |||
| Args: | |||
| param_dict (dict): param_dict | |||
| Returns: | |||
| Exception: ValueError or RuntimeError if error | |||
| """ | |||
| columns_list = param_dict.get('columns_list') | |||
| block_reader = param_dict.get('block_reader') | |||
| padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') | |||
| if padded_sample is not None: | |||
| if num_padded is None: | |||
| raise RuntimeError("padded_sample is specified and requires num_padded as well.") | |||
| if num_padded < 0: | |||
| raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) | |||
| if columns_list is None: | |||
| raise RuntimeError("padded_sample is specified and requires columns_list as well.") | |||
| for column in columns_list: | |||
| if column not in padded_sample: | |||
| raise ValueError("padded_sample cannot match columns_list.") | |||
| if block_reader: | |||
| raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") | |||
| if padded_sample is None and num_padded is not None: | |||
| raise RuntimeError("num_padded is specified but padded_sample is not.") | |||
| def check_num_parallel_workers(value): | |||
| type_check(value, (int,), "num_parallel_workers") | |||
| if value < 1 or value > cpu_count(): | |||
| raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) | |||
| def check_num_samples(value): | |||
| type_check(value, (int,), "num_samples") | |||
| check_value(value, [0, INT32_MAX], "num_samples") | |||
| def validate_dataset_param_value(param_list, param_dict, param_type): | |||
| for param_name in param_list: | |||
| if param_dict.get(param_name) is not None: | |||
| if param_name == 'num_parallel_workers': | |||
| check_num_parallel_workers(param_dict.get(param_name)) | |||
| if param_name == 'num_samples': | |||
| check_num_samples(param_dict.get(param_name)) | |||
| else: | |||
| type_check(param_dict.get(param_name), (param_type,), param_name) | |||
| def check_gnn_list_or_ndarray(param, param_name): | |||
| """ | |||
| Check if the input parameter is list or numpy.ndarray. | |||
| Args: | |||
| param (list, nd.ndarray): param | |||
| param_name (str): param_name | |||
| Returns: | |||
| Exception: TypeError if error | |||
| """ | |||
| type_check(param, (list, np.ndarray), param_name) | |||
| if isinstance(param, list): | |||
| param_names = ["param_{0}".format(i) for i in range(len(param))] | |||
| type_check_list(param, (int,), param_names) | |||
| elif isinstance(param, np.ndarray): | |||
| if not param.dtype == np.int32: | |||
| raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( | |||
| param_name, param.dtype)) | |||
| @@ -98,7 +98,7 @@ class Ngram(cde.NgramOp): | |||
| """ | |||
| @check_ngram | |||
| def __init__(self, n, left_pad=None, right_pad=None, separator=None): | |||
| def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "): | |||
| super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0], | |||
| r_pad_token=right_pad[0], separator=separator) | |||
| @@ -28,6 +28,7 @@ __all__ = [ | |||
| "Vocab", "to_str", "to_bytes" | |||
| ] | |||
| class Vocab(cde.Vocab): | |||
| """ | |||
| Vocab object that is used to lookup a word. | |||
| @@ -38,7 +39,7 @@ class Vocab(cde.Vocab): | |||
| @classmethod | |||
| @check_from_dataset | |||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, | |||
| special_first=None): | |||
| special_first=True): | |||
| """ | |||
| Build a vocab from a dataset. | |||
| @@ -62,13 +63,21 @@ class Vocab(cde.Vocab): | |||
| special_tokens(list, optional): a list of strings, each one is a special token. for example | |||
| special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | |||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens | |||
| is specified and special_first is set to None, special_tokens will be prepended (default=None). | |||
| is specified and special_first is set to True, special_tokens will be prepended (default=True). | |||
| Returns: | |||
| Vocab, Vocab object built from dataset. | |||
| """ | |||
| vocab = Vocab() | |||
| if columns is None: | |||
| columns = [] | |||
| if not isinstance(columns, list): | |||
| columns = [columns] | |||
| if freq_range is None: | |||
| freq_range = (None, None) | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) | |||
| for d in root.create_dict_iterator(): | |||
| if d is not None: | |||
| @@ -77,7 +86,7 @@ class Vocab(cde.Vocab): | |||
| @classmethod | |||
| @check_from_list | |||
| def from_list(cls, word_list, special_tokens=None, special_first=None): | |||
| def from_list(cls, word_list, special_tokens=None, special_first=True): | |||
| """ | |||
| Build a vocab object from a list of word. | |||
| @@ -86,29 +95,33 @@ class Vocab(cde.Vocab): | |||
| special_tokens(list, optional): a list of strings, each one is a special token. for example | |||
| special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | |||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens | |||
| is specified and special_first is set to None, special_tokens will be prepended (default=None). | |||
| is specified and special_first is set to True, special_tokens will be prepended (default=True). | |||
| """ | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| return super().from_list(word_list, special_tokens, special_first) | |||
| @classmethod | |||
| @check_from_file | |||
| def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None): | |||
| def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True): | |||
| """ | |||
| Build a vocab object from a list of word. | |||
| Args: | |||
| file_path (str): path to the file which contains the vocab list. | |||
| delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be | |||
| the word (default=None). | |||
| the word (default=""). | |||
| vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken). | |||
| special_tokens (list, optional): a list of strings, each one is a special token. for example | |||
| special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | |||
| special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, | |||
| If special_tokens is specified and special_first is set to None, | |||
| special_tokens will be prepended (default=None). | |||
| If special_tokens is specified and special_first is set to True, | |||
| special_tokens will be prepended (default=True). | |||
| """ | |||
| if vocab_size is None: | |||
| vocab_size = -1 | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first) | |||
| @classmethod | |||
| @@ -17,23 +17,22 @@ validators for text ops | |||
| """ | |||
| from functools import wraps | |||
| import mindspore._c_dataengine as cde | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore._c_dataengine as cde | |||
| from mindspore._c_expression import typing | |||
| from ..transforms.validators import check_uint32, check_pos_int64 | |||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \ | |||
| INT32_MAX, check_value | |||
| def check_unique_list_of_words(words, arg_name): | |||
| """Check that words is a list and each element is a str without any duplication""" | |||
| if not isinstance(words, list): | |||
| raise ValueError(arg_name + " needs to be a list of words of type string.") | |||
| type_check(words, (list,), arg_name) | |||
| words_set = set() | |||
| for word in words: | |||
| if not isinstance(word, str): | |||
| raise ValueError("each word in " + arg_name + " needs to be type str.") | |||
| type_check(word, (str,), arg_name) | |||
| if word in words_set: | |||
| raise ValueError(arg_name + " contains duplicate word: " + word + ".") | |||
| words_set.add(word) | |||
| @@ -45,21 +44,14 @@ def check_lookup(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| vocab, unknown = (list(args) + 2 * [None])[:2] | |||
| if "vocab" in kwargs: | |||
| vocab = kwargs.get("vocab") | |||
| if "unknown" in kwargs: | |||
| unknown = kwargs.get("unknown") | |||
| if unknown is not None: | |||
| if not (isinstance(unknown, int) and unknown >= 0): | |||
| raise ValueError("unknown needs to be a non-negative integer.") | |||
| [vocab, unknown], _ = parse_user_args(method, *args, **kwargs) | |||
| if not isinstance(vocab, cde.Vocab): | |||
| raise ValueError("vocab is not an instance of cde.Vocab.") | |||
| if unknown is not None: | |||
| type_check(unknown, (int,), "unknown") | |||
| check_positive(unknown) | |||
| type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") | |||
| kwargs["vocab"] = vocab | |||
| kwargs["unknown"] = unknown | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -69,50 +61,15 @@ def check_from_file(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5] | |||
| if "file_path" in kwargs: | |||
| file_path = kwargs.get("file_path") | |||
| if "delimiter" in kwargs: | |||
| delimiter = kwargs.get("delimiter") | |||
| if "vocab_size" in kwargs: | |||
| vocab_size = kwargs.get("vocab_size") | |||
| if "special_tokens" in kwargs: | |||
| special_tokens = kwargs.get("special_tokens") | |||
| if "special_first" in kwargs: | |||
| special_first = kwargs.get("special_first") | |||
| if not isinstance(file_path, str): | |||
| raise ValueError("file_path needs to be str.") | |||
| if delimiter is not None: | |||
| if not isinstance(delimiter, str): | |||
| raise ValueError("delimiter needs to be str.") | |||
| else: | |||
| delimiter = "" | |||
| if vocab_size is not None: | |||
| if not (isinstance(vocab_size, int) and vocab_size > 0): | |||
| raise ValueError("vocab size needs to be a positive integer.") | |||
| else: | |||
| vocab_size = -1 | |||
| if special_first is None: | |||
| special_first = True | |||
| if not isinstance(special_first, bool): | |||
| raise ValueError("special_first needs to be a boolean value") | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, | |||
| **kwargs) | |||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||
| type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) | |||
| if vocab_size is not None: | |||
| check_value(vocab_size, (-1, INT32_MAX), "vocab_size") | |||
| type_check(special_first, (bool,), special_first) | |||
| kwargs["file_path"] = file_path | |||
| kwargs["delimiter"] = delimiter | |||
| kwargs["vocab_size"] = vocab_size | |||
| kwargs["special_tokens"] = special_tokens | |||
| kwargs["special_first"] = special_first | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -122,33 +79,20 @@ def check_from_list(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3] | |||
| if "word_list" in kwargs: | |||
| word_list = kwargs.get("word_list") | |||
| if "special_tokens" in kwargs: | |||
| special_tokens = kwargs.get("special_tokens") | |||
| if "special_first" in kwargs: | |||
| special_first = kwargs.get("special_first") | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| word_set = check_unique_list_of_words(word_list, "word_list") | |||
| token_set = check_unique_list_of_words(special_tokens, "special_tokens") | |||
| [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs) | |||
| intersect = word_set.intersection(token_set) | |||
| word_set = check_unique_list_of_words(word_list, "word_list") | |||
| if special_tokens is not None: | |||
| token_set = check_unique_list_of_words(special_tokens, "special_tokens") | |||
| if intersect != set(): | |||
| raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") | |||
| intersect = word_set.intersection(token_set) | |||
| if special_first is None: | |||
| special_first = True | |||
| if intersect != set(): | |||
| raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") | |||
| if not isinstance(special_first, bool): | |||
| raise ValueError("special_first needs to be a boolean value.") | |||
| type_check(special_first, (bool,), "special_first") | |||
| kwargs["word_list"] = word_list | |||
| kwargs["special_tokens"] = special_tokens | |||
| kwargs["special_first"] = special_first | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -158,18 +102,15 @@ def check_from_dict(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word_dict, = (list(args) + [None])[:1] | |||
| if "word_dict" in kwargs: | |||
| word_dict = kwargs.get("word_dict") | |||
| if not isinstance(word_dict, dict): | |||
| raise ValueError("word_dict needs to be a list of word,id pairs.") | |||
| [word_dict], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(word_dict, (dict,), "word_dict") | |||
| for word, word_id in word_dict.items(): | |||
| if not isinstance(word, str): | |||
| raise ValueError("Each word in word_dict needs to be type string.") | |||
| if not (isinstance(word_id, int) and word_id >= 0): | |||
| raise ValueError("Each word id needs to be positive integer.") | |||
| kwargs["word_dict"] = word_dict | |||
| return method(self, **kwargs) | |||
| type_check(word, (str,), "word") | |||
| type_check(word_id, (int,), "word_id") | |||
| check_value(word_id, (-1, INT32_MAX), "word_id") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -179,23 +120,8 @@ def check_jieba_init(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| hmm_path, mp_path, model = (list(args) + 3 * [None])[:3] | |||
| if "hmm_path" in kwargs: | |||
| hmm_path = kwargs.get("hmm_path") | |||
| if "mp_path" in kwargs: | |||
| mp_path = kwargs.get("mp_path") | |||
| if hmm_path is None: | |||
| raise ValueError( | |||
| "The dict of HMMSegment in cppjieba is not provided.") | |||
| kwargs["hmm_path"] = hmm_path | |||
| if mp_path is None: | |||
| raise ValueError( | |||
| "The dict of MPSegment in cppjieba is not provided.") | |||
| kwargs["mp_path"] = mp_path | |||
| if model is not None: | |||
| kwargs["model"] = model | |||
| return method(self, **kwargs) | |||
| parse_user_args(method, *args, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -205,19 +131,12 @@ def check_jieba_add_word(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| word, freq = (list(args) + 2 * [None])[:2] | |||
| if "word" in kwargs: | |||
| word = kwargs.get("word") | |||
| if "freq" in kwargs: | |||
| freq = kwargs.get("freq") | |||
| [word, freq], _ = parse_user_args(method, *args, **kwargs) | |||
| if word is None: | |||
| raise ValueError("word is not provided.") | |||
| kwargs["word"] = word | |||
| if freq is not None: | |||
| check_uint32(freq) | |||
| kwargs["freq"] = freq | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -227,13 +146,8 @@ def check_jieba_add_dict(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| user_dict = (list(args) + [None])[0] | |||
| if "user_dict" in kwargs: | |||
| user_dict = kwargs.get("user_dict") | |||
| if user_dict is None: | |||
| raise ValueError("user_dict is not provided.") | |||
| kwargs["user_dict"] = user_dict | |||
| return method(self, **kwargs) | |||
| parse_user_args(method, *args, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -244,69 +158,39 @@ def check_from_dataset(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6] | |||
| if "dataset" in kwargs: | |||
| dataset = kwargs.get("dataset") | |||
| if "columns" in kwargs: | |||
| columns = kwargs.get("columns") | |||
| if "freq_range" in kwargs: | |||
| freq_range = kwargs.get("freq_range") | |||
| if "top_k" in kwargs: | |||
| top_k = kwargs.get("top_k") | |||
| if "special_tokens" in kwargs: | |||
| special_tokens = kwargs.get("special_tokens") | |||
| if "special_first" in kwargs: | |||
| special_first = kwargs.get("special_first") | |||
| if columns is None: | |||
| columns = [] | |||
| if not isinstance(columns, list): | |||
| columns = [columns] | |||
| for column in columns: | |||
| if not isinstance(column, str): | |||
| raise ValueError("columns need to be a list of strings.") | |||
| if freq_range is None: | |||
| freq_range = (None, None) | |||
| if not isinstance(freq_range, tuple) or len(freq_range) != 2: | |||
| raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") | |||
| [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args, | |||
| **kwargs) | |||
| if columns is not None: | |||
| if not isinstance(columns, list): | |||
| columns = [columns] | |||
| col_names = ["col_{0}".format(i) for i in range(len(columns))] | |||
| type_check_list(columns, (str,), col_names) | |||
| for num in freq_range: | |||
| if num is not None and (not isinstance(num, int)): | |||
| raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") | |||
| if freq_range is not None: | |||
| type_check(freq_range, (tuple,), "freq_range") | |||
| if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): | |||
| if freq_range[0] > freq_range[1] or freq_range[0] < 0: | |||
| raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") | |||
| if len(freq_range) != 2: | |||
| raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.") | |||
| if top_k is not None and (not isinstance(top_k, int)): | |||
| raise ValueError("top_k needs to be a positive integer.") | |||
| for num in freq_range: | |||
| if num is not None and (not isinstance(num, int)): | |||
| raise ValueError( | |||
| "freq_range needs to be either None or a tuple of 2 integers or an int and a None.") | |||
| if isinstance(top_k, int) and top_k <= 0: | |||
| raise ValueError("top_k needs to be a positive integer.") | |||
| if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): | |||
| if freq_range[0] > freq_range[1] or freq_range[0] < 0: | |||
| raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") | |||
| if special_first is None: | |||
| special_first = True | |||
| type_check(top_k, (int, type(None)), "top_k") | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| if isinstance(top_k, int): | |||
| check_value(top_k, (0, INT32_MAX), "top_k") | |||
| type_check(special_first, (bool,), "special_first") | |||
| if not isinstance(special_first, bool): | |||
| raise ValueError("special_first needs to be a boolean value.") | |||
| if special_tokens is not None: | |||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||
| kwargs["dataset"] = dataset | |||
| kwargs["columns"] = columns | |||
| kwargs["freq_range"] = freq_range | |||
| kwargs["top_k"] = top_k | |||
| kwargs["special_tokens"] = special_tokens | |||
| kwargs["special_first"] = special_first | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -316,15 +200,7 @@ def check_ngram(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4] | |||
| if "n" in kwargs: | |||
| n = kwargs.get("n") | |||
| if "left_pad" in kwargs: | |||
| left_pad = kwargs.get("left_pad") | |||
| if "right_pad" in kwargs: | |||
| right_pad = kwargs.get("right_pad") | |||
| if "separator" in kwargs: | |||
| separator = kwargs.get("separator") | |||
| [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs) | |||
| if isinstance(n, int): | |||
| n = [n] | |||
| @@ -332,15 +208,9 @@ def check_ngram(method): | |||
| if not (isinstance(n, list) and n != []): | |||
| raise ValueError("n needs to be a non-empty list of positive integers.") | |||
| for gram in n: | |||
| if not (isinstance(gram, int) and gram > 0): | |||
| raise ValueError("n in ngram needs to be a positive number.") | |||
| if left_pad is None: | |||
| left_pad = ("", 0) | |||
| if right_pad is None: | |||
| right_pad = ("", 0) | |||
| for i, gram in enumerate(n): | |||
| type_check(gram, (int,), "gram[{0}]".format(i)) | |||
| check_value(gram, (0, INT32_MAX), "gram_{}".format(i)) | |||
| if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( | |||
| left_pad[1], int)): | |||
| @@ -353,11 +223,7 @@ def check_ngram(method): | |||
| if not (left_pad[1] >= 0 and right_pad[1] >= 0): | |||
| raise ValueError("padding width need to be positive numbers.") | |||
| if separator is None: | |||
| separator = " " | |||
| if not isinstance(separator, str): | |||
| raise ValueError("separator needs to be a string.") | |||
| type_check(separator, (str,), "separator") | |||
| kwargs["n"] = n | |||
| kwargs["left_pad"] = left_pad | |||
| @@ -374,16 +240,8 @@ def check_pair_truncate(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| max_length = (list(args) + [None])[0] | |||
| if "max_length" in kwargs: | |||
| max_length = kwargs.get("max_length") | |||
| if max_length is None: | |||
| raise ValueError("max_length is not provided.") | |||
| check_pos_int64(max_length) | |||
| kwargs["max_length"] = max_length | |||
| return method(self, **kwargs) | |||
| parse_user_args(method, *args, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -393,22 +251,13 @@ def check_to_number(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| data_type = (list(args) + [None])[0] | |||
| if "data_type" in kwargs: | |||
| data_type = kwargs.get("data_type") | |||
| if data_type is None: | |||
| raise ValueError("data_type is a mandatory parameter but was not provided.") | |||
| if not isinstance(data_type, typing.Type): | |||
| raise TypeError("data_type is not a MindSpore data type.") | |||
| [data_type], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(data_type, (typing.Type,), "data_type") | |||
| if data_type not in mstype.number_type: | |||
| raise TypeError("data_type is not numeric data type.") | |||
| kwargs["data_type"] = data_type | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -418,18 +267,11 @@ def check_python_tokenizer(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| tokenizer = (list(args) + [None])[0] | |||
| if "tokenizer" in kwargs: | |||
| tokenizer = kwargs.get("tokenizer") | |||
| if tokenizer is None: | |||
| raise ValueError("tokenizer is a mandatory parameter.") | |||
| [tokenizer], _ = parse_user_args(method, *args, **kwargs) | |||
| if not callable(tokenizer): | |||
| raise TypeError("tokenizer is not a callable python function") | |||
| kwargs["tokenizer"] = tokenizer | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -18,6 +18,7 @@ from functools import wraps | |||
| import numpy as np | |||
| from mindspore._c_expression import typing | |||
| from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive | |||
| # POS_INT_MIN is used to limit values from starting from 0 | |||
| POS_INT_MIN = 1 | |||
| @@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992 | |||
| DOUBLE_MIN_INTEGER = -9007199254740992 | |||
| def check_type(value, valid_type): | |||
| if not isinstance(value, valid_type): | |||
| raise ValueError("Wrong input type") | |||
| def check_value(value, valid_range): | |||
| if value < valid_range[0] or value > valid_range[1]: | |||
| raise ValueError("Input is not within the required range") | |||
| def check_range(values, valid_range): | |||
| if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: | |||
| raise ValueError("Input range is not valid") | |||
| def check_positive(value): | |||
| if value <= 0: | |||
| raise ValueError("Input must greater than 0") | |||
| def check_positive_float(value, valid_max=None): | |||
| if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max): | |||
| raise ValueError("Input need to be a valid positive float.") | |||
| def check_bool(value): | |||
| if not isinstance(value, bool): | |||
| raise ValueError("Value needs to be a boolean.") | |||
| def check_2tuple(value): | |||
| if not (isinstance(value, tuple) and len(value) == 2): | |||
| raise ValueError("Value needs to be a 2-tuple.") | |||
| def check_list(value): | |||
| if not isinstance(value, list): | |||
| raise ValueError("The input needs to be a list.") | |||
| def check_uint8(value): | |||
| if not isinstance(value, int): | |||
| raise ValueError("The input needs to be a integer") | |||
| check_value(value, [UINT8_MIN, UINT8_MAX]) | |||
| def check_uint32(value): | |||
| if not isinstance(value, int): | |||
| raise ValueError("The input needs to be a integer") | |||
| check_value(value, [UINT32_MIN, UINT32_MAX]) | |||
| def check_pos_int32(value): | |||
| """Checks for int values starting from 1""" | |||
| if not isinstance(value, int): | |||
| raise ValueError("The input needs to be a integer") | |||
| check_value(value, [POS_INT_MIN, INT32_MAX]) | |||
| def check_uint64(value): | |||
| if not isinstance(value, int): | |||
| raise ValueError("The input needs to be a integer") | |||
| check_value(value, [UINT64_MIN, UINT64_MAX]) | |||
| def check_pos_int64(value): | |||
| if not isinstance(value, int): | |||
| raise ValueError("The input needs to be a integer") | |||
| check_value(value, [UINT64_MIN, INT64_MAX]) | |||
| def check_fill_value(method): | |||
| """Wrapper method to check the parameters of fill_value.""" | |||
| def check_pos_float32(value): | |||
| check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER]) | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [fill_value], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(fill_value, (str, float, bool, int, bytes), "fill_value") | |||
| return method(self, *args, **kwargs) | |||
| def check_pos_float64(value): | |||
| check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER]) | |||
| return new_method | |||
| def check_one_hot_op(method): | |||
| """Wrapper method to check the parameters of one hot op.""" | |||
| """Wrapper method to check the parameters of one_hot_op.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| args = (list(args) + 2 * [None])[:2] | |||
| num_classes, smoothing_rate = args | |||
| if "num_classes" in kwargs: | |||
| num_classes = kwargs.get("num_classes") | |||
| if "smoothing_rate" in kwargs: | |||
| smoothing_rate = kwargs.get("smoothing_rate") | |||
| if num_classes is None: | |||
| raise ValueError("num_classes") | |||
| check_pos_int32(num_classes) | |||
| kwargs["num_classes"] = num_classes | |||
| [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(num_classes, (int,), "num_classes") | |||
| check_positive(num_classes) | |||
| if smoothing_rate is not None: | |||
| check_value(smoothing_rate, [0., 1.]) | |||
| kwargs["smoothing_rate"] = smoothing_rate | |||
| check_value(smoothing_rate, [0., 1.], "smoothing_rate") | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -146,35 +74,12 @@ def check_num_classes(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| num_classes = (list(args) + [None])[0] | |||
| if "num_classes" in kwargs: | |||
| num_classes = kwargs.get("num_classes") | |||
| if num_classes is None: | |||
| raise ValueError("num_classes is not provided.") | |||
| check_pos_int32(num_classes) | |||
| kwargs["num_classes"] = num_classes | |||
| return method(self, **kwargs) | |||
| return new_method | |||
| [num_classes], _ = parse_user_args(method, *args, **kwargs) | |||
| def check_fill_value(method): | |||
| """Wrapper method to check the parameters of fill value.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| fill_value = (list(args) + [None])[0] | |||
| if "fill_value" in kwargs: | |||
| fill_value = kwargs.get("fill_value") | |||
| if fill_value is None: | |||
| raise ValueError("fill_value is not provided.") | |||
| if not isinstance(fill_value, (str, float, bool, int, bytes)): | |||
| raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int") | |||
| kwargs["fill_value"] = fill_value | |||
| type_check(num_classes, (int,), "num_classes") | |||
| check_positive(num_classes) | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -184,17 +89,11 @@ def check_de_type(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| data_type = (list(args) + [None])[0] | |||
| if "data_type" in kwargs: | |||
| data_type = kwargs.get("data_type") | |||
| [data_type], _ = parse_user_args(method, *args, **kwargs) | |||
| if data_type is None: | |||
| raise ValueError("data_type is not provided.") | |||
| if not isinstance(data_type, typing.Type): | |||
| raise TypeError("data_type is not a MindSpore data type.") | |||
| kwargs["data_type"] = data_type | |||
| type_check(data_type, (typing.Type,), "data_type") | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -204,13 +103,11 @@ def check_slice_op(method): | |||
| @wraps(method) | |||
| def new_method(self, *args): | |||
| for i, arg in enumerate(args): | |||
| if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)): | |||
| raise TypeError("Indexing of dim " + str(i) + "is not of valid type") | |||
| for _, arg in enumerate(args): | |||
| type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg") | |||
| if isinstance(arg, list): | |||
| for a in arg: | |||
| if not isinstance(a, int): | |||
| raise TypeError("Index " + a + " is not an int") | |||
| type_check(a, (int,), "a") | |||
| return method(self, *args) | |||
| return new_method | |||
| @@ -221,36 +118,14 @@ def check_mask_op(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| operator, constant, dtype = (list(args) + 3 * [None])[:3] | |||
| if "operator" in kwargs: | |||
| operator = kwargs.get("operator") | |||
| if "constant" in kwargs: | |||
| constant = kwargs.get("constant") | |||
| if "dtype" in kwargs: | |||
| dtype = kwargs.get("dtype") | |||
| if operator is None: | |||
| raise ValueError("operator is not provided.") | |||
| if constant is None: | |||
| raise ValueError("constant is not provided.") | |||
| [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs) | |||
| from .c_transforms import Relational | |||
| if not isinstance(operator, Relational): | |||
| raise TypeError("operator is not a Relational operator enum.") | |||
| type_check(operator, (Relational,), "operator") | |||
| type_check(constant, (str, float, bool, int, bytes), "constant") | |||
| type_check(dtype, (typing.Type,), "dtype") | |||
| if not isinstance(constant, (str, float, bool, int, bytes)): | |||
| raise TypeError("constant must be either a primitive python str, float, bool, bytes or int") | |||
| if dtype is not None: | |||
| if not isinstance(dtype, typing.Type): | |||
| raise TypeError("dtype is not a MindSpore data type.") | |||
| kwargs["dtype"] = dtype | |||
| kwargs["operator"] = operator | |||
| kwargs["constant"] = constant | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -260,22 +135,12 @@ def check_pad_end(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| pad_shape, pad_value = (list(args) + 2 * [None])[:2] | |||
| if "pad_shape" in kwargs: | |||
| pad_shape = kwargs.get("pad_shape") | |||
| if "pad_value" in kwargs: | |||
| pad_value = kwargs.get("pad_value") | |||
| if pad_shape is None: | |||
| raise ValueError("pad_shape is not provided.") | |||
| [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs) | |||
| if pad_value is not None: | |||
| if not isinstance(pad_value, (str, float, bool, int, bytes)): | |||
| raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes") | |||
| kwargs["pad_value"] = pad_value | |||
| if not isinstance(pad_shape, list): | |||
| raise TypeError("pad_shape must be a list") | |||
| type_check(pad_value, (str, float, bool, int, bytes), "pad_value") | |||
| type_check(pad_shape, (list,), "pad_end") | |||
| for dim in pad_shape: | |||
| if dim is not None: | |||
| @@ -284,9 +149,7 @@ def check_pad_end(method): | |||
| else: | |||
| raise TypeError("a value in the list is not an integer.") | |||
| kwargs["pad_shape"] = pad_shape | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -296,31 +159,24 @@ def check_concat_type(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| axis, prepend, append = (list(args) + 3 * [None])[:3] | |||
| if "prepend" in kwargs: | |||
| prepend = kwargs.get("prepend") | |||
| if "append" in kwargs: | |||
| append = kwargs.get("append") | |||
| if "axis" in kwargs: | |||
| axis = kwargs.get("axis") | |||
| [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs) | |||
| if axis is not None: | |||
| if not isinstance(axis, int): | |||
| raise TypeError("axis type is not valid, must be an integer.") | |||
| type_check(axis, (int,), "axis") | |||
| if axis not in (0, -1): | |||
| raise ValueError("only 1D concatenation supported.") | |||
| kwargs["axis"] = axis | |||
| if prepend is not None: | |||
| if not isinstance(prepend, (type(None), np.ndarray)): | |||
| raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") | |||
| kwargs["prepend"] = prepend | |||
| type_check(prepend, (np.ndarray,), "prepend") | |||
| if len(prepend.shape) != 1: | |||
| raise ValueError("can only prepend 1D arrays.") | |||
| if append is not None: | |||
| if not isinstance(append, (type(None), np.ndarray)): | |||
| raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") | |||
| kwargs["append"] = append | |||
| type_check(append, (np.ndarray,), "append") | |||
| if len(append.shape) != 1: | |||
| raise ValueError("can only append 1D arrays.") | |||
| return method(self, **kwargs) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -40,12 +40,14 @@ Examples: | |||
| >>> dataset = dataset.map(input_columns="image", operations=transforms_list) | |||
| >>> dataset = dataset.map(input_columns="label", operations=onehot_op) | |||
| """ | |||
| import numbers | |||
| import mindspore._c_dataengine as cde | |||
| from .utils import Inter, Border | |||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | |||
| check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||
| check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp | |||
| check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ | |||
| check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ | |||
| FLOAT_MAX_INTEGER | |||
| DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | |||
| Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | |||
| @@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT, | |||
| Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} | |||
| def parse_padding(padding): | |||
| if isinstance(padding, numbers.Number): | |||
| padding = [padding] * 4 | |||
| if len(padding) == 2: | |||
| left = right = padding[0] | |||
| top = bottom = padding[1] | |||
| padding = (left, top, right, bottom,) | |||
| if isinstance(padding, list): | |||
| padding = tuple(padding) | |||
| return padding | |||
| class Decode(cde.DecodeOp): | |||
| """ | |||
| Decode the input image in RGB mode. | |||
| @@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp): | |||
| @check_random_crop | |||
| def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | |||
| self.size = size | |||
| self.padding = padding | |||
| self.pad_if_needed = pad_if_needed | |||
| self.fill_value = fill_value | |||
| self.padding_mode = padding_mode.value | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| if padding is None: | |||
| padding = (0, 0, 0, 0) | |||
| else: | |||
| padding = parse_padding(padding) | |||
| if isinstance(fill_value, int): # temporary fix | |||
| fill_value = tuple([fill_value] * 3) | |||
| border_type = DE_C_BORDER_TYPE[padding_mode] | |||
| self.size = size | |||
| self.padding = padding | |||
| self.pad_if_needed = pad_if_needed | |||
| self.fill_value = fill_value | |||
| self.padding_mode = padding_mode.value | |||
| super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) | |||
| @@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp): | |||
| @check_random_crop | |||
| def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | |||
| self.size = size | |||
| self.padding = padding | |||
| self.pad_if_needed = pad_if_needed | |||
| self.fill_value = fill_value | |||
| self.padding_mode = padding_mode.value | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| if padding is None: | |||
| padding = (0, 0, 0, 0) | |||
| else: | |||
| padding = parse_padding(padding) | |||
| if isinstance(fill_value, int): # temporary fix | |||
| fill_value = tuple([fill_value] * 3) | |||
| border_type = DE_C_BORDER_TYPE[padding_mode] | |||
| self.size = size | |||
| self.padding = padding | |||
| self.pad_if_needed = pad_if_needed | |||
| self.fill_value = fill_value | |||
| self.padding_mode = padding_mode.value | |||
| super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) | |||
| @@ -292,6 +319,8 @@ class Resize(cde.ResizeOp): | |||
| @check_resize_interpolation | |||
| def __init__(self, size, interpolation=Inter.LINEAR): | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| self.size = size | |||
| self.interpolation = interpolation | |||
| interpoltn = DE_C_INTER_MODE[interpolation] | |||
| @@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): | |||
| @check_random_resize_crop | |||
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | |||
| interpolation=Inter.BILINEAR, max_attempts=10): | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| self.size = size | |||
| self.scale = scale | |||
| self.ratio = ratio | |||
| @@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp): | |||
| @check_random_resize_crop | |||
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | |||
| interpolation=Inter.BILINEAR, max_attempts=10): | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| self.size = size | |||
| self.scale = scale | |||
| self.ratio = ratio | |||
| @@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp): | |||
| @check_crop | |||
| def __init__(self, size): | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| self.size = size | |||
| super().__init__(*size) | |||
| @@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp): | |||
| @check_random_color_adjust | |||
| def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)): | |||
| brightness = self.expand_values(brightness) | |||
| contrast = self.expand_values(contrast) | |||
| saturation = self.expand_values(saturation) | |||
| hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False) | |||
| self.brightness = brightness | |||
| self.contrast = contrast | |||
| self.saturation = saturation | |||
| self.hue = hue | |||
| super().__init__(*brightness, *contrast, *saturation, *hue) | |||
| def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): | |||
| if isinstance(value, numbers.Number): | |||
| value = [center - value, center + value] | |||
| if non_negative: | |||
| value[0] = max(0, value[0]) | |||
| check_range(value, bound) | |||
| return (value[0], value[1]) | |||
| class RandomRotation(cde.RandomRotationOp): | |||
| """ | |||
| @@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp): | |||
| self.expand = expand | |||
| self.center = center | |||
| self.fill_value = fill_value | |||
| if isinstance(degrees, numbers.Number): | |||
| degrees = (-degrees, degrees) | |||
| if center is None: | |||
| center = (-1, -1) | |||
| if isinstance(fill_value, int): # temporary fix | |||
| @@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp): | |||
| @check_random_resize_crop | |||
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | |||
| interpolation=Inter.BILINEAR, max_attempts=10): | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| self.size = size | |||
| self.scale = scale | |||
| self.ratio = ratio | |||
| @@ -623,12 +676,14 @@ class Pad(cde.PadOp): | |||
| @check_pad | |||
| def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): | |||
| self.padding = padding | |||
| self.fill_value = fill_value | |||
| self.padding_mode = padding_mode | |||
| padding = parse_padding(padding) | |||
| if isinstance(fill_value, int): # temporary fix | |||
| fill_value = tuple([fill_value] * 3) | |||
| padding_mode = DE_C_BORDER_TYPE[padding_mode] | |||
| self.padding = padding | |||
| self.fill_value = fill_value | |||
| self.padding_mode = padding_mode | |||
| super().__init__(*padding, padding_mode, *fill_value) | |||
| @@ -28,6 +28,7 @@ import numpy as np | |||
| from PIL import Image | |||
| from . import py_transforms_util as util | |||
| from .c_transforms import parse_padding | |||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | |||
| check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||
| check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ | |||
| @@ -295,6 +296,10 @@ class RandomCrop: | |||
| @check_random_crop | |||
| def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | |||
| if padding is None: | |||
| padding = (0, 0, 0, 0) | |||
| else: | |||
| padding = parse_padding(padding) | |||
| self.size = size | |||
| self.padding = padding | |||
| self.pad_if_needed = pad_if_needed | |||
| @@ -753,6 +758,8 @@ class TenCrop: | |||
| @check_ten_crop | |||
| def __init__(self, size, use_vertical_flip=False): | |||
| if isinstance(size, int): | |||
| size = (size, size) | |||
| self.size = size | |||
| self.use_vertical_flip = use_vertical_flip | |||
| @@ -877,6 +884,8 @@ class Pad: | |||
| @check_pad | |||
| def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): | |||
| parse_padding(padding) | |||
| self.padding = padding | |||
| self.fill_value = fill_value | |||
| self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] | |||
| @@ -1129,56 +1138,23 @@ class RandomAffine: | |||
| def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0): | |||
| # Parameter checking | |||
| # rotation | |||
| if isinstance(degrees, numbers.Number): | |||
| if degrees < 0: | |||
| raise ValueError("If degrees is a single number, it must be positive.") | |||
| self.degrees = (-degrees, degrees) | |||
| elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: | |||
| self.degrees = degrees | |||
| else: | |||
| raise TypeError("If degrees is a list or tuple, it must be of length 2.") | |||
| # translation | |||
| if translate is not None: | |||
| if isinstance(translate, (tuple, list)) and len(translate) == 2: | |||
| for t in translate: | |||
| if t < 0.0 or t > 1.0: | |||
| raise ValueError("translation values should be between 0 and 1") | |||
| else: | |||
| raise TypeError("translate should be a list or tuple of length 2.") | |||
| self.translate = translate | |||
| # scale | |||
| if scale is not None: | |||
| if isinstance(scale, (tuple, list)) and len(scale) == 2: | |||
| for s in scale: | |||
| if s <= 0: | |||
| raise ValueError("scale values should be positive") | |||
| else: | |||
| raise TypeError("scale should be a list or tuple of length 2.") | |||
| self.scale_ranges = scale | |||
| # shear | |||
| if shear is not None: | |||
| if isinstance(shear, numbers.Number): | |||
| if shear < 0: | |||
| raise ValueError("If shear is a single number, it must be positive.") | |||
| self.shear = (-1 * shear, shear) | |||
| elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4): | |||
| # X-Axis shear with [min, max] | |||
| shear = (-1 * shear, shear) | |||
| else: | |||
| if len(shear) == 2: | |||
| self.shear = [shear[0], shear[1], 0., 0.] | |||
| shear = [shear[0], shear[1], 0., 0.] | |||
| elif len(shear) == 4: | |||
| self.shear = [s for s in shear] | |||
| else: | |||
| raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.") | |||
| else: | |||
| self.shear = shear | |||
| shear = [s for s in shear] | |||
| # resample | |||
| self.resample = DE_PY_INTER_MODE[resample] | |||
| if isinstance(degrees, numbers.Number): | |||
| degrees = (-degrees, degrees) | |||
| # fill_value | |||
| self.degrees = degrees | |||
| self.translate = translate | |||
| self.scale_ranges = scale | |||
| self.shear = shear | |||
| self.resample = DE_PY_INTER_MODE[resample] | |||
| self.fill_value = fill_value | |||
| def __call__(self, img): | |||
| @@ -15,13 +15,15 @@ | |||
| """ | |||
| Testing the bounding box augment op in DE | |||
| """ | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| import numpy as np | |||
| import mindspore.log as logger | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| # updated VOC dataset with correct annotations | |||
| @@ -241,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c(): | |||
| operations=[test_op]) # Add column for "annotation" | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input is not" in str(error) | |||
| assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error) | |||
| def test_bounding_box_augment_invalid_bounds_c(): | |||
| @@ -17,6 +17,7 @@ import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| # generates 1 column [0], [0, 1], ..., [0, ..., n-1] | |||
| def generate_sequential(n): | |||
| for i in range(n): | |||
| @@ -99,12 +100,12 @@ def test_bucket_batch_invalid_input(): | |||
| with pytest.raises(TypeError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, | |||
| None, None, invalid_type_pad_to_bucket_boundary) | |||
| assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value) | |||
| assert "Argument pad_to_bucket_boundary with value \"\" is not of type (<class \'bool\'>,)." in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, | |||
| None, None, False, invalid_type_drop_remainder) | |||
| assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value) | |||
| assert "Argument drop_remainder with value \"\" is not of type (<class 'bool'>,)." in str(info.value) | |||
| def test_bucket_batch_multi_bucket_no_padding(): | |||
| @@ -272,7 +273,6 @@ def test_bucket_batch_default_pad(): | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | |||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] | |||
| output = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| output.append(data["col1"].tolist()) | |||
| @@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis(): | |||
| def test_concatenate_op_incorrect_input_dim(): | |||
| def gen(): | |||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||
| prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| for _ in data: | |||
| pass | |||
| assert "Only 1D tensors supported" in repr(error_info.value) | |||
| with pytest.raises(ValueError) as error_info: | |||
| data_trans.Concatenate(0, prepend_tensor) | |||
| assert "can only prepend 1D arrays." in repr(error_info.value) | |||
| if __name__ == "__main__": | |||
| @@ -28,9 +28,9 @@ def test_exception_01(): | |||
| """ | |||
| logger.info("test_exception_01") | |||
| data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) | |||
| with pytest.raises(ValueError) as info: | |||
| data = data.map(input_columns=["image"], operations=vision.Resize(100, 100)) | |||
| assert "Invalid interpolation mode." in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| data.map(input_columns=["image"], operations=vision.Resize(100, 100)) | |||
| assert "Argument interpolation with value 100 is not of type (<enum 'Inter'>,)" in str(info.value) | |||
| def test_exception_02(): | |||
| @@ -40,8 +40,8 @@ def test_exception_02(): | |||
| logger.info("test_exception_02") | |||
| num_samples = -1 | |||
| with pytest.raises(ValueError) as info: | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||
| assert "num_samples cannot be less than 0" in str(info.value) | |||
| ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||
| assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value) | |||
| num_samples = 1 | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||
| @@ -23,7 +23,8 @@ import mindspore.dataset.text as text | |||
| def test_demo_basic_from_dataset(): | |||
| """ this is a tutorial on how from_dataset should be used in a normal use case""" | |||
| data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | |||
| vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"], | |||
| vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, | |||
| special_tokens=["<pad>", "<unk>"], | |||
| special_first=True) | |||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | |||
| res = [] | |||
| @@ -127,15 +128,16 @@ def test_from_dataset_exceptions(): | |||
| data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | |||
| vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) | |||
| assert isinstance(vocab.text.Vocab) | |||
| except ValueError as e: | |||
| except (TypeError, ValueError, RuntimeError) as e: | |||
| assert s in str(e), str(e) | |||
| test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers") | |||
| test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer") | |||
| test_config(23, (2, 3), 1.2345, "columns need to be a list of strings") | |||
| test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b") | |||
| test_config("text", (2, 3), 0, "top_k needs to be a positive integer") | |||
| test_config([123], (2, 3), 0, "columns need to be a list of strings") | |||
| test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") | |||
| test_config("text", (2, 3), 1.2345, | |||
| "Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)") | |||
| test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)") | |||
| test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") | |||
| test_config("text", (2, 3), 0, "top_k needs to be positive number") | |||
| test_config([123], (2, 3), 0, "top_k needs to be positive number") | |||
| if __name__ == '__main__': | |||
| @@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False): | |||
| if plot: | |||
| visualize_list(image, image_transformed) | |||
| def test_linear_transformation_md5(): | |||
| """ | |||
| Test LinearTransformation op: valid params (transformation_matrix, mean_vector) | |||
| @@ -102,6 +103,7 @@ def test_linear_transformation_md5(): | |||
| filename = "linear_transformation_01_result.npz" | |||
| save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | |||
| def test_linear_transformation_exception_01(): | |||
| """ | |||
| Test LinearTransformation op: transformation_matrix is not provided | |||
| @@ -126,9 +128,10 @@ def test_linear_transformation_exception_01(): | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| data1 = data1.map(input_columns=["image"], operations=transform()) | |||
| except ValueError as e: | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "not provided" in str(e) | |||
| assert "Argument transformation_matrix with value None is not of type (<class 'numpy.ndarray'>,)" in str(e) | |||
| def test_linear_transformation_exception_02(): | |||
| """ | |||
| @@ -154,9 +157,10 @@ def test_linear_transformation_exception_02(): | |||
| ] | |||
| transform = py_vision.ComposeOp(transforms) | |||
| data1 = data1.map(input_columns=["image"], operations=transform()) | |||
| except ValueError as e: | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "not provided" in str(e) | |||
| assert "Argument mean_vector with value None is not of type (<class 'numpy.ndarray'>,)" in str(e) | |||
| def test_linear_transformation_exception_03(): | |||
| """ | |||
| @@ -187,6 +191,7 @@ def test_linear_transformation_exception_03(): | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "square matrix" in str(e) | |||
| def test_linear_transformation_exception_04(): | |||
| """ | |||
| Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix | |||
| @@ -199,7 +204,7 @@ def test_linear_transformation_exception_04(): | |||
| weight = 50 | |||
| dim = 3 * height * weight | |||
| transformation_matrix = np.ones([dim, dim]) | |||
| mean_vector = np.zeros(dim-1) | |||
| mean_vector = np.zeros(dim - 1) | |||
| # Generate dataset | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| @@ -216,6 +221,7 @@ def test_linear_transformation_exception_04(): | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "should match" in str(e) | |||
| if __name__ == '__main__': | |||
| test_linear_transformation_op(plot=True) | |||
| test_linear_transformation_md5() | |||
| @@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards(): | |||
| create_cv_mindrecord(1) | |||
| columns_list = ["data", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| with pytest.raises(Exception) as error_info: | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| def test_minddataset_invalidate_shard_id(): | |||
| create_cv_mindrecord(1) | |||
| columns_list = ["data", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| with pytest.raises(Exception) as error_info: | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard(): | |||
| create_cv_mindrecord(1) | |||
| columns_list = ["data", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| with pytest.raises(Exception) as error_info: | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) | |||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||
| with pytest.raises(Exception) as error_info: | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| @@ -15,9 +15,9 @@ | |||
| """ | |||
| Testing Ngram in mindspore.dataset | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.text as text | |||
| import numpy as np | |||
| def test_multiple_ngrams(): | |||
| @@ -61,7 +61,7 @@ def test_simple_ngram(): | |||
| yield (np.array(line.split(" "), dtype='S'),) | |||
| dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) | |||
| dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None)) | |||
| dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" ")) | |||
| i = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| @@ -72,7 +72,7 @@ def test_simple_ngram(): | |||
| def test_corner_cases(): | |||
| """ testing various corner cases and exceptions""" | |||
| def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None): | |||
| def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): | |||
| def gen(texts): | |||
| yield (np.array(texts.split(" "), dtype='S'),) | |||
| @@ -93,7 +93,7 @@ def test_corner_cases(): | |||
| try: | |||
| test_config("Yours to Discover", "", [0, [1]]) | |||
| except Exception as e: | |||
| assert "ngram needs to be a positive number" in str(e) | |||
| assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e) | |||
| # test empty n | |||
| try: | |||
| test_config("Yours to Discover", "", []) | |||
| @@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py(): | |||
| _ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not within the required range" in str(e) | |||
| assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e) | |||
| def test_normalize_grayscale_md5_01(): | |||
| @@ -61,6 +61,10 @@ def test_pad_end_exceptions(): | |||
| pad_compare([3, 4, 5], ["2"], 1, []) | |||
| assert "a value in the list is not an integer." in str(info.value) | |||
| with pytest.raises(TypeError) as info: | |||
| pad_compare([1, 2], 3, -1, [1, 2, -1]) | |||
| assert "Argument pad_end with value 3 is not of type (<class 'list'>,)" in str(info.value) | |||
| if __name__ == "__main__": | |||
| test_pad_end_basics() | |||
| @@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees(): | |||
| _ = py_vision.RandomAffine(degrees=-15) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "If degrees is a single number, it cannot be negative." | |||
| assert str(e) == "Input degrees is not within the required interval of (0 to inf)." | |||
| def test_random_affine_exception_translation_range(): | |||
| @@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range(): | |||
| _ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5)) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "translation values should be between 0 and 1" | |||
| assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)." | |||
| def test_random_affine_exception_scale_value(): | |||
| @@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value(): | |||
| _ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1)) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "scale values should be positive" | |||
| assert str(e) == "Input scale[0] must be greater than 0." | |||
| def test_random_affine_exception_shear_value(): | |||
| @@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value(): | |||
| _ = py_vision.RandomAffine(degrees=15, shear=-5) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "If shear is a single number, it must be positive." | |||
| assert str(e) == "Input shear must be greater than 0." | |||
| def test_random_affine_exception_degrees_size(): | |||
| @@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size(): | |||
| _ = py_vision.RandomAffine(degrees=15, translate=(0.1)) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "translate should be a list or tuple of length 2." | |||
| assert str( | |||
| e) == "Argument translate with value 0.1 is not of type (<class 'list'>," \ | |||
| " <class 'tuple'>)." | |||
| def test_random_affine_exception_scale_size(): | |||
| @@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size(): | |||
| _ = py_vision.RandomAffine(degrees=15, scale=(0.5)) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "scale should be a list or tuple of length 2." | |||
| assert str(e) == "Argument scale with value 0.5 is not of type (<class 'tuple'>," \ | |||
| " <class 'list'>)." | |||
| def test_random_affine_exception_shear_size(): | |||
| @@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size(): | |||
| _ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10)) | |||
| except TypeError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4." | |||
| assert str(e) == "shear must be of length 2 or 4." | |||
| if __name__ == "__main__": | |||
| @@ -97,7 +97,7 @@ def test_random_color_md5(): | |||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | |||
| transforms = F.ComposeOp([F.Decode(), | |||
| F.RandomColor((0.5, 1.5)), | |||
| F.RandomColor((0.1, 1.9)), | |||
| F.ToTensor()]) | |||
| data = data.map(input_columns="image", operations=transforms()) | |||
| @@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c(): | |||
| data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input range is not valid" in str(e) | |||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||
| def test_random_crop_and_resize_04_py(): | |||
| @@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py(): | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input range is not valid" in str(e) | |||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||
| def test_random_crop_and_resize_05_c(): | |||
| @@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c(): | |||
| data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input range is not valid" in str(e) | |||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||
| def test_random_crop_and_resize_05_py(): | |||
| @@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py(): | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input range is not valid" in str(e) | |||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||
| def test_random_crop_and_resize_comp(plot=False): | |||
| @@ -159,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c(): | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input range is not valid" in str(err) | |||
| assert "Input is not within the required interval of (0 to 16777216)." in str(err) | |||
| def test_random_resized_crop_with_bbox_op_invalid2_c(): | |||
| @@ -185,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input range is not valid" in str(err) | |||
| assert "Input is not within the required interval of (0 to 16777216)." in str(err) | |||
| def test_random_resized_crop_with_bbox_op_bad_c(): | |||
| @@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param(): | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not within the required range" in str(e) | |||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) | |||
| if __name__ == "__main__": | |||
| test_random_grayscale_valid_prob(True) | |||
| @@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c(): | |||
| data = data.map(input_columns=["image"], operations=random_horizontal_op) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not" in str(e) | |||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) | |||
| def test_random_horizontal_invalid_prob_py(): | |||
| @@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py(): | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not" in str(e) | |||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) | |||
| def test_random_horizontal_comp(plot=False): | |||
| @@ -190,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): | |||
| operations=[test_op]) # Add column for "annotation" | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input is not" in str(error) | |||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error) | |||
| def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): | |||
| @@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range(): | |||
| _ = py_vision.RandomPerspective(distortion_scale=1.5) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Input is not within the required range" | |||
| assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)." | |||
| def test_random_perspective_exception_prob_range(): | |||
| @@ -119,7 +119,7 @@ def test_random_perspective_exception_prob_range(): | |||
| _ = py_vision.RandomPerspective(prob=1.2) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert str(e) == "Input is not within the required range" | |||
| assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)." | |||
| if __name__ == "__main__": | |||
| @@ -163,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c(): | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| assert "Input is not within the required interval of (1 to 16777216)." in str(err) | |||
| try: | |||
| # one of the size values is zero | |||
| @@ -171,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c(): | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err) | |||
| try: | |||
| # negative value for resize | |||
| @@ -179,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c(): | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| assert "Input is not within the required interval of (1 to 16777216)." in str(err) | |||
| try: | |||
| # invalid input shape | |||
| @@ -97,7 +97,7 @@ def test_random_sharpness_md5(): | |||
| # define map operations | |||
| transforms = [ | |||
| F.Decode(), | |||
| F.RandomSharpness((0.5, 1.5)), | |||
| F.RandomSharpness((0.1, 1.9)), | |||
| F.ToTensor() | |||
| ] | |||
| transform = F.ComposeOp(transforms) | |||
| @@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c(): | |||
| data = data.map(input_columns=["image"], operations=random_horizontal_op) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not" in str(e) | |||
| assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) | |||
| def test_random_vertical_invalid_prob_py(): | |||
| @@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py(): | |||
| data = data.map(input_columns=["image"], operations=transform()) | |||
| except ValueError as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Input is not" in str(e) | |||
| assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) | |||
| def test_random_vertical_comp(plot=False): | |||
| @@ -191,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err) | |||
| def test_random_vertical_flip_with_bbox_op_bad_c(): | |||
| @@ -150,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c(): | |||
| # invalid interpolation value | |||
| c_vision.ResizeWithBBox(400, interpolation="invalid") | |||
| except ValueError as err: | |||
| except TypeError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "interpolation" in str(err) | |||
| @@ -154,7 +154,7 @@ def test_shuffle_exception_01(): | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) | |||
| def test_shuffle_exception_02(): | |||
| @@ -172,7 +172,7 @@ def test_shuffle_exception_02(): | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) | |||
| def test_shuffle_exception_03(): | |||
| @@ -190,7 +190,7 @@ def test_shuffle_exception_03(): | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "buffer_size" in str(e) | |||
| assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) | |||
| def test_shuffle_exception_05(): | |||
| @@ -62,7 +62,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False): | |||
| logger.info("dtype of image_2: {}".format(image_2.dtype)) | |||
| if plot: | |||
| visualize_list(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) | |||
| visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) | |||
| # The output data should be of a 4D tensor shape, a stack of 10 images. | |||
| assert len(image_2.shape) == 4 | |||
| @@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg(): | |||
| vision.TenCrop(0), | |||
| lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images | |||
| ] | |||
| error_msg = "Input is not within the required range" | |||
| error_msg = "Input is not within the required interval of (1 to 16777216)." | |||
| assert error_msg == str(info.value) | |||
| with pytest.raises(ValueError) as info: | |||
| @@ -169,7 +169,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "operations" in str(e) | |||
| assert "Argument tensor_op_5 with value" \ | |||
| " <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e) | |||
| assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e) | |||
| def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | |||
| @@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "num_ops" in str(e) | |||
| assert "Input num_ops must be greater than 0" in str(e) | |||
| def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): | |||
| @@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): | |||
| except Exception as e: | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "integer" in str(e) | |||
| assert "Argument num_ops with value 2.5 is not of type (<class 'int'>,)" in str(e) | |||
| def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): | |||
| @@ -314,14 +314,15 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows= | |||
| if len(orig) != len(aug) or not orig: | |||
| return | |||
| batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together | |||
| batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together | |||
| split_point = batch_size * plot_rows | |||
| orig, aug = np.array(orig), np.array(aug) | |||
| if len(orig) > plot_rows: | |||
| # Create batches of required size and add remainder to last batch | |||
| orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added | |||
| orig = np.split(orig[:split_point], batch_size) + ( | |||
| [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added | |||
| aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) | |||
| else: | |||
| orig = [orig] | |||
| @@ -336,7 +337,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows= | |||
| for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): | |||
| cur_ix = base_ix + x | |||
| (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row | |||
| # select plotting axes based on number of image rows on plot - else case when 1 row | |||
| (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) | |||
| axA.imshow(dataA["image"]) | |||
| add_bounding_boxes(axA, dataA[annot_name]) | |||