Merge pull request !2833 from nhussain/engine_validatorstags/v0.6.0-beta
| @@ -0,0 +1,342 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """ | |||||
| General Validators. | |||||
| """ | |||||
| import inspect | |||||
| from multiprocessing import cpu_count | |||||
| import os | |||||
| import numpy as np | |||||
| from ..engine import samplers | |||||
| # POS_INT_MIN is used to limit values from starting from 0 | |||||
| POS_INT_MIN = 1 | |||||
| UINT8_MAX = 255 | |||||
| UINT8_MIN = 0 | |||||
| UINT32_MAX = 4294967295 | |||||
| UINT32_MIN = 0 | |||||
| UINT64_MAX = 18446744073709551615 | |||||
| UINT64_MIN = 0 | |||||
| INT32_MAX = 2147483647 | |||||
| INT32_MIN = -2147483648 | |||||
| INT64_MAX = 9223372036854775807 | |||||
| INT64_MIN = -9223372036854775808 | |||||
| FLOAT_MAX_INTEGER = 16777216 | |||||
| FLOAT_MIN_INTEGER = -16777216 | |||||
| DOUBLE_MAX_INTEGER = 9007199254740992 | |||||
| DOUBLE_MIN_INTEGER = -9007199254740992 | |||||
| valid_detype = [ | |||||
| "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", | |||||
| "uint32", "uint64", "float16", "float32", "float64", "string" | |||||
| ] | |||||
| def pad_arg_name(arg_name): | |||||
| if arg_name != "": | |||||
| arg_name = arg_name + " " | |||||
| return arg_name | |||||
| def check_value(value, valid_range, arg_name=""): | |||||
| arg_name = pad_arg_name(arg_name) | |||||
| if value < valid_range[0] or value > valid_range[1]: | |||||
| raise ValueError( | |||||
| "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], | |||||
| valid_range[1])) | |||||
| def check_range(values, valid_range, arg_name=""): | |||||
| arg_name = pad_arg_name(arg_name) | |||||
| if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: | |||||
| raise ValueError( | |||||
| "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], | |||||
| valid_range[1])) | |||||
| def check_positive(value, arg_name=""): | |||||
| arg_name = pad_arg_name(arg_name) | |||||
| if value <= 0: | |||||
| raise ValueError("Input {0}must be greater than 0.".format(arg_name)) | |||||
| def check_positive_float(value, arg_name=""): | |||||
| arg_name = pad_arg_name(arg_name) | |||||
| type_check(value, (float,), arg_name) | |||||
| check_positive(value, arg_name) | |||||
| def check_2tuple(value, arg_name=""): | |||||
| if not (isinstance(value, tuple) and len(value) == 2): | |||||
| raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name)) | |||||
| def check_uint8(value, arg_name=""): | |||||
| type_check(value, (int,), arg_name) | |||||
| check_value(value, [UINT8_MIN, UINT8_MAX]) | |||||
| def check_uint32(value, arg_name=""): | |||||
| type_check(value, (int,), arg_name) | |||||
| check_value(value, [UINT32_MIN, UINT32_MAX]) | |||||
| def check_pos_int32(value, arg_name=""): | |||||
| type_check(value, (int,), arg_name) | |||||
| check_value(value, [POS_INT_MIN, INT32_MAX]) | |||||
| def check_uint64(value, arg_name=""): | |||||
| type_check(value, (int,), arg_name) | |||||
| check_value(value, [UINT64_MIN, UINT64_MAX]) | |||||
| def check_pos_int64(value, arg_name=""): | |||||
| type_check(value, (int,), arg_name) | |||||
| check_value(value, [UINT64_MIN, INT64_MAX]) | |||||
| def check_pos_float32(value, arg_name=""): | |||||
| check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name) | |||||
| def check_pos_float64(value, arg_name=""): | |||||
| check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name) | |||||
| def check_valid_detype(type_): | |||||
| if type_ not in valid_detype: | |||||
| raise ValueError("Unknown column type") | |||||
| return True | |||||
| def check_columns(columns, name): | |||||
| type_check(columns, (list, str), name) | |||||
| if isinstance(columns, list): | |||||
| if not columns: | |||||
| raise ValueError("Column names should not be empty") | |||||
| col_names = ["col_{0}".format(i) for i in range(len(columns))] | |||||
| type_check_list(columns, (str,), col_names) | |||||
| def parse_user_args(method, *args, **kwargs): | |||||
| """ | |||||
| Parse user arguments in a function | |||||
| Args: | |||||
| method (method): a callable function | |||||
| *args: user passed args | |||||
| **kwargs: user passed kwargs | |||||
| Returns: | |||||
| user_filled_args (list): values of what the user passed in for the arguments, | |||||
| ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. | |||||
| """ | |||||
| sig = inspect.signature(method) | |||||
| if 'self' in sig.parameters or 'cls' in sig.parameters: | |||||
| ba = sig.bind(method, *args, **kwargs) | |||||
| ba.apply_defaults() | |||||
| params = list(sig.parameters.keys())[1:] | |||||
| else: | |||||
| ba = sig.bind(*args, **kwargs) | |||||
| ba.apply_defaults() | |||||
| params = list(sig.parameters.keys()) | |||||
| user_filled_args = [ba.arguments.get(arg_value) for arg_value in params] | |||||
| return user_filled_args, ba.arguments | |||||
| def type_check_list(args, types, arg_names): | |||||
| """ | |||||
| Check the type of each parameter in the list | |||||
| Args: | |||||
| args (list, tuple): a list or tuple of any variable | |||||
| types (tuple): tuple of all valid types for arg | |||||
| arg_names (list, tuple of str): the names of args | |||||
| Returns: | |||||
| Exception: when the type is not correct, otherwise nothing | |||||
| """ | |||||
| type_check(args, (list, tuple,), arg_names) | |||||
| if len(args) != len(arg_names): | |||||
| raise ValueError("List of arguments is not the same length as argument_names.") | |||||
| for arg, arg_name in zip(args, arg_names): | |||||
| type_check(arg, types, arg_name) | |||||
| def type_check(arg, types, arg_name): | |||||
| """ | |||||
| Check the type of the parameter | |||||
| Args: | |||||
| arg : any variable | |||||
| types (tuple): tuple of all valid types for arg | |||||
| arg_name (str): the name of arg | |||||
| Returns: | |||||
| Exception: when the type is not correct, otherwise nothing | |||||
| """ | |||||
| # handle special case of booleans being a subclass of ints | |||||
| print_value = '\"\"' if repr(arg) == repr('') else arg | |||||
| if int in types and bool not in types: | |||||
| if isinstance(arg, bool): | |||||
| raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) | |||||
| if not isinstance(arg, types): | |||||
| raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) | |||||
| def check_filename(path): | |||||
| """ | |||||
| check the filename in the path | |||||
| Args: | |||||
| path (str): the path | |||||
| Returns: | |||||
| Exception: when error | |||||
| """ | |||||
| if not isinstance(path, str): | |||||
| raise TypeError("path: {} is not string".format(path)) | |||||
| filename = os.path.basename(path) | |||||
| # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', | |||||
| # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', | |||||
| # '*', '(', '%', ')', '-', '=', '{', '?', '$' | |||||
| forbidden_symbols = set(r'\/:*?"<>|`&\';') | |||||
| if set(filename) & forbidden_symbols: | |||||
| raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'") | |||||
| if filename.startswith(' ') or filename.endswith(' '): | |||||
| raise ValueError("filename should not start/end with space") | |||||
| return True | |||||
| def check_dir(dataset_dir): | |||||
| if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): | |||||
| raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) | |||||
| def check_file(dataset_file): | |||||
| check_filename(dataset_file) | |||||
| if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): | |||||
| raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) | |||||
| def check_sampler_shuffle_shard_options(param_dict): | |||||
| """ | |||||
| Check for valid shuffle, sampler, num_shards, and shard_id inputs. | |||||
| Args: | |||||
| param_dict (dict): param_dict | |||||
| Returns: | |||||
| Exception: ValueError or RuntimeError if error | |||||
| """ | |||||
| shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') | |||||
| num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | |||||
| type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler") | |||||
| if sampler is not None: | |||||
| if shuffle is not None: | |||||
| raise RuntimeError("sampler and shuffle cannot be specified at the same time.") | |||||
| if num_shards is not None: | |||||
| check_pos_int32(num_shards) | |||||
| if shard_id is None: | |||||
| raise RuntimeError("num_shards is specified and currently requires shard_id as well.") | |||||
| check_value(shard_id, [0, num_shards - 1], "shard_id") | |||||
| if num_shards is None and shard_id is not None: | |||||
| raise RuntimeError("shard_id is specified but num_shards is not.") | |||||
| def check_padding_options(param_dict): | |||||
| """ | |||||
| Check for valid padded_sample and num_padded of padded samples | |||||
| Args: | |||||
| param_dict (dict): param_dict | |||||
| Returns: | |||||
| Exception: ValueError or RuntimeError if error | |||||
| """ | |||||
| columns_list = param_dict.get('columns_list') | |||||
| block_reader = param_dict.get('block_reader') | |||||
| padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') | |||||
| if padded_sample is not None: | |||||
| if num_padded is None: | |||||
| raise RuntimeError("padded_sample is specified and requires num_padded as well.") | |||||
| if num_padded < 0: | |||||
| raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) | |||||
| if columns_list is None: | |||||
| raise RuntimeError("padded_sample is specified and requires columns_list as well.") | |||||
| for column in columns_list: | |||||
| if column not in padded_sample: | |||||
| raise ValueError("padded_sample cannot match columns_list.") | |||||
| if block_reader: | |||||
| raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") | |||||
| if padded_sample is None and num_padded is not None: | |||||
| raise RuntimeError("num_padded is specified but padded_sample is not.") | |||||
| def check_num_parallel_workers(value): | |||||
| type_check(value, (int,), "num_parallel_workers") | |||||
| if value < 1 or value > cpu_count(): | |||||
| raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) | |||||
| def check_num_samples(value): | |||||
| type_check(value, (int,), "num_samples") | |||||
| check_value(value, [0, INT32_MAX], "num_samples") | |||||
| def validate_dataset_param_value(param_list, param_dict, param_type): | |||||
| for param_name in param_list: | |||||
| if param_dict.get(param_name) is not None: | |||||
| if param_name == 'num_parallel_workers': | |||||
| check_num_parallel_workers(param_dict.get(param_name)) | |||||
| if param_name == 'num_samples': | |||||
| check_num_samples(param_dict.get(param_name)) | |||||
| else: | |||||
| type_check(param_dict.get(param_name), (param_type,), param_name) | |||||
| def check_gnn_list_or_ndarray(param, param_name): | |||||
| """ | |||||
| Check if the input parameter is list or numpy.ndarray. | |||||
| Args: | |||||
| param (list, nd.ndarray): param | |||||
| param_name (str): param_name | |||||
| Returns: | |||||
| Exception: TypeError if error | |||||
| """ | |||||
| type_check(param, (list, np.ndarray), param_name) | |||||
| if isinstance(param, list): | |||||
| param_names = ["param_{0}".format(i) for i in range(len(param))] | |||||
| type_check_list(param, (int,), param_names) | |||||
| elif isinstance(param, np.ndarray): | |||||
| if not param.dtype == np.int32: | |||||
| raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( | |||||
| param_name, param.dtype)) | |||||
| @@ -98,7 +98,7 @@ class Ngram(cde.NgramOp): | |||||
| """ | """ | ||||
| @check_ngram | @check_ngram | ||||
| def __init__(self, n, left_pad=None, right_pad=None, separator=None): | |||||
| def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "): | |||||
| super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0], | super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0], | ||||
| r_pad_token=right_pad[0], separator=separator) | r_pad_token=right_pad[0], separator=separator) | ||||
| @@ -28,6 +28,7 @@ __all__ = [ | |||||
| "Vocab", "to_str", "to_bytes" | "Vocab", "to_str", "to_bytes" | ||||
| ] | ] | ||||
| class Vocab(cde.Vocab): | class Vocab(cde.Vocab): | ||||
| """ | """ | ||||
| Vocab object that is used to lookup a word. | Vocab object that is used to lookup a word. | ||||
| @@ -38,7 +39,7 @@ class Vocab(cde.Vocab): | |||||
| @classmethod | @classmethod | ||||
| @check_from_dataset | @check_from_dataset | ||||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, | def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, | ||||
| special_first=None): | |||||
| special_first=True): | |||||
| """ | """ | ||||
| Build a vocab from a dataset. | Build a vocab from a dataset. | ||||
| @@ -62,13 +63,21 @@ class Vocab(cde.Vocab): | |||||
| special_tokens(list, optional): a list of strings, each one is a special token. for example | special_tokens(list, optional): a list of strings, each one is a special token. for example | ||||
| special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | ||||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens | special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens | ||||
| is specified and special_first is set to None, special_tokens will be prepended (default=None). | |||||
| is specified and special_first is set to True, special_tokens will be prepended (default=True). | |||||
| Returns: | Returns: | ||||
| Vocab, Vocab object built from dataset. | Vocab, Vocab object built from dataset. | ||||
| """ | """ | ||||
| vocab = Vocab() | vocab = Vocab() | ||||
| if columns is None: | |||||
| columns = [] | |||||
| if not isinstance(columns, list): | |||||
| columns = [columns] | |||||
| if freq_range is None: | |||||
| freq_range = (None, None) | |||||
| if special_tokens is None: | |||||
| special_tokens = [] | |||||
| root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) | root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) | ||||
| for d in root.create_dict_iterator(): | for d in root.create_dict_iterator(): | ||||
| if d is not None: | if d is not None: | ||||
| @@ -77,7 +86,7 @@ class Vocab(cde.Vocab): | |||||
| @classmethod | @classmethod | ||||
| @check_from_list | @check_from_list | ||||
| def from_list(cls, word_list, special_tokens=None, special_first=None): | |||||
| def from_list(cls, word_list, special_tokens=None, special_first=True): | |||||
| """ | """ | ||||
| Build a vocab object from a list of word. | Build a vocab object from a list of word. | ||||
| @@ -86,29 +95,33 @@ class Vocab(cde.Vocab): | |||||
| special_tokens(list, optional): a list of strings, each one is a special token. for example | special_tokens(list, optional): a list of strings, each one is a special token. for example | ||||
| special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | ||||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens | special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens | ||||
| is specified and special_first is set to None, special_tokens will be prepended (default=None). | |||||
| is specified and special_first is set to True, special_tokens will be prepended (default=True). | |||||
| """ | """ | ||||
| if special_tokens is None: | |||||
| special_tokens = [] | |||||
| return super().from_list(word_list, special_tokens, special_first) | return super().from_list(word_list, special_tokens, special_first) | ||||
| @classmethod | @classmethod | ||||
| @check_from_file | @check_from_file | ||||
| def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None): | |||||
| def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True): | |||||
| """ | """ | ||||
| Build a vocab object from a list of word. | Build a vocab object from a list of word. | ||||
| Args: | Args: | ||||
| file_path (str): path to the file which contains the vocab list. | file_path (str): path to the file which contains the vocab list. | ||||
| delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be | delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be | ||||
| the word (default=None). | |||||
| the word (default=""). | |||||
| vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken). | vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken). | ||||
| special_tokens (list, optional): a list of strings, each one is a special token. for example | special_tokens (list, optional): a list of strings, each one is a special token. for example | ||||
| special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | ||||
| special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, | special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, | ||||
| If special_tokens is specified and special_first is set to None, | |||||
| special_tokens will be prepended (default=None). | |||||
| If special_tokens is specified and special_first is set to True, | |||||
| special_tokens will be prepended (default=True). | |||||
| """ | """ | ||||
| if vocab_size is None: | |||||
| vocab_size = -1 | |||||
| if special_tokens is None: | |||||
| special_tokens = [] | |||||
| return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first) | return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first) | ||||
| @classmethod | @classmethod | ||||
| @@ -17,23 +17,22 @@ validators for text ops | |||||
| """ | """ | ||||
| from functools import wraps | from functools import wraps | ||||
| import mindspore._c_dataengine as cde | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore._c_dataengine as cde | |||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| from ..transforms.validators import check_uint32, check_pos_int64 | |||||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, check_positive, \ | |||||
| INT32_MAX, check_value | |||||
| def check_unique_list_of_words(words, arg_name): | def check_unique_list_of_words(words, arg_name): | ||||
| """Check that words is a list and each element is a str without any duplication""" | """Check that words is a list and each element is a str without any duplication""" | ||||
| if not isinstance(words, list): | |||||
| raise ValueError(arg_name + " needs to be a list of words of type string.") | |||||
| type_check(words, (list,), arg_name) | |||||
| words_set = set() | words_set = set() | ||||
| for word in words: | for word in words: | ||||
| if not isinstance(word, str): | |||||
| raise ValueError("each word in " + arg_name + " needs to be type str.") | |||||
| type_check(word, (str,), arg_name) | |||||
| if word in words_set: | if word in words_set: | ||||
| raise ValueError(arg_name + " contains duplicate word: " + word + ".") | raise ValueError(arg_name + " contains duplicate word: " + word + ".") | ||||
| words_set.add(word) | words_set.add(word) | ||||
| @@ -45,21 +44,14 @@ def check_lookup(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| vocab, unknown = (list(args) + 2 * [None])[:2] | |||||
| if "vocab" in kwargs: | |||||
| vocab = kwargs.get("vocab") | |||||
| if "unknown" in kwargs: | |||||
| unknown = kwargs.get("unknown") | |||||
| if unknown is not None: | |||||
| if not (isinstance(unknown, int) and unknown >= 0): | |||||
| raise ValueError("unknown needs to be a non-negative integer.") | |||||
| [vocab, unknown], _ = parse_user_args(method, *args, **kwargs) | |||||
| if not isinstance(vocab, cde.Vocab): | |||||
| raise ValueError("vocab is not an instance of cde.Vocab.") | |||||
| if unknown is not None: | |||||
| type_check(unknown, (int,), "unknown") | |||||
| check_positive(unknown) | |||||
| type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") | |||||
| kwargs["vocab"] = vocab | |||||
| kwargs["unknown"] = unknown | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -69,50 +61,15 @@ def check_from_file(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5] | |||||
| if "file_path" in kwargs: | |||||
| file_path = kwargs.get("file_path") | |||||
| if "delimiter" in kwargs: | |||||
| delimiter = kwargs.get("delimiter") | |||||
| if "vocab_size" in kwargs: | |||||
| vocab_size = kwargs.get("vocab_size") | |||||
| if "special_tokens" in kwargs: | |||||
| special_tokens = kwargs.get("special_tokens") | |||||
| if "special_first" in kwargs: | |||||
| special_first = kwargs.get("special_first") | |||||
| if not isinstance(file_path, str): | |||||
| raise ValueError("file_path needs to be str.") | |||||
| if delimiter is not None: | |||||
| if not isinstance(delimiter, str): | |||||
| raise ValueError("delimiter needs to be str.") | |||||
| else: | |||||
| delimiter = "" | |||||
| if vocab_size is not None: | |||||
| if not (isinstance(vocab_size, int) and vocab_size > 0): | |||||
| raise ValueError("vocab size needs to be a positive integer.") | |||||
| else: | |||||
| vocab_size = -1 | |||||
| if special_first is None: | |||||
| special_first = True | |||||
| if not isinstance(special_first, bool): | |||||
| raise ValueError("special_first needs to be a boolean value") | |||||
| if special_tokens is None: | |||||
| special_tokens = [] | |||||
| [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, | |||||
| **kwargs) | |||||
| check_unique_list_of_words(special_tokens, "special_tokens") | check_unique_list_of_words(special_tokens, "special_tokens") | ||||
| type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) | |||||
| if vocab_size is not None: | |||||
| check_value(vocab_size, (-1, INT32_MAX), "vocab_size") | |||||
| type_check(special_first, (bool,), special_first) | |||||
| kwargs["file_path"] = file_path | |||||
| kwargs["delimiter"] = delimiter | |||||
| kwargs["vocab_size"] = vocab_size | |||||
| kwargs["special_tokens"] = special_tokens | |||||
| kwargs["special_first"] = special_first | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -122,33 +79,20 @@ def check_from_list(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3] | |||||
| if "word_list" in kwargs: | |||||
| word_list = kwargs.get("word_list") | |||||
| if "special_tokens" in kwargs: | |||||
| special_tokens = kwargs.get("special_tokens") | |||||
| if "special_first" in kwargs: | |||||
| special_first = kwargs.get("special_first") | |||||
| if special_tokens is None: | |||||
| special_tokens = [] | |||||
| word_set = check_unique_list_of_words(word_list, "word_list") | |||||
| token_set = check_unique_list_of_words(special_tokens, "special_tokens") | |||||
| [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs) | |||||
| intersect = word_set.intersection(token_set) | |||||
| word_set = check_unique_list_of_words(word_list, "word_list") | |||||
| if special_tokens is not None: | |||||
| token_set = check_unique_list_of_words(special_tokens, "special_tokens") | |||||
| if intersect != set(): | |||||
| raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") | |||||
| intersect = word_set.intersection(token_set) | |||||
| if special_first is None: | |||||
| special_first = True | |||||
| if intersect != set(): | |||||
| raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") | |||||
| if not isinstance(special_first, bool): | |||||
| raise ValueError("special_first needs to be a boolean value.") | |||||
| type_check(special_first, (bool,), "special_first") | |||||
| kwargs["word_list"] = word_list | |||||
| kwargs["special_tokens"] = special_tokens | |||||
| kwargs["special_first"] = special_first | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -158,18 +102,15 @@ def check_from_dict(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| word_dict, = (list(args) + [None])[:1] | |||||
| if "word_dict" in kwargs: | |||||
| word_dict = kwargs.get("word_dict") | |||||
| if not isinstance(word_dict, dict): | |||||
| raise ValueError("word_dict needs to be a list of word,id pairs.") | |||||
| [word_dict], _ = parse_user_args(method, *args, **kwargs) | |||||
| type_check(word_dict, (dict,), "word_dict") | |||||
| for word, word_id in word_dict.items(): | for word, word_id in word_dict.items(): | ||||
| if not isinstance(word, str): | |||||
| raise ValueError("Each word in word_dict needs to be type string.") | |||||
| if not (isinstance(word_id, int) and word_id >= 0): | |||||
| raise ValueError("Each word id needs to be positive integer.") | |||||
| kwargs["word_dict"] = word_dict | |||||
| return method(self, **kwargs) | |||||
| type_check(word, (str,), "word") | |||||
| type_check(word_id, (int,), "word_id") | |||||
| check_value(word_id, (-1, INT32_MAX), "word_id") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -179,23 +120,8 @@ def check_jieba_init(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| hmm_path, mp_path, model = (list(args) + 3 * [None])[:3] | |||||
| if "hmm_path" in kwargs: | |||||
| hmm_path = kwargs.get("hmm_path") | |||||
| if "mp_path" in kwargs: | |||||
| mp_path = kwargs.get("mp_path") | |||||
| if hmm_path is None: | |||||
| raise ValueError( | |||||
| "The dict of HMMSegment in cppjieba is not provided.") | |||||
| kwargs["hmm_path"] = hmm_path | |||||
| if mp_path is None: | |||||
| raise ValueError( | |||||
| "The dict of MPSegment in cppjieba is not provided.") | |||||
| kwargs["mp_path"] = mp_path | |||||
| if model is not None: | |||||
| kwargs["model"] = model | |||||
| return method(self, **kwargs) | |||||
| parse_user_args(method, *args, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -205,19 +131,12 @@ def check_jieba_add_word(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| word, freq = (list(args) + 2 * [None])[:2] | |||||
| if "word" in kwargs: | |||||
| word = kwargs.get("word") | |||||
| if "freq" in kwargs: | |||||
| freq = kwargs.get("freq") | |||||
| [word, freq], _ = parse_user_args(method, *args, **kwargs) | |||||
| if word is None: | if word is None: | ||||
| raise ValueError("word is not provided.") | raise ValueError("word is not provided.") | ||||
| kwargs["word"] = word | |||||
| if freq is not None: | if freq is not None: | ||||
| check_uint32(freq) | check_uint32(freq) | ||||
| kwargs["freq"] = freq | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -227,13 +146,8 @@ def check_jieba_add_dict(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| user_dict = (list(args) + [None])[0] | |||||
| if "user_dict" in kwargs: | |||||
| user_dict = kwargs.get("user_dict") | |||||
| if user_dict is None: | |||||
| raise ValueError("user_dict is not provided.") | |||||
| kwargs["user_dict"] = user_dict | |||||
| return method(self, **kwargs) | |||||
| parse_user_args(method, *args, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -244,69 +158,39 @@ def check_from_dataset(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6] | |||||
| if "dataset" in kwargs: | |||||
| dataset = kwargs.get("dataset") | |||||
| if "columns" in kwargs: | |||||
| columns = kwargs.get("columns") | |||||
| if "freq_range" in kwargs: | |||||
| freq_range = kwargs.get("freq_range") | |||||
| if "top_k" in kwargs: | |||||
| top_k = kwargs.get("top_k") | |||||
| if "special_tokens" in kwargs: | |||||
| special_tokens = kwargs.get("special_tokens") | |||||
| if "special_first" in kwargs: | |||||
| special_first = kwargs.get("special_first") | |||||
| if columns is None: | |||||
| columns = [] | |||||
| if not isinstance(columns, list): | |||||
| columns = [columns] | |||||
| for column in columns: | |||||
| if not isinstance(column, str): | |||||
| raise ValueError("columns need to be a list of strings.") | |||||
| if freq_range is None: | |||||
| freq_range = (None, None) | |||||
| if not isinstance(freq_range, tuple) or len(freq_range) != 2: | |||||
| raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") | |||||
| [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args, | |||||
| **kwargs) | |||||
| if columns is not None: | |||||
| if not isinstance(columns, list): | |||||
| columns = [columns] | |||||
| col_names = ["col_{0}".format(i) for i in range(len(columns))] | |||||
| type_check_list(columns, (str,), col_names) | |||||
| for num in freq_range: | |||||
| if num is not None and (not isinstance(num, int)): | |||||
| raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") | |||||
| if freq_range is not None: | |||||
| type_check(freq_range, (tuple,), "freq_range") | |||||
| if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): | |||||
| if freq_range[0] > freq_range[1] or freq_range[0] < 0: | |||||
| raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") | |||||
| if len(freq_range) != 2: | |||||
| raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.") | |||||
| if top_k is not None and (not isinstance(top_k, int)): | |||||
| raise ValueError("top_k needs to be a positive integer.") | |||||
| for num in freq_range: | |||||
| if num is not None and (not isinstance(num, int)): | |||||
| raise ValueError( | |||||
| "freq_range needs to be either None or a tuple of 2 integers or an int and a None.") | |||||
| if isinstance(top_k, int) and top_k <= 0: | |||||
| raise ValueError("top_k needs to be a positive integer.") | |||||
| if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): | |||||
| if freq_range[0] > freq_range[1] or freq_range[0] < 0: | |||||
| raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") | |||||
| if special_first is None: | |||||
| special_first = True | |||||
| type_check(top_k, (int, type(None)), "top_k") | |||||
| if special_tokens is None: | |||||
| special_tokens = [] | |||||
| if isinstance(top_k, int): | |||||
| check_value(top_k, (0, INT32_MAX), "top_k") | |||||
| type_check(special_first, (bool,), "special_first") | |||||
| if not isinstance(special_first, bool): | |||||
| raise ValueError("special_first needs to be a boolean value.") | |||||
| if special_tokens is not None: | |||||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||||
| check_unique_list_of_words(special_tokens, "special_tokens") | |||||
| kwargs["dataset"] = dataset | |||||
| kwargs["columns"] = columns | |||||
| kwargs["freq_range"] = freq_range | |||||
| kwargs["top_k"] = top_k | |||||
| kwargs["special_tokens"] = special_tokens | |||||
| kwargs["special_first"] = special_first | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -316,15 +200,7 @@ def check_ngram(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4] | |||||
| if "n" in kwargs: | |||||
| n = kwargs.get("n") | |||||
| if "left_pad" in kwargs: | |||||
| left_pad = kwargs.get("left_pad") | |||||
| if "right_pad" in kwargs: | |||||
| right_pad = kwargs.get("right_pad") | |||||
| if "separator" in kwargs: | |||||
| separator = kwargs.get("separator") | |||||
| [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs) | |||||
| if isinstance(n, int): | if isinstance(n, int): | ||||
| n = [n] | n = [n] | ||||
| @@ -332,15 +208,9 @@ def check_ngram(method): | |||||
| if not (isinstance(n, list) and n != []): | if not (isinstance(n, list) and n != []): | ||||
| raise ValueError("n needs to be a non-empty list of positive integers.") | raise ValueError("n needs to be a non-empty list of positive integers.") | ||||
| for gram in n: | |||||
| if not (isinstance(gram, int) and gram > 0): | |||||
| raise ValueError("n in ngram needs to be a positive number.") | |||||
| if left_pad is None: | |||||
| left_pad = ("", 0) | |||||
| if right_pad is None: | |||||
| right_pad = ("", 0) | |||||
| for i, gram in enumerate(n): | |||||
| type_check(gram, (int,), "gram[{0}]".format(i)) | |||||
| check_value(gram, (0, INT32_MAX), "gram_{}".format(i)) | |||||
| if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( | if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( | ||||
| left_pad[1], int)): | left_pad[1], int)): | ||||
| @@ -353,11 +223,7 @@ def check_ngram(method): | |||||
| if not (left_pad[1] >= 0 and right_pad[1] >= 0): | if not (left_pad[1] >= 0 and right_pad[1] >= 0): | ||||
| raise ValueError("padding width need to be positive numbers.") | raise ValueError("padding width need to be positive numbers.") | ||||
| if separator is None: | |||||
| separator = " " | |||||
| if not isinstance(separator, str): | |||||
| raise ValueError("separator needs to be a string.") | |||||
| type_check(separator, (str,), "separator") | |||||
| kwargs["n"] = n | kwargs["n"] = n | ||||
| kwargs["left_pad"] = left_pad | kwargs["left_pad"] = left_pad | ||||
| @@ -374,16 +240,8 @@ def check_pair_truncate(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| max_length = (list(args) + [None])[0] | |||||
| if "max_length" in kwargs: | |||||
| max_length = kwargs.get("max_length") | |||||
| if max_length is None: | |||||
| raise ValueError("max_length is not provided.") | |||||
| check_pos_int64(max_length) | |||||
| kwargs["max_length"] = max_length | |||||
| return method(self, **kwargs) | |||||
| parse_user_args(method, *args, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -393,22 +251,13 @@ def check_to_number(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| data_type = (list(args) + [None])[0] | |||||
| if "data_type" in kwargs: | |||||
| data_type = kwargs.get("data_type") | |||||
| if data_type is None: | |||||
| raise ValueError("data_type is a mandatory parameter but was not provided.") | |||||
| if not isinstance(data_type, typing.Type): | |||||
| raise TypeError("data_type is not a MindSpore data type.") | |||||
| [data_type], _ = parse_user_args(method, *args, **kwargs) | |||||
| type_check(data_type, (typing.Type,), "data_type") | |||||
| if data_type not in mstype.number_type: | if data_type not in mstype.number_type: | ||||
| raise TypeError("data_type is not numeric data type.") | raise TypeError("data_type is not numeric data type.") | ||||
| kwargs["data_type"] = data_type | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -418,18 +267,11 @@ def check_python_tokenizer(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| tokenizer = (list(args) + [None])[0] | |||||
| if "tokenizer" in kwargs: | |||||
| tokenizer = kwargs.get("tokenizer") | |||||
| if tokenizer is None: | |||||
| raise ValueError("tokenizer is a mandatory parameter.") | |||||
| [tokenizer], _ = parse_user_args(method, *args, **kwargs) | |||||
| if not callable(tokenizer): | if not callable(tokenizer): | ||||
| raise TypeError("tokenizer is not a callable python function") | raise TypeError("tokenizer is not a callable python function") | ||||
| kwargs["tokenizer"] = tokenizer | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -18,6 +18,7 @@ from functools import wraps | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive | |||||
| # POS_INT_MIN is used to limit values from starting from 0 | # POS_INT_MIN is used to limit values from starting from 0 | ||||
| POS_INT_MIN = 1 | POS_INT_MIN = 1 | ||||
| @@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992 | |||||
| DOUBLE_MIN_INTEGER = -9007199254740992 | DOUBLE_MIN_INTEGER = -9007199254740992 | ||||
| def check_type(value, valid_type): | |||||
| if not isinstance(value, valid_type): | |||||
| raise ValueError("Wrong input type") | |||||
| def check_value(value, valid_range): | |||||
| if value < valid_range[0] or value > valid_range[1]: | |||||
| raise ValueError("Input is not within the required range") | |||||
| def check_range(values, valid_range): | |||||
| if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: | |||||
| raise ValueError("Input range is not valid") | |||||
| def check_positive(value): | |||||
| if value <= 0: | |||||
| raise ValueError("Input must greater than 0") | |||||
| def check_positive_float(value, valid_max=None): | |||||
| if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max): | |||||
| raise ValueError("Input need to be a valid positive float.") | |||||
| def check_bool(value): | |||||
| if not isinstance(value, bool): | |||||
| raise ValueError("Value needs to be a boolean.") | |||||
| def check_2tuple(value): | |||||
| if not (isinstance(value, tuple) and len(value) == 2): | |||||
| raise ValueError("Value needs to be a 2-tuple.") | |||||
| def check_list(value): | |||||
| if not isinstance(value, list): | |||||
| raise ValueError("The input needs to be a list.") | |||||
| def check_uint8(value): | |||||
| if not isinstance(value, int): | |||||
| raise ValueError("The input needs to be a integer") | |||||
| check_value(value, [UINT8_MIN, UINT8_MAX]) | |||||
| def check_uint32(value): | |||||
| if not isinstance(value, int): | |||||
| raise ValueError("The input needs to be a integer") | |||||
| check_value(value, [UINT32_MIN, UINT32_MAX]) | |||||
| def check_pos_int32(value): | |||||
| """Checks for int values starting from 1""" | |||||
| if not isinstance(value, int): | |||||
| raise ValueError("The input needs to be a integer") | |||||
| check_value(value, [POS_INT_MIN, INT32_MAX]) | |||||
| def check_uint64(value): | |||||
| if not isinstance(value, int): | |||||
| raise ValueError("The input needs to be a integer") | |||||
| check_value(value, [UINT64_MIN, UINT64_MAX]) | |||||
| def check_pos_int64(value): | |||||
| if not isinstance(value, int): | |||||
| raise ValueError("The input needs to be a integer") | |||||
| check_value(value, [UINT64_MIN, INT64_MAX]) | |||||
| def check_fill_value(method): | |||||
| """Wrapper method to check the parameters of fill_value.""" | |||||
| def check_pos_float32(value): | |||||
| check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER]) | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [fill_value], _ = parse_user_args(method, *args, **kwargs) | |||||
| type_check(fill_value, (str, float, bool, int, bytes), "fill_value") | |||||
| return method(self, *args, **kwargs) | |||||
| def check_pos_float64(value): | |||||
| check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER]) | |||||
| return new_method | |||||
| def check_one_hot_op(method): | def check_one_hot_op(method): | ||||
| """Wrapper method to check the parameters of one hot op.""" | |||||
| """Wrapper method to check the parameters of one_hot_op.""" | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| args = (list(args) + 2 * [None])[:2] | |||||
| num_classes, smoothing_rate = args | |||||
| if "num_classes" in kwargs: | |||||
| num_classes = kwargs.get("num_classes") | |||||
| if "smoothing_rate" in kwargs: | |||||
| smoothing_rate = kwargs.get("smoothing_rate") | |||||
| if num_classes is None: | |||||
| raise ValueError("num_classes") | |||||
| check_pos_int32(num_classes) | |||||
| kwargs["num_classes"] = num_classes | |||||
| [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs) | |||||
| type_check(num_classes, (int,), "num_classes") | |||||
| check_positive(num_classes) | |||||
| if smoothing_rate is not None: | if smoothing_rate is not None: | ||||
| check_value(smoothing_rate, [0., 1.]) | |||||
| kwargs["smoothing_rate"] = smoothing_rate | |||||
| check_value(smoothing_rate, [0., 1.], "smoothing_rate") | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -146,35 +74,12 @@ def check_num_classes(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| num_classes = (list(args) + [None])[0] | |||||
| if "num_classes" in kwargs: | |||||
| num_classes = kwargs.get("num_classes") | |||||
| if num_classes is None: | |||||
| raise ValueError("num_classes is not provided.") | |||||
| check_pos_int32(num_classes) | |||||
| kwargs["num_classes"] = num_classes | |||||
| return method(self, **kwargs) | |||||
| return new_method | |||||
| [num_classes], _ = parse_user_args(method, *args, **kwargs) | |||||
| def check_fill_value(method): | |||||
| """Wrapper method to check the parameters of fill value.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| fill_value = (list(args) + [None])[0] | |||||
| if "fill_value" in kwargs: | |||||
| fill_value = kwargs.get("fill_value") | |||||
| if fill_value is None: | |||||
| raise ValueError("fill_value is not provided.") | |||||
| if not isinstance(fill_value, (str, float, bool, int, bytes)): | |||||
| raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int") | |||||
| kwargs["fill_value"] = fill_value | |||||
| type_check(num_classes, (int,), "num_classes") | |||||
| check_positive(num_classes) | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -184,17 +89,11 @@ def check_de_type(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| data_type = (list(args) + [None])[0] | |||||
| if "data_type" in kwargs: | |||||
| data_type = kwargs.get("data_type") | |||||
| [data_type], _ = parse_user_args(method, *args, **kwargs) | |||||
| if data_type is None: | |||||
| raise ValueError("data_type is not provided.") | |||||
| if not isinstance(data_type, typing.Type): | |||||
| raise TypeError("data_type is not a MindSpore data type.") | |||||
| kwargs["data_type"] = data_type | |||||
| type_check(data_type, (typing.Type,), "data_type") | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -204,13 +103,11 @@ def check_slice_op(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args): | def new_method(self, *args): | ||||
| for i, arg in enumerate(args): | |||||
| if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)): | |||||
| raise TypeError("Indexing of dim " + str(i) + "is not of valid type") | |||||
| for _, arg in enumerate(args): | |||||
| type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg") | |||||
| if isinstance(arg, list): | if isinstance(arg, list): | ||||
| for a in arg: | for a in arg: | ||||
| if not isinstance(a, int): | |||||
| raise TypeError("Index " + a + " is not an int") | |||||
| type_check(a, (int,), "a") | |||||
| return method(self, *args) | return method(self, *args) | ||||
| return new_method | return new_method | ||||
| @@ -221,36 +118,14 @@ def check_mask_op(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| operator, constant, dtype = (list(args) + 3 * [None])[:3] | |||||
| if "operator" in kwargs: | |||||
| operator = kwargs.get("operator") | |||||
| if "constant" in kwargs: | |||||
| constant = kwargs.get("constant") | |||||
| if "dtype" in kwargs: | |||||
| dtype = kwargs.get("dtype") | |||||
| if operator is None: | |||||
| raise ValueError("operator is not provided.") | |||||
| if constant is None: | |||||
| raise ValueError("constant is not provided.") | |||||
| [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs) | |||||
| from .c_transforms import Relational | from .c_transforms import Relational | ||||
| if not isinstance(operator, Relational): | |||||
| raise TypeError("operator is not a Relational operator enum.") | |||||
| type_check(operator, (Relational,), "operator") | |||||
| type_check(constant, (str, float, bool, int, bytes), "constant") | |||||
| type_check(dtype, (typing.Type,), "dtype") | |||||
| if not isinstance(constant, (str, float, bool, int, bytes)): | |||||
| raise TypeError("constant must be either a primitive python str, float, bool, bytes or int") | |||||
| if dtype is not None: | |||||
| if not isinstance(dtype, typing.Type): | |||||
| raise TypeError("dtype is not a MindSpore data type.") | |||||
| kwargs["dtype"] = dtype | |||||
| kwargs["operator"] = operator | |||||
| kwargs["constant"] = constant | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -260,22 +135,12 @@ def check_pad_end(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| pad_shape, pad_value = (list(args) + 2 * [None])[:2] | |||||
| if "pad_shape" in kwargs: | |||||
| pad_shape = kwargs.get("pad_shape") | |||||
| if "pad_value" in kwargs: | |||||
| pad_value = kwargs.get("pad_value") | |||||
| if pad_shape is None: | |||||
| raise ValueError("pad_shape is not provided.") | |||||
| [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs) | |||||
| if pad_value is not None: | if pad_value is not None: | ||||
| if not isinstance(pad_value, (str, float, bool, int, bytes)): | |||||
| raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes") | |||||
| kwargs["pad_value"] = pad_value | |||||
| if not isinstance(pad_shape, list): | |||||
| raise TypeError("pad_shape must be a list") | |||||
| type_check(pad_value, (str, float, bool, int, bytes), "pad_value") | |||||
| type_check(pad_shape, (list,), "pad_end") | |||||
| for dim in pad_shape: | for dim in pad_shape: | ||||
| if dim is not None: | if dim is not None: | ||||
| @@ -284,9 +149,7 @@ def check_pad_end(method): | |||||
| else: | else: | ||||
| raise TypeError("a value in the list is not an integer.") | raise TypeError("a value in the list is not an integer.") | ||||
| kwargs["pad_shape"] = pad_shape | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -296,31 +159,24 @@ def check_concat_type(method): | |||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| axis, prepend, append = (list(args) + 3 * [None])[:3] | |||||
| if "prepend" in kwargs: | |||||
| prepend = kwargs.get("prepend") | |||||
| if "append" in kwargs: | |||||
| append = kwargs.get("append") | |||||
| if "axis" in kwargs: | |||||
| axis = kwargs.get("axis") | |||||
| [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs) | |||||
| if axis is not None: | if axis is not None: | ||||
| if not isinstance(axis, int): | |||||
| raise TypeError("axis type is not valid, must be an integer.") | |||||
| type_check(axis, (int,), "axis") | |||||
| if axis not in (0, -1): | if axis not in (0, -1): | ||||
| raise ValueError("only 1D concatenation supported.") | raise ValueError("only 1D concatenation supported.") | ||||
| kwargs["axis"] = axis | |||||
| if prepend is not None: | if prepend is not None: | ||||
| if not isinstance(prepend, (type(None), np.ndarray)): | |||||
| raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") | |||||
| kwargs["prepend"] = prepend | |||||
| type_check(prepend, (np.ndarray,), "prepend") | |||||
| if len(prepend.shape) != 1: | |||||
| raise ValueError("can only prepend 1D arrays.") | |||||
| if append is not None: | if append is not None: | ||||
| if not isinstance(append, (type(None), np.ndarray)): | |||||
| raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") | |||||
| kwargs["append"] = append | |||||
| type_check(append, (np.ndarray,), "append") | |||||
| if len(append.shape) != 1: | |||||
| raise ValueError("can only append 1D arrays.") | |||||
| return method(self, **kwargs) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | return new_method | ||||
| @@ -40,12 +40,14 @@ Examples: | |||||
| >>> dataset = dataset.map(input_columns="image", operations=transforms_list) | >>> dataset = dataset.map(input_columns="image", operations=transforms_list) | ||||
| >>> dataset = dataset.map(input_columns="label", operations=onehot_op) | >>> dataset = dataset.map(input_columns="label", operations=onehot_op) | ||||
| """ | """ | ||||
| import numbers | |||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from .utils import Inter, Border | from .utils import Inter, Border | ||||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | ||||
| check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ | |||||
| check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp | |||||
| check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ | |||||
| check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ | |||||
| FLOAT_MAX_INTEGER | |||||
| DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, | ||||
| Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, | ||||
| @@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT, | |||||
| Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} | Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} | ||||
| def parse_padding(padding): | |||||
| if isinstance(padding, numbers.Number): | |||||
| padding = [padding] * 4 | |||||
| if len(padding) == 2: | |||||
| left = right = padding[0] | |||||
| top = bottom = padding[1] | |||||
| padding = (left, top, right, bottom,) | |||||
| if isinstance(padding, list): | |||||
| padding = tuple(padding) | |||||
| return padding | |||||
| class Decode(cde.DecodeOp): | class Decode(cde.DecodeOp): | ||||
| """ | """ | ||||
| Decode the input image in RGB mode. | Decode the input image in RGB mode. | ||||
| @@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp): | |||||
| @check_random_crop | @check_random_crop | ||||
| def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | ||||
| self.size = size | |||||
| self.padding = padding | |||||
| self.pad_if_needed = pad_if_needed | |||||
| self.fill_value = fill_value | |||||
| self.padding_mode = padding_mode.value | |||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| if padding is None: | if padding is None: | ||||
| padding = (0, 0, 0, 0) | padding = (0, 0, 0, 0) | ||||
| else: | |||||
| padding = parse_padding(padding) | |||||
| if isinstance(fill_value, int): # temporary fix | if isinstance(fill_value, int): # temporary fix | ||||
| fill_value = tuple([fill_value] * 3) | fill_value = tuple([fill_value] * 3) | ||||
| border_type = DE_C_BORDER_TYPE[padding_mode] | border_type = DE_C_BORDER_TYPE[padding_mode] | ||||
| self.size = size | |||||
| self.padding = padding | |||||
| self.pad_if_needed = pad_if_needed | |||||
| self.fill_value = fill_value | |||||
| self.padding_mode = padding_mode.value | |||||
| super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) | super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) | ||||
| @@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp): | |||||
| @check_random_crop | @check_random_crop | ||||
| def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | ||||
| self.size = size | |||||
| self.padding = padding | |||||
| self.pad_if_needed = pad_if_needed | |||||
| self.fill_value = fill_value | |||||
| self.padding_mode = padding_mode.value | |||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| if padding is None: | if padding is None: | ||||
| padding = (0, 0, 0, 0) | padding = (0, 0, 0, 0) | ||||
| else: | |||||
| padding = parse_padding(padding) | |||||
| if isinstance(fill_value, int): # temporary fix | if isinstance(fill_value, int): # temporary fix | ||||
| fill_value = tuple([fill_value] * 3) | fill_value = tuple([fill_value] * 3) | ||||
| border_type = DE_C_BORDER_TYPE[padding_mode] | border_type = DE_C_BORDER_TYPE[padding_mode] | ||||
| self.size = size | |||||
| self.padding = padding | |||||
| self.pad_if_needed = pad_if_needed | |||||
| self.fill_value = fill_value | |||||
| self.padding_mode = padding_mode.value | |||||
| super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) | super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) | ||||
| @@ -292,6 +319,8 @@ class Resize(cde.ResizeOp): | |||||
| @check_resize_interpolation | @check_resize_interpolation | ||||
| def __init__(self, size, interpolation=Inter.LINEAR): | def __init__(self, size, interpolation=Inter.LINEAR): | ||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| self.size = size | self.size = size | ||||
| self.interpolation = interpolation | self.interpolation = interpolation | ||||
| interpoltn = DE_C_INTER_MODE[interpolation] | interpoltn = DE_C_INTER_MODE[interpolation] | ||||
| @@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): | |||||
| @check_random_resize_crop | @check_random_resize_crop | ||||
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | ||||
| interpolation=Inter.BILINEAR, max_attempts=10): | interpolation=Inter.BILINEAR, max_attempts=10): | ||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| self.size = size | self.size = size | ||||
| self.scale = scale | self.scale = scale | ||||
| self.ratio = ratio | self.ratio = ratio | ||||
| @@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp): | |||||
| @check_random_resize_crop | @check_random_resize_crop | ||||
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | ||||
| interpolation=Inter.BILINEAR, max_attempts=10): | interpolation=Inter.BILINEAR, max_attempts=10): | ||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| self.size = size | self.size = size | ||||
| self.scale = scale | self.scale = scale | ||||
| self.ratio = ratio | self.ratio = ratio | ||||
| @@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp): | |||||
| @check_crop | @check_crop | ||||
| def __init__(self, size): | def __init__(self, size): | ||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| self.size = size | self.size = size | ||||
| super().__init__(*size) | super().__init__(*size) | ||||
| @@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp): | |||||
| @check_random_color_adjust | @check_random_color_adjust | ||||
| def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)): | def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)): | ||||
| brightness = self.expand_values(brightness) | |||||
| contrast = self.expand_values(contrast) | |||||
| saturation = self.expand_values(saturation) | |||||
| hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False) | |||||
| self.brightness = brightness | self.brightness = brightness | ||||
| self.contrast = contrast | self.contrast = contrast | ||||
| self.saturation = saturation | self.saturation = saturation | ||||
| self.hue = hue | self.hue = hue | ||||
| super().__init__(*brightness, *contrast, *saturation, *hue) | super().__init__(*brightness, *contrast, *saturation, *hue) | ||||
| def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): | |||||
| if isinstance(value, numbers.Number): | |||||
| value = [center - value, center + value] | |||||
| if non_negative: | |||||
| value[0] = max(0, value[0]) | |||||
| check_range(value, bound) | |||||
| return (value[0], value[1]) | |||||
| class RandomRotation(cde.RandomRotationOp): | class RandomRotation(cde.RandomRotationOp): | ||||
| """ | """ | ||||
| @@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp): | |||||
| self.expand = expand | self.expand = expand | ||||
| self.center = center | self.center = center | ||||
| self.fill_value = fill_value | self.fill_value = fill_value | ||||
| if isinstance(degrees, numbers.Number): | |||||
| degrees = (-degrees, degrees) | |||||
| if center is None: | if center is None: | ||||
| center = (-1, -1) | center = (-1, -1) | ||||
| if isinstance(fill_value, int): # temporary fix | if isinstance(fill_value, int): # temporary fix | ||||
| @@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp): | |||||
| @check_random_resize_crop | @check_random_resize_crop | ||||
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), | ||||
| interpolation=Inter.BILINEAR, max_attempts=10): | interpolation=Inter.BILINEAR, max_attempts=10): | ||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| self.size = size | self.size = size | ||||
| self.scale = scale | self.scale = scale | ||||
| self.ratio = ratio | self.ratio = ratio | ||||
| @@ -623,12 +676,14 @@ class Pad(cde.PadOp): | |||||
| @check_pad | @check_pad | ||||
| def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): | def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): | ||||
| self.padding = padding | |||||
| self.fill_value = fill_value | |||||
| self.padding_mode = padding_mode | |||||
| padding = parse_padding(padding) | |||||
| if isinstance(fill_value, int): # temporary fix | if isinstance(fill_value, int): # temporary fix | ||||
| fill_value = tuple([fill_value] * 3) | fill_value = tuple([fill_value] * 3) | ||||
| padding_mode = DE_C_BORDER_TYPE[padding_mode] | padding_mode = DE_C_BORDER_TYPE[padding_mode] | ||||
| self.padding = padding | |||||
| self.fill_value = fill_value | |||||
| self.padding_mode = padding_mode | |||||
| super().__init__(*padding, padding_mode, *fill_value) | super().__init__(*padding, padding_mode, *fill_value) | ||||
| @@ -28,6 +28,7 @@ import numpy as np | |||||
| from PIL import Image | from PIL import Image | ||||
| from . import py_transforms_util as util | from . import py_transforms_util as util | ||||
| from .c_transforms import parse_padding | |||||
| from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ | ||||
| check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ | check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ | ||||
| check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ | check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ | ||||
| @@ -295,6 +296,10 @@ class RandomCrop: | |||||
| @check_random_crop | @check_random_crop | ||||
| def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): | ||||
| if padding is None: | |||||
| padding = (0, 0, 0, 0) | |||||
| else: | |||||
| padding = parse_padding(padding) | |||||
| self.size = size | self.size = size | ||||
| self.padding = padding | self.padding = padding | ||||
| self.pad_if_needed = pad_if_needed | self.pad_if_needed = pad_if_needed | ||||
| @@ -753,6 +758,8 @@ class TenCrop: | |||||
| @check_ten_crop | @check_ten_crop | ||||
| def __init__(self, size, use_vertical_flip=False): | def __init__(self, size, use_vertical_flip=False): | ||||
| if isinstance(size, int): | |||||
| size = (size, size) | |||||
| self.size = size | self.size = size | ||||
| self.use_vertical_flip = use_vertical_flip | self.use_vertical_flip = use_vertical_flip | ||||
| @@ -877,6 +884,8 @@ class Pad: | |||||
| @check_pad | @check_pad | ||||
| def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): | def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): | ||||
| parse_padding(padding) | |||||
| self.padding = padding | self.padding = padding | ||||
| self.fill_value = fill_value | self.fill_value = fill_value | ||||
| self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] | self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] | ||||
| @@ -1129,56 +1138,23 @@ class RandomAffine: | |||||
| def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0): | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0): | ||||
| # Parameter checking | # Parameter checking | ||||
| # rotation | # rotation | ||||
| if isinstance(degrees, numbers.Number): | |||||
| if degrees < 0: | |||||
| raise ValueError("If degrees is a single number, it must be positive.") | |||||
| self.degrees = (-degrees, degrees) | |||||
| elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: | |||||
| self.degrees = degrees | |||||
| else: | |||||
| raise TypeError("If degrees is a list or tuple, it must be of length 2.") | |||||
| # translation | |||||
| if translate is not None: | |||||
| if isinstance(translate, (tuple, list)) and len(translate) == 2: | |||||
| for t in translate: | |||||
| if t < 0.0 or t > 1.0: | |||||
| raise ValueError("translation values should be between 0 and 1") | |||||
| else: | |||||
| raise TypeError("translate should be a list or tuple of length 2.") | |||||
| self.translate = translate | |||||
| # scale | |||||
| if scale is not None: | |||||
| if isinstance(scale, (tuple, list)) and len(scale) == 2: | |||||
| for s in scale: | |||||
| if s <= 0: | |||||
| raise ValueError("scale values should be positive") | |||||
| else: | |||||
| raise TypeError("scale should be a list or tuple of length 2.") | |||||
| self.scale_ranges = scale | |||||
| # shear | |||||
| if shear is not None: | if shear is not None: | ||||
| if isinstance(shear, numbers.Number): | if isinstance(shear, numbers.Number): | ||||
| if shear < 0: | |||||
| raise ValueError("If shear is a single number, it must be positive.") | |||||
| self.shear = (-1 * shear, shear) | |||||
| elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4): | |||||
| # X-Axis shear with [min, max] | |||||
| shear = (-1 * shear, shear) | |||||
| else: | |||||
| if len(shear) == 2: | if len(shear) == 2: | ||||
| self.shear = [shear[0], shear[1], 0., 0.] | |||||
| shear = [shear[0], shear[1], 0., 0.] | |||||
| elif len(shear) == 4: | elif len(shear) == 4: | ||||
| self.shear = [s for s in shear] | |||||
| else: | |||||
| raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.") | |||||
| else: | |||||
| self.shear = shear | |||||
| shear = [s for s in shear] | |||||
| # resample | |||||
| self.resample = DE_PY_INTER_MODE[resample] | |||||
| if isinstance(degrees, numbers.Number): | |||||
| degrees = (-degrees, degrees) | |||||
| # fill_value | |||||
| self.degrees = degrees | |||||
| self.translate = translate | |||||
| self.scale_ranges = scale | |||||
| self.shear = shear | |||||
| self.resample = DE_PY_INTER_MODE[resample] | |||||
| self.fill_value = fill_value | self.fill_value = fill_value | ||||
| def __call__(self, img): | def __call__(self, img): | ||||
| @@ -15,13 +15,15 @@ | |||||
| """ | """ | ||||
| Testing the bounding box augment op in DE | Testing the bounding box augment op in DE | ||||
| """ | """ | ||||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.log as logger | import mindspore.log as logger | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | import mindspore.dataset.transforms.vision.c_transforms as c_vision | ||||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| # updated VOC dataset with correct annotations | # updated VOC dataset with correct annotations | ||||
| @@ -241,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c(): | |||||
| operations=[test_op]) # Add column for "annotation" | operations=[test_op]) # Add column for "annotation" | ||||
| except ValueError as error: | except ValueError as error: | ||||
| logger.info("Got an exception in DE: {}".format(str(error))) | logger.info("Got an exception in DE: {}".format(str(error))) | ||||
| assert "Input is not" in str(error) | |||||
| assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error) | |||||
| def test_bounding_box_augment_invalid_bounds_c(): | def test_bounding_box_augment_invalid_bounds_c(): | ||||
| @@ -17,6 +17,7 @@ import pytest | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| # generates 1 column [0], [0, 1], ..., [0, ..., n-1] | # generates 1 column [0], [0, 1], ..., [0, ..., n-1] | ||||
| def generate_sequential(n): | def generate_sequential(n): | ||||
| for i in range(n): | for i in range(n): | ||||
| @@ -99,12 +100,12 @@ def test_bucket_batch_invalid_input(): | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, | _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, | ||||
| None, None, invalid_type_pad_to_bucket_boundary) | None, None, invalid_type_pad_to_bucket_boundary) | ||||
| assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value) | |||||
| assert "Argument pad_to_bucket_boundary with value \"\" is not of type (<class \'bool\'>,)." in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | with pytest.raises(TypeError) as info: | ||||
| _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, | _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, | ||||
| None, None, False, invalid_type_drop_remainder) | None, None, False, invalid_type_drop_remainder) | ||||
| assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value) | |||||
| assert "Argument drop_remainder with value \"\" is not of type (<class 'bool'>,)." in str(info.value) | |||||
| def test_bucket_batch_multi_bucket_no_padding(): | def test_bucket_batch_multi_bucket_no_padding(): | ||||
| @@ -272,7 +273,6 @@ def test_bucket_batch_default_pad(): | |||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | ||||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] | ||||
| output = [] | output = [] | ||||
| for data in dataset.create_dict_iterator(): | for data in dataset.create_dict_iterator(): | ||||
| output.append(data["col1"].tolist()) | output.append(data["col1"].tolist()) | ||||
| @@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis(): | |||||
| def test_concatenate_op_incorrect_input_dim(): | def test_concatenate_op_incorrect_input_dim(): | ||||
| def gen(): | |||||
| yield (np.array(["ss", "ad"], dtype='S'),) | |||||
| prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') | prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') | ||||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||||
| concatenate_op = data_trans.Concatenate(0, prepend_tensor) | |||||
| data = data.map(input_columns=["col"], operations=concatenate_op) | |||||
| with pytest.raises(RuntimeError) as error_info: | |||||
| for _ in data: | |||||
| pass | |||||
| assert "Only 1D tensors supported" in repr(error_info.value) | |||||
| with pytest.raises(ValueError) as error_info: | |||||
| data_trans.Concatenate(0, prepend_tensor) | |||||
| assert "can only prepend 1D arrays." in repr(error_info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -28,9 +28,9 @@ def test_exception_01(): | |||||
| """ | """ | ||||
| logger.info("test_exception_01") | logger.info("test_exception_01") | ||||
| data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) | data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) | ||||
| with pytest.raises(ValueError) as info: | |||||
| data = data.map(input_columns=["image"], operations=vision.Resize(100, 100)) | |||||
| assert "Invalid interpolation mode." in str(info.value) | |||||
| with pytest.raises(TypeError) as info: | |||||
| data.map(input_columns=["image"], operations=vision.Resize(100, 100)) | |||||
| assert "Argument interpolation with value 100 is not of type (<enum 'Inter'>,)" in str(info.value) | |||||
| def test_exception_02(): | def test_exception_02(): | ||||
| @@ -40,8 +40,8 @@ def test_exception_02(): | |||||
| logger.info("test_exception_02") | logger.info("test_exception_02") | ||||
| num_samples = -1 | num_samples = -1 | ||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||||
| assert "num_samples cannot be less than 0" in str(info.value) | |||||
| ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||||
| assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value) | |||||
| num_samples = 1 | num_samples = 1 | ||||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | ||||
| @@ -23,7 +23,8 @@ import mindspore.dataset.text as text | |||||
| def test_demo_basic_from_dataset(): | def test_demo_basic_from_dataset(): | ||||
| """ this is a tutorial on how from_dataset should be used in a normal use case""" | """ this is a tutorial on how from_dataset should be used in a normal use case""" | ||||
| data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | ||||
| vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["<pad>", "<unk>"], | |||||
| vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, | |||||
| special_tokens=["<pad>", "<unk>"], | |||||
| special_first=True) | special_first=True) | ||||
| data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) | ||||
| res = [] | res = [] | ||||
| @@ -127,15 +128,16 @@ def test_from_dataset_exceptions(): | |||||
| data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) | ||||
| vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) | vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) | ||||
| assert isinstance(vocab.text.Vocab) | assert isinstance(vocab.text.Vocab) | ||||
| except ValueError as e: | |||||
| except (TypeError, ValueError, RuntimeError) as e: | |||||
| assert s in str(e), str(e) | assert s in str(e), str(e) | ||||
| test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers") | |||||
| test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer") | |||||
| test_config(23, (2, 3), 1.2345, "columns need to be a list of strings") | |||||
| test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b") | |||||
| test_config("text", (2, 3), 0, "top_k needs to be a positive integer") | |||||
| test_config([123], (2, 3), 0, "columns need to be a list of strings") | |||||
| test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") | |||||
| test_config("text", (2, 3), 1.2345, | |||||
| "Argument top_k with value 1.2345 is not of type (<class 'int'>, <class 'NoneType'>)") | |||||
| test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (<class 'str'>,)") | |||||
| test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") | |||||
| test_config("text", (2, 3), 0, "top_k needs to be positive number") | |||||
| test_config([123], (2, 3), 0, "top_k needs to be positive number") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False): | |||||
| if plot: | if plot: | ||||
| visualize_list(image, image_transformed) | visualize_list(image, image_transformed) | ||||
| def test_linear_transformation_md5(): | def test_linear_transformation_md5(): | ||||
| """ | """ | ||||
| Test LinearTransformation op: valid params (transformation_matrix, mean_vector) | Test LinearTransformation op: valid params (transformation_matrix, mean_vector) | ||||
| @@ -102,6 +103,7 @@ def test_linear_transformation_md5(): | |||||
| filename = "linear_transformation_01_result.npz" | filename = "linear_transformation_01_result.npz" | ||||
| save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) | ||||
| def test_linear_transformation_exception_01(): | def test_linear_transformation_exception_01(): | ||||
| """ | """ | ||||
| Test LinearTransformation op: transformation_matrix is not provided | Test LinearTransformation op: transformation_matrix is not provided | ||||
| @@ -126,9 +128,10 @@ def test_linear_transformation_exception_01(): | |||||
| ] | ] | ||||
| transform = py_vision.ComposeOp(transforms) | transform = py_vision.ComposeOp(transforms) | ||||
| data1 = data1.map(input_columns=["image"], operations=transform()) | data1 = data1.map(input_columns=["image"], operations=transform()) | ||||
| except ValueError as e: | |||||
| except TypeError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "not provided" in str(e) | |||||
| assert "Argument transformation_matrix with value None is not of type (<class 'numpy.ndarray'>,)" in str(e) | |||||
| def test_linear_transformation_exception_02(): | def test_linear_transformation_exception_02(): | ||||
| """ | """ | ||||
| @@ -154,9 +157,10 @@ def test_linear_transformation_exception_02(): | |||||
| ] | ] | ||||
| transform = py_vision.ComposeOp(transforms) | transform = py_vision.ComposeOp(transforms) | ||||
| data1 = data1.map(input_columns=["image"], operations=transform()) | data1 = data1.map(input_columns=["image"], operations=transform()) | ||||
| except ValueError as e: | |||||
| except TypeError as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "not provided" in str(e) | |||||
| assert "Argument mean_vector with value None is not of type (<class 'numpy.ndarray'>,)" in str(e) | |||||
| def test_linear_transformation_exception_03(): | def test_linear_transformation_exception_03(): | ||||
| """ | """ | ||||
| @@ -187,6 +191,7 @@ def test_linear_transformation_exception_03(): | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "square matrix" in str(e) | assert "square matrix" in str(e) | ||||
| def test_linear_transformation_exception_04(): | def test_linear_transformation_exception_04(): | ||||
| """ | """ | ||||
| Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix | Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix | ||||
| @@ -199,7 +204,7 @@ def test_linear_transformation_exception_04(): | |||||
| weight = 50 | weight = 50 | ||||
| dim = 3 * height * weight | dim = 3 * height * weight | ||||
| transformation_matrix = np.ones([dim, dim]) | transformation_matrix = np.ones([dim, dim]) | ||||
| mean_vector = np.zeros(dim-1) | |||||
| mean_vector = np.zeros(dim - 1) | |||||
| # Generate dataset | # Generate dataset | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | ||||
| @@ -216,6 +221,7 @@ def test_linear_transformation_exception_04(): | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "should match" in str(e) | assert "should match" in str(e) | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_linear_transformation_op(plot=True) | test_linear_transformation_op(plot=True) | ||||
| test_linear_transformation_md5() | test_linear_transformation_md5() | ||||
| @@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards(): | |||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| columns_list = ["data", "label"] | columns_list = ["data", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||||
| with pytest.raises(Exception) as error_info: | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| def test_minddataset_invalidate_shard_id(): | def test_minddataset_invalidate_shard_id(): | ||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| columns_list = ["data", "label"] | columns_list = ["data", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||||
| with pytest.raises(Exception) as error_info: | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 0).' in repr(error_info) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard(): | |||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| columns_list = ["data", "label"] | columns_list = ["data", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||||
| with pytest.raises(Exception) as error_info: | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) | |||||
| with pytest.raises(Exception, match="shard_id is invalid, "): | |||||
| with pytest.raises(Exception) as error_info: | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(): | for _ in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert 'Input shard_id is not within the required interval of (0 to 1).' in repr(error_info) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| @@ -15,9 +15,9 @@ | |||||
| """ | """ | ||||
| Testing Ngram in mindspore.dataset | Testing Ngram in mindspore.dataset | ||||
| """ | """ | ||||
| import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.text as text | import mindspore.dataset.text as text | ||||
| import numpy as np | |||||
| def test_multiple_ngrams(): | def test_multiple_ngrams(): | ||||
| @@ -61,7 +61,7 @@ def test_simple_ngram(): | |||||
| yield (np.array(line.split(" "), dtype='S'),) | yield (np.array(line.split(" "), dtype='S'),) | ||||
| dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) | dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) | ||||
| dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None)) | |||||
| dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" ")) | |||||
| i = 0 | i = 0 | ||||
| for data in dataset.create_dict_iterator(): | for data in dataset.create_dict_iterator(): | ||||
| @@ -72,7 +72,7 @@ def test_simple_ngram(): | |||||
| def test_corner_cases(): | def test_corner_cases(): | ||||
| """ testing various corner cases and exceptions""" | """ testing various corner cases and exceptions""" | ||||
| def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None): | |||||
| def test_config(input_line, output_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): | |||||
| def gen(texts): | def gen(texts): | ||||
| yield (np.array(texts.split(" "), dtype='S'),) | yield (np.array(texts.split(" "), dtype='S'),) | ||||
| @@ -93,7 +93,7 @@ def test_corner_cases(): | |||||
| try: | try: | ||||
| test_config("Yours to Discover", "", [0, [1]]) | test_config("Yours to Discover", "", [0, [1]]) | ||||
| except Exception as e: | except Exception as e: | ||||
| assert "ngram needs to be a positive number" in str(e) | |||||
| assert "Argument gram[1] with value [1] is not of type (<class 'int'>,)" in str(e) | |||||
| # test empty n | # test empty n | ||||
| try: | try: | ||||
| test_config("Yours to Discover", "", []) | test_config("Yours to Discover", "", []) | ||||
| @@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py(): | |||||
| _ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) | _ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input is not within the required range" in str(e) | |||||
| assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e) | |||||
| def test_normalize_grayscale_md5_01(): | def test_normalize_grayscale_md5_01(): | ||||
| @@ -61,6 +61,10 @@ def test_pad_end_exceptions(): | |||||
| pad_compare([3, 4, 5], ["2"], 1, []) | pad_compare([3, 4, 5], ["2"], 1, []) | ||||
| assert "a value in the list is not an integer." in str(info.value) | assert "a value in the list is not an integer." in str(info.value) | ||||
| with pytest.raises(TypeError) as info: | |||||
| pad_compare([1, 2], 3, -1, [1, 2, -1]) | |||||
| assert "Argument pad_end with value 3 is not of type (<class 'list'>,)" in str(info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_pad_end_basics() | test_pad_end_basics() | ||||
| @@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees(): | |||||
| _ = py_vision.RandomAffine(degrees=-15) | _ = py_vision.RandomAffine(degrees=-15) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "If degrees is a single number, it cannot be negative." | |||||
| assert str(e) == "Input degrees is not within the required interval of (0 to inf)." | |||||
| def test_random_affine_exception_translation_range(): | def test_random_affine_exception_translation_range(): | ||||
| @@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range(): | |||||
| _ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5)) | _ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5)) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "translation values should be between 0 and 1" | |||||
| assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)." | |||||
| def test_random_affine_exception_scale_value(): | def test_random_affine_exception_scale_value(): | ||||
| @@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value(): | |||||
| _ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1)) | _ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1)) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "scale values should be positive" | |||||
| assert str(e) == "Input scale[0] must be greater than 0." | |||||
| def test_random_affine_exception_shear_value(): | def test_random_affine_exception_shear_value(): | ||||
| @@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value(): | |||||
| _ = py_vision.RandomAffine(degrees=15, shear=-5) | _ = py_vision.RandomAffine(degrees=15, shear=-5) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "If shear is a single number, it must be positive." | |||||
| assert str(e) == "Input shear must be greater than 0." | |||||
| def test_random_affine_exception_degrees_size(): | def test_random_affine_exception_degrees_size(): | ||||
| @@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size(): | |||||
| _ = py_vision.RandomAffine(degrees=15, translate=(0.1)) | _ = py_vision.RandomAffine(degrees=15, translate=(0.1)) | ||||
| except TypeError as e: | except TypeError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "translate should be a list or tuple of length 2." | |||||
| assert str( | |||||
| e) == "Argument translate with value 0.1 is not of type (<class 'list'>," \ | |||||
| " <class 'tuple'>)." | |||||
| def test_random_affine_exception_scale_size(): | def test_random_affine_exception_scale_size(): | ||||
| @@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size(): | |||||
| _ = py_vision.RandomAffine(degrees=15, scale=(0.5)) | _ = py_vision.RandomAffine(degrees=15, scale=(0.5)) | ||||
| except TypeError as e: | except TypeError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "scale should be a list or tuple of length 2." | |||||
| assert str(e) == "Argument scale with value 0.5 is not of type (<class 'tuple'>," \ | |||||
| " <class 'list'>)." | |||||
| def test_random_affine_exception_shear_size(): | def test_random_affine_exception_shear_size(): | ||||
| @@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size(): | |||||
| _ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10)) | _ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10)) | ||||
| except TypeError as e: | except TypeError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4." | |||||
| assert str(e) == "shear must be of length 2 or 4." | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -97,7 +97,7 @@ def test_random_color_md5(): | |||||
| data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) | ||||
| transforms = F.ComposeOp([F.Decode(), | transforms = F.ComposeOp([F.Decode(), | ||||
| F.RandomColor((0.5, 1.5)), | |||||
| F.RandomColor((0.1, 1.9)), | |||||
| F.ToTensor()]) | F.ToTensor()]) | ||||
| data = data.map(input_columns="image", operations=transforms()) | data = data.map(input_columns="image", operations=transforms()) | ||||
| @@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c(): | |||||
| data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) | data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input range is not valid" in str(e) | |||||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||||
| def test_random_crop_and_resize_04_py(): | def test_random_crop_and_resize_04_py(): | ||||
| @@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py(): | |||||
| data = data.map(input_columns=["image"], operations=transform()) | data = data.map(input_columns=["image"], operations=transform()) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input range is not valid" in str(e) | |||||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||||
| def test_random_crop_and_resize_05_c(): | def test_random_crop_and_resize_05_c(): | ||||
| @@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c(): | |||||
| data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) | data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input range is not valid" in str(e) | |||||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||||
| def test_random_crop_and_resize_05_py(): | def test_random_crop_and_resize_05_py(): | ||||
| @@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py(): | |||||
| data = data.map(input_columns=["image"], operations=transform()) | data = data.map(input_columns=["image"], operations=transform()) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input range is not valid" in str(e) | |||||
| assert "Input is not within the required interval of (0 to 16777216)." in str(e) | |||||
| def test_random_crop_and_resize_comp(plot=False): | def test_random_crop_and_resize_comp(plot=False): | ||||
| @@ -159,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c(): | |||||
| except ValueError as err: | except ValueError as err: | ||||
| logger.info("Got an exception in DE: {}".format(str(err))) | logger.info("Got an exception in DE: {}".format(str(err))) | ||||
| assert "Input range is not valid" in str(err) | |||||
| assert "Input is not within the required interval of (0 to 16777216)." in str(err) | |||||
| def test_random_resized_crop_with_bbox_op_invalid2_c(): | def test_random_resized_crop_with_bbox_op_invalid2_c(): | ||||
| @@ -185,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): | |||||
| except ValueError as err: | except ValueError as err: | ||||
| logger.info("Got an exception in DE: {}".format(str(err))) | logger.info("Got an exception in DE: {}".format(str(err))) | ||||
| assert "Input range is not valid" in str(err) | |||||
| assert "Input is not within the required interval of (0 to 16777216)." in str(err) | |||||
| def test_random_resized_crop_with_bbox_op_bad_c(): | def test_random_resized_crop_with_bbox_op_bad_c(): | ||||
| @@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param(): | |||||
| data = data.map(input_columns=["image"], operations=transform()) | data = data.map(input_columns=["image"], operations=transform()) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input is not within the required range" in str(e) | |||||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_random_grayscale_valid_prob(True) | test_random_grayscale_valid_prob(True) | ||||
| @@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c(): | |||||
| data = data.map(input_columns=["image"], operations=random_horizontal_op) | data = data.map(input_columns=["image"], operations=random_horizontal_op) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input is not" in str(e) | |||||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) | |||||
| def test_random_horizontal_invalid_prob_py(): | def test_random_horizontal_invalid_prob_py(): | ||||
| @@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py(): | |||||
| data = data.map(input_columns=["image"], operations=transform()) | data = data.map(input_columns=["image"], operations=transform()) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input is not" in str(e) | |||||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) | |||||
| def test_random_horizontal_comp(plot=False): | def test_random_horizontal_comp(plot=False): | ||||
| @@ -190,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): | |||||
| operations=[test_op]) # Add column for "annotation" | operations=[test_op]) # Add column for "annotation" | ||||
| except ValueError as error: | except ValueError as error: | ||||
| logger.info("Got an exception in DE: {}".format(str(error))) | logger.info("Got an exception in DE: {}".format(str(error))) | ||||
| assert "Input is not" in str(error) | |||||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error) | |||||
| def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): | def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): | ||||
| @@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range(): | |||||
| _ = py_vision.RandomPerspective(distortion_scale=1.5) | _ = py_vision.RandomPerspective(distortion_scale=1.5) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "Input is not within the required range" | |||||
| assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)." | |||||
| def test_random_perspective_exception_prob_range(): | def test_random_perspective_exception_prob_range(): | ||||
| @@ -119,7 +119,7 @@ def test_random_perspective_exception_prob_range(): | |||||
| _ = py_vision.RandomPerspective(prob=1.2) | _ = py_vision.RandomPerspective(prob=1.2) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert str(e) == "Input is not within the required range" | |||||
| assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)." | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -163,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c(): | |||||
| except ValueError as err: | except ValueError as err: | ||||
| logger.info("Got an exception in DE: {}".format(str(err))) | logger.info("Got an exception in DE: {}".format(str(err))) | ||||
| assert "Input is not" in str(err) | |||||
| assert "Input is not within the required interval of (1 to 16777216)." in str(err) | |||||
| try: | try: | ||||
| # one of the size values is zero | # one of the size values is zero | ||||
| @@ -171,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c(): | |||||
| except ValueError as err: | except ValueError as err: | ||||
| logger.info("Got an exception in DE: {}".format(str(err))) | logger.info("Got an exception in DE: {}".format(str(err))) | ||||
| assert "Input is not" in str(err) | |||||
| assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err) | |||||
| try: | try: | ||||
| # negative value for resize | # negative value for resize | ||||
| @@ -179,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c(): | |||||
| except ValueError as err: | except ValueError as err: | ||||
| logger.info("Got an exception in DE: {}".format(str(err))) | logger.info("Got an exception in DE: {}".format(str(err))) | ||||
| assert "Input is not" in str(err) | |||||
| assert "Input is not within the required interval of (1 to 16777216)." in str(err) | |||||
| try: | try: | ||||
| # invalid input shape | # invalid input shape | ||||
| @@ -97,7 +97,7 @@ def test_random_sharpness_md5(): | |||||
| # define map operations | # define map operations | ||||
| transforms = [ | transforms = [ | ||||
| F.Decode(), | F.Decode(), | ||||
| F.RandomSharpness((0.5, 1.5)), | |||||
| F.RandomSharpness((0.1, 1.9)), | |||||
| F.ToTensor() | F.ToTensor() | ||||
| ] | ] | ||||
| transform = F.ComposeOp(transforms) | transform = F.ComposeOp(transforms) | ||||
| @@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c(): | |||||
| data = data.map(input_columns=["image"], operations=random_horizontal_op) | data = data.map(input_columns=["image"], operations=random_horizontal_op) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input is not" in str(e) | |||||
| assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) | |||||
| def test_random_vertical_invalid_prob_py(): | def test_random_vertical_invalid_prob_py(): | ||||
| @@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py(): | |||||
| data = data.map(input_columns=["image"], operations=transform()) | data = data.map(input_columns=["image"], operations=transform()) | ||||
| except ValueError as e: | except ValueError as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "Input is not" in str(e) | |||||
| assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) | |||||
| def test_random_vertical_comp(plot=False): | def test_random_vertical_comp(plot=False): | ||||
| @@ -191,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): | |||||
| except ValueError as err: | except ValueError as err: | ||||
| logger.info("Got an exception in DE: {}".format(str(err))) | logger.info("Got an exception in DE: {}".format(str(err))) | ||||
| assert "Input is not" in str(err) | |||||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err) | |||||
| def test_random_vertical_flip_with_bbox_op_bad_c(): | def test_random_vertical_flip_with_bbox_op_bad_c(): | ||||
| @@ -150,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c(): | |||||
| # invalid interpolation value | # invalid interpolation value | ||||
| c_vision.ResizeWithBBox(400, interpolation="invalid") | c_vision.ResizeWithBBox(400, interpolation="invalid") | ||||
| except ValueError as err: | |||||
| except TypeError as err: | |||||
| logger.info("Got an exception in DE: {}".format(str(err))) | logger.info("Got an exception in DE: {}".format(str(err))) | ||||
| assert "interpolation" in str(err) | assert "interpolation" in str(err) | ||||
| @@ -154,7 +154,7 @@ def test_shuffle_exception_01(): | |||||
| except Exception as e: | except Exception as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "buffer_size" in str(e) | |||||
| assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) | |||||
| def test_shuffle_exception_02(): | def test_shuffle_exception_02(): | ||||
| @@ -172,7 +172,7 @@ def test_shuffle_exception_02(): | |||||
| except Exception as e: | except Exception as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "buffer_size" in str(e) | |||||
| assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) | |||||
| def test_shuffle_exception_03(): | def test_shuffle_exception_03(): | ||||
| @@ -190,7 +190,7 @@ def test_shuffle_exception_03(): | |||||
| except Exception as e: | except Exception as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "buffer_size" in str(e) | |||||
| assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) | |||||
| def test_shuffle_exception_05(): | def test_shuffle_exception_05(): | ||||
| @@ -62,7 +62,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False): | |||||
| logger.info("dtype of image_2: {}".format(image_2.dtype)) | logger.info("dtype of image_2: {}".format(image_2.dtype)) | ||||
| if plot: | if plot: | ||||
| visualize_list(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) | |||||
| visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) | |||||
| # The output data should be of a 4D tensor shape, a stack of 10 images. | # The output data should be of a 4D tensor shape, a stack of 10 images. | ||||
| assert len(image_2.shape) == 4 | assert len(image_2.shape) == 4 | ||||
| @@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg(): | |||||
| vision.TenCrop(0), | vision.TenCrop(0), | ||||
| lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images | lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images | ||||
| ] | ] | ||||
| error_msg = "Input is not within the required range" | |||||
| error_msg = "Input is not within the required interval of (1 to 16777216)." | |||||
| assert error_msg == str(info.value) | assert error_msg == str(info.value) | ||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| @@ -169,7 +169,9 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): | |||||
| except Exception as e: | except Exception as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "operations" in str(e) | |||||
| assert "Argument tensor_op_5 with value" \ | |||||
| " <mindspore.dataset.transforms.vision.py_transforms.Invert" in str(e) | |||||
| assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,)" in str(e) | |||||
| def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | ||||
| @@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): | |||||
| except Exception as e: | except Exception as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "num_ops" in str(e) | |||||
| assert "Input num_ops must be greater than 0" in str(e) | |||||
| def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): | def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): | ||||
| @@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): | |||||
| except Exception as e: | except Exception as e: | ||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| assert "integer" in str(e) | |||||
| assert "Argument num_ops with value 2.5 is not of type (<class 'int'>,)" in str(e) | |||||
| def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): | def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): | ||||
| @@ -314,14 +314,15 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows= | |||||
| if len(orig) != len(aug) or not orig: | if len(orig) != len(aug) or not orig: | ||||
| return | return | ||||
| batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together | |||||
| batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together | |||||
| split_point = batch_size * plot_rows | split_point = batch_size * plot_rows | ||||
| orig, aug = np.array(orig), np.array(aug) | orig, aug = np.array(orig), np.array(aug) | ||||
| if len(orig) > plot_rows: | if len(orig) > plot_rows: | ||||
| # Create batches of required size and add remainder to last batch | # Create batches of required size and add remainder to last batch | ||||
| orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added | |||||
| orig = np.split(orig[:split_point], batch_size) + ( | |||||
| [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added | |||||
| aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) | aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) | ||||
| else: | else: | ||||
| orig = [orig] | orig = [orig] | ||||
| @@ -336,7 +337,8 @@ def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows= | |||||
| for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): | for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): | ||||
| cur_ix = base_ix + x | cur_ix = base_ix + x | ||||
| (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row | |||||
| # select plotting axes based on number of image rows on plot - else case when 1 row | |||||
| (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) | |||||
| axA.imshow(dataA["image"]) | axA.imshow(dataA["image"]) | ||||
| add_bounding_boxes(axA, dataA[annot_name]) | add_bounding_boxes(axA, dataA[annot_name]) | ||||