From: @luoyang42 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -846,6 +846,12 @@ class SoftDvppDecodeRandomCropResizeJpeg : public TensorTransform { | |||||
| /// \param[in] size A vector representing the output size of the resized image. | /// \param[in] size A vector representing the output size of the resized image. | ||||
| /// If size is a single value, smaller edge of the image will be resized to this value with | /// If size is a single value, smaller edge of the image will be resized to this value with | ||||
| /// the same image aspect ratio. If size has 2 values, it should be (height, width). | /// the same image aspect ratio. If size has 2 values, it should be (height, width). | ||||
| /// \param[in] scale Range [min, max) of respective size of the original | |||||
| /// size to be cropped (default=(0.08, 1.0)). | |||||
| /// \param[in] ratio Range [min, max) of aspect ratio to be cropped | |||||
| /// (default=(3. / 4., 4. / 3.)). | |||||
| /// \param[in] max_attempts The maximum number of attempts to propose a valid | |||||
| /// crop_area (default=10). If exceeded, fall back to use center_crop instead. | |||||
| SoftDvppDecodeRandomCropResizeJpeg(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, | SoftDvppDecodeRandomCropResizeJpeg(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0}, | ||||
| std::vector<float> ratio = {3. / 4., 4. / 3.}, int32_t max_attempts = 10); | std::vector<float> ratio = {3. / 4., 4. / 3.}, int32_t max_attempts = 10); | ||||
| @@ -72,6 +72,10 @@ bool CheckTensorShape(const std::shared_ptr<Tensor> &tensor, const int &channel) | |||||
| Status Flip(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, int flip_code) { | Status Flip(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, int flip_code) { | ||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input)); | std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input)); | ||||
| if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) { | |||||
| RETURN_STATUS_UNEXPECTED("Flip: input tensor is not in shape of <H,W,C> or <H,W>."); | |||||
| } | |||||
| std::shared_ptr<CVTensor> output_cv; | std::shared_ptr<CVTensor> output_cv; | ||||
| RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | ||||
| @@ -583,9 +587,13 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||||
| if (!input_cv->mat().data) { | if (!input_cv->mat().data) { | ||||
| RETURN_STATUS_UNEXPECTED("Rotate: load image failed."); | RETURN_STATUS_UNEXPECTED("Rotate: load image failed."); | ||||
| } | } | ||||
| if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) { | |||||
| RETURN_STATUS_UNEXPECTED("Rotate: input tensor is not in shape of <H,W,C> or <H,W>."); | |||||
| } | |||||
| cv::Mat input_img = input_cv->mat(); | cv::Mat input_img = input_cv->mat(); | ||||
| if (input_img.cols > (MAX_INT_PRECISION * 2) || input_img.rows > (MAX_INT_PRECISION * 2)) { | if (input_img.cols > (MAX_INT_PRECISION * 2) || input_img.rows > (MAX_INT_PRECISION * 2)) { | ||||
| RETURN_STATUS_UNEXPECTED("Rotate: image is too large and center not precise"); | |||||
| RETURN_STATUS_UNEXPECTED("Rotate: image is too large and center is not precise."); | |||||
| } | } | ||||
| // default to center of image | // default to center of image | ||||
| if (fx == -1 && fy == -1) { | if (fx == -1 && fy == -1) { | ||||
| @@ -728,7 +736,7 @@ Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||||
| } | } | ||||
| int num_channels = input_cv->shape()[2]; | int num_channels = input_cv->shape()[2]; | ||||
| if (input_cv->Rank() != 3 || num_channels != 3) { | if (input_cv->Rank() != 3 || num_channels != 3) { | ||||
| RETURN_STATUS_UNEXPECTED("AdjustBrightness: image shape is not <H,W,C>."); | |||||
| RETURN_STATUS_UNEXPECTED("AdjustBrightness: image shape is not <H,W,C> or channel is not 3."); | |||||
| } | } | ||||
| std::shared_ptr<CVTensor> output_cv; | std::shared_ptr<CVTensor> output_cv; | ||||
| RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | ||||
| @@ -749,7 +757,7 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens | |||||
| } | } | ||||
| int num_channels = input_cv->shape()[2]; | int num_channels = input_cv->shape()[2]; | ||||
| if (input_cv->Rank() != 3 || num_channels != 3) { | if (input_cv->Rank() != 3 || num_channels != 3) { | ||||
| RETURN_STATUS_UNEXPECTED("AdjustContrast: image shape is not <H,W,C>."); | |||||
| RETURN_STATUS_UNEXPECTED("AdjustContrast: image shape is not <H,W,C> or channel is not 3."); | |||||
| } | } | ||||
| cv::Mat gray, output_img; | cv::Mat gray, output_img; | ||||
| cv::cvtColor(input_img, gray, CV_RGB2GRAY); | cv::cvtColor(input_img, gray, CV_RGB2GRAY); | ||||
| @@ -854,7 +862,7 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||||
| } | } | ||||
| int num_channels = input_cv->shape()[2]; | int num_channels = input_cv->shape()[2]; | ||||
| if (input_cv->Rank() != 3 || num_channels != 3) { | if (input_cv->Rank() != 3 || num_channels != 3) { | ||||
| RETURN_STATUS_UNEXPECTED("AdjustSaturation: image shape is not <H,W,C>."); | |||||
| RETURN_STATUS_UNEXPECTED("AdjustSaturation: image shape is not <H,W,C> or channel is not 3."); | |||||
| } | } | ||||
| std::shared_ptr<CVTensor> output_cv; | std::shared_ptr<CVTensor> output_cv; | ||||
| RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | ||||
| @@ -882,7 +890,7 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * | |||||
| } | } | ||||
| int num_channels = input_cv->shape()[2]; | int num_channels = input_cv->shape()[2]; | ||||
| if (input_cv->Rank() != 3 || num_channels != 3) { | if (input_cv->Rank() != 3 || num_channels != 3) { | ||||
| RETURN_STATUS_UNEXPECTED("AdjustHue: image shape is not <H,W,C>."); | |||||
| RETURN_STATUS_UNEXPECTED("AdjustHue: image shape is not <H,W,C> or channel is not 3."); | |||||
| } | } | ||||
| std::shared_ptr<CVTensor> output_cv; | std::shared_ptr<CVTensor> output_cv; | ||||
| RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); | ||||
| @@ -956,7 +964,7 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp | |||||
| RETURN_STATUS_UNEXPECTED("CutOut: load image failed."); | RETURN_STATUS_UNEXPECTED("CutOut: load image failed."); | ||||
| } | } | ||||
| if (input_cv->Rank() != 3 || num_channels != 3) { | if (input_cv->Rank() != 3 || num_channels != 3) { | ||||
| RETURN_STATUS_UNEXPECTED("CutOut: image shape is not <H,W,C> or <H,W>."); | |||||
| RETURN_STATUS_UNEXPECTED("CutOut: image shape is not <H,W,C> or channel is not 3."); | |||||
| } | } | ||||
| cv::Mat input_img = input_cv->mat(); | cv::Mat input_img = input_cv->mat(); | ||||
| int32_t image_h = input_cv->shape()[0]; | int32_t image_h = input_cv->shape()[0]; | ||||
| @@ -1016,6 +1024,12 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output | |||||
| try { | try { | ||||
| // input image | // input image | ||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | ||||
| // validate rank | |||||
| if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) { | |||||
| RETURN_STATUS_UNEXPECTED("Pad: input tensor is not in shape of <H,W,C> or <H,W>."); | |||||
| } | |||||
| // get the border type in openCV | // get the border type in openCV | ||||
| auto b_type = GetCVBorderType(border_types); | auto b_type = GetCVBorderType(border_types); | ||||
| // output image | // output image | ||||
| @@ -1106,6 +1120,10 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||||
| InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { | InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { | ||||
| try { | try { | ||||
| std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); | ||||
| if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or channel is not 3."); | |||||
| } | |||||
| cv::Mat affine_mat(mat); | cv::Mat affine_mat(mat); | ||||
| affine_mat = affine_mat.reshape(1, {2, 3}); | affine_mat = affine_mat.reshape(1, {2, 3}); | ||||
| @@ -25,8 +25,8 @@ RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_ | |||||
| Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) { | Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) { | ||||
| IO_CHECK(in, out); | IO_CHECK(in, out); | ||||
| if (in->Rank() != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("RandomColor: image shape is not <H,W,C>."); | |||||
| if (in->Rank() != 3 || in->shape()[2] != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("RandomColor: image shape is not <H,W,C> or channel is not 3."); | |||||
| } | } | ||||
| // 0.5 pixel precision assuming an 8 bit image | // 0.5 pixel precision assuming an 8 bit image | ||||
| const auto eps = 0.00195; | const auto eps = 0.00195; | ||||
| @@ -34,8 +34,8 @@ Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt | |||||
| RETURN_STATUS_UNEXPECTED("Sharpness: load image failed."); | RETURN_STATUS_UNEXPECTED("Sharpness: load image failed."); | ||||
| } | } | ||||
| if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { | |||||
| RETURN_STATUS_UNEXPECTED("Sharpness: image shape is not <H,W,C> or <H,W>"); | |||||
| if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) { | |||||
| RETURN_STATUS_UNEXPECTED("Sharpness: input tensor is not in shape of <H,W,C> or <H,W>."); | |||||
| } | } | ||||
| /// creating a smoothing filter. 1, 1, 1, | /// creating a smoothing filter. 1, 1, 1, | ||||
| @@ -39,16 +39,6 @@ Status SolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr | |||||
| RETURN_STATUS_UNEXPECTED("Solarize: load image failed."); | RETURN_STATUS_UNEXPECTED("Solarize: load image failed."); | ||||
| } | } | ||||
| if (input_cv->Rank() != 2 && input_cv->Rank() != 3) { | |||||
| RETURN_STATUS_UNEXPECTED("Solarize: image shape is not <H,W,C> or <H,W>."); | |||||
| } | |||||
| if (input_cv->Rank() == 3) { | |||||
| int num_channels = input_cv->shape()[2]; | |||||
| if (num_channels != 3 && num_channels != 1) { | |||||
| RETURN_STATUS_UNEXPECTED("Solarize: image shape is not <H,W,C>."); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<CVTensor> mask_mat_tensor; | std::shared_ptr<CVTensor> mask_mat_tensor; | ||||
| std::shared_ptr<CVTensor> output_cv_tensor; | std::shared_ptr<CVTensor> output_cv_tensor; | ||||
| RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_cv->mat(), &mask_mat_tensor)); | RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_cv->mat(), &mask_mat_tensor)); | ||||
| @@ -160,8 +160,8 @@ Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[0], {0}, true)); | |||||
| RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[1], {0}, true)); | |||||
| RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[0], {0}, true)); | |||||
| RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[1], {0}, true)); | |||||
| if (ratio[1] < ratio[0]) { | if (ratio[1] < ratio[0]) { | ||||
| std::string err_msg = op_name + ": ratio must be in the format of (min, max)."; | std::string err_msg = op_name + ": ratio must be in the format of (min, max)."; | ||||
| MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio; | MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio; | ||||
| @@ -187,7 +187,7 @@ BoundingBoxAugmentOperation::BoundingBoxAugmentOperation(std::shared_ptr<TensorO | |||||
| Status BoundingBoxAugmentOperation::ValidateParams() { | Status BoundingBoxAugmentOperation::ValidateParams() { | ||||
| RETURN_IF_NOT_OK(ValidateVectorTransforms("BoundingBoxAugment", {transform_})); | RETURN_IF_NOT_OK(ValidateVectorTransforms("BoundingBoxAugment", {transform_})); | ||||
| RETURN_IF_NOT_OK(ValidateProbability("BoundingBoxAugment", ratio_)); | |||||
| RETURN_IF_NOT_OK(ValidateScalar("BoundingBoxAugment", "ratio", ratio_, {0.0, 1.0}, false, false)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1566,7 +1566,8 @@ Status UniformAugOperation::ValidateParams() { | |||||
| // transforms | // transforms | ||||
| RETURN_IF_NOT_OK(ValidateVectorTransforms("UniformAug", transforms_)); | RETURN_IF_NOT_OK(ValidateVectorTransforms("UniformAug", transforms_)); | ||||
| if (num_ops_ > transforms_.size()) { | if (num_ops_ > transforms_.size()) { | ||||
| std::string err_msg = "UniformAug: num_ops is greater than transforms size, but got: " + std::to_string(num_ops_); | |||||
| std::string err_msg = | |||||
| "UniformAug: num_ops must be less than or equal to transforms size, but got: " + std::to_string(num_ops_); | |||||
| MS_LOG(ERROR) << err_msg; | MS_LOG(ERROR) << err_msg; | ||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| @@ -387,5 +387,13 @@ def check_tensor_op(param, param_name): | |||||
| raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | ||||
| def check_c_tensor_op(param, param_name): | |||||
| """check whether param is a tensor op or a callable Python function but not a py_transform""" | |||||
| if callable(param) and getattr(param, 'parse', True): | |||||
| raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name)) | |||||
| if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): | |||||
| raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | |||||
| def replace_none(value, default): | def replace_none(value, default): | ||||
| return value if value is not None else default | return value if value is not None else default | ||||
| @@ -34,5 +34,5 @@ __all__ = ["CelebADataset", "Cifar100Dataset", "Cifar10Dataset", "CLUEDataset", | |||||
| "GeneratorDataset", "GraphData", "ImageFolderDataset", "ManifestDataset", "MindDataset", "MnistDataset", | "GeneratorDataset", "GraphData", "ImageFolderDataset", "ManifestDataset", "MindDataset", "MnistDataset", | ||||
| "NumpySlicesDataset", "PaddedDataset", "TextFileDataset", "TFRecordDataset", "VOCDataset", | "NumpySlicesDataset", "PaddedDataset", "TextFileDataset", "TFRecordDataset", "VOCDataset", | ||||
| "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", | "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", | ||||
| "WeightedRandomSampler", | |||||
| "WeightedRandomSampler", "SubsetSampler", | |||||
| "config", "DatasetCache", "Schema", "zip"] | "config", "DatasetCache", "Schema", "zip"] | ||||
| @@ -52,8 +52,8 @@ import mindspore.common.dtype as mstype | |||||
| from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType | from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType | ||||
| from .validators import check_lookup, check_jieba_add_dict, \ | from .validators import check_lookup, check_jieba_add_dict, \ | ||||
| check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \ | check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \ | ||||
| check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate, \ | |||||
| check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow | |||||
| check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \ | |||||
| check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow | |||||
| from ..core.datatypes import mstype_to_detype | from ..core.datatypes import mstype_to_detype | ||||
| from ..core.validator_helpers import replace_none | from ..core.validator_helpers import replace_none | ||||
| from ..transforms.c_transforms import TensorOperation | from ..transforms.c_transforms import TensorOperation | ||||
| @@ -756,6 +756,7 @@ if platform.system().lower() != 'windows': | |||||
| >>> text_file_dataset = text_file_dataset.map(operations=replace_op) | >>> text_file_dataset = text_file_dataset.map(operations=replace_op) | ||||
| """ | """ | ||||
| @check_regex_replace | |||||
| def __init__(self, pattern, replace, replace_all=True): | def __init__(self, pattern, replace, replace_all=True): | ||||
| self.pattern = pattern | self.pattern = pattern | ||||
| self.replace = replace | self.replace = replace | ||||
| @@ -216,6 +216,20 @@ def check_wordpiece_tokenizer(method): | |||||
| return new_method | return new_method | ||||
| def check_regex_replace(method): | |||||
| """Wrapper method to check the parameter of RegexReplace.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [pattern, replace, replace_all], _ = parse_user_args(method, *args, **kwargs) | |||||
| type_check(pattern, (str,), "pattern") | |||||
| type_check(replace, (str,), "replace") | |||||
| type_check(replace_all, (bool,), "replace_all") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_regex_tokenizer(method): | def check_regex_tokenizer(method): | ||||
| """Wrapper method to check the parameter of RegexTokenizer.""" | """Wrapper method to check the parameter of RegexTokenizer.""" | ||||
| @@ -133,9 +133,9 @@ class _SliceOption(cde.SliceOption): | |||||
| 1. :py:obj:`int`: Slice this index only along the dimension. Negative index is supported. | 1. :py:obj:`int`: Slice this index only along the dimension. Negative index is supported. | ||||
| 2. :py:obj:`list(int)`: Slice these indices along the dimension. Negative indices are supported. | 2. :py:obj:`list(int)`: Slice these indices along the dimension. Negative indices are supported. | ||||
| 3. :py:obj:`slice`: Slice the generated indices from the slice object along the dimension. | 3. :py:obj:`slice`: Slice the generated indices from the slice object along the dimension. | ||||
| 4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing. | |||||
| 5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing. | |||||
| 6. :py:obj:`boolean`: Slice the whole dimension. Similar to `:` in Python indexing. | |||||
| 4. :py:obj:`None`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing. | |||||
| 5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing. | |||||
| 6. :py:obj:`boolean`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing. | |||||
| """ | """ | ||||
| @check_slice_option | @check_slice_option | ||||
| @@ -165,8 +165,8 @@ class Slice(cde.SliceOp): | |||||
| 2. :py:obj:`list(int)`: Slice these indices along the first dimension. Negative indices are supported. | 2. :py:obj:`list(int)`: Slice these indices along the first dimension. Negative indices are supported. | ||||
| 3. :py:obj:`slice`: Slice the generated indices from the slice object along the first dimension. | 3. :py:obj:`slice`: Slice the generated indices from the slice object along the first dimension. | ||||
| Similar to start:stop:step. | Similar to start:stop:step. | ||||
| 4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing. | |||||
| 5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing. | |||||
| 4. :py:obj:`None`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing. | |||||
| 5. :py:obj:`Ellipsis`: Slice the whole dimension, same result with `None`. | |||||
| Examples: | Examples: | ||||
| >>> # Data before | >>> # Data before | ||||
| @@ -271,7 +271,7 @@ class Decode(ImageTensorOperation): | |||||
| img (NumPy), Decoded image. | img (NumPy), Decoded image. | ||||
| """ | """ | ||||
| if not isinstance(img, np.ndarray) or img.ndim != 1 or img.dtype.type is np.str_: | if not isinstance(img, np.ndarray) or img.ndim != 1 or img.dtype.type is np.str_: | ||||
| raise TypeError("Input should be an encoded image with 1-D NumPy type, got {}.".format(type(img))) | |||||
| raise TypeError("Input should be an encoded image in 1-D NumPy format, got {}.".format(type(img))) | |||||
| return super().__call__(img) | return super().__call__(img) | ||||
| def parse(self): | def parse(self): | ||||
| @@ -763,6 +763,14 @@ class RandomCropDecodeResize(ImageTensorOperation): | |||||
| DE_C_INTER_MODE[self.interpolation], | DE_C_INTER_MODE[self.interpolation], | ||||
| self.max_attempts) | self.max_attempts) | ||||
| def __call__(self, img): | |||||
| if not isinstance(img, np.ndarray): | |||||
| raise TypeError("Input should be an encoded image in 1-D NumPy format, got {}.".format(type(img))) | |||||
| if img.ndim != 1 or img.dtype.type is not np.uint8: | |||||
| raise TypeError("Input should be an encoded image with uint8 type in 1-D NumPy format, " + | |||||
| "got format:{}, dtype:{}.".format(type(img), img.dtype.type)) | |||||
| super().__call__(img=img) | |||||
| class RandomCropWithBBox(ImageTensorOperation): | class RandomCropWithBBox(ImageTensorOperation): | ||||
| """ | """ | ||||
| @@ -1164,8 +1172,8 @@ class RandomSharpness(ImageTensorOperation): | |||||
| degree of 1.0 gives the original image, and degree of 2.0 gives a sharpened image. | degree of 1.0 gives the original image, and degree of 2.0 gives a sharpened image. | ||||
| Args: | Args: | ||||
| degrees (tuple, optional): Range of random sharpness adjustment degrees. It should be in (min, max) format. | |||||
| If min=max, then it is a single fixed magnitude operation (default = (0.1, 1.9)). | |||||
| degrees (Union[list, tuple], optional): Range of random sharpness adjustment degrees. It should be in | |||||
| (min, max) format. If min=max, then it is a single fixed magnitude operation (default = (0.1, 1.9)). | |||||
| Raises: | Raises: | ||||
| TypeError : If degrees is not a list or tuple. | TypeError : If degrees is not a list or tuple. | ||||
| @@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation | |||||
| from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ | from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ | ||||
| check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \ | check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \ | ||||
| check_tensor_op, UINT8_MAX, check_value_normalize_std | |||||
| check_c_tensor_op, UINT8_MAX, check_value_normalize_std | |||||
| from .utils import Inter, Border, ImageBatchFormat | from .utils import Inter, Border, ImageBatchFormat | ||||
| @@ -727,7 +727,7 @@ def check_random_select_subpolicy_op(method): | |||||
| raise ValueError("policy[{0}] can not be empty.".format(sub_ind)) | raise ValueError("policy[{0}] can not be empty.".format(sub_ind)) | ||||
| for op_ind, tp in enumerate(sub): | for op_ind, tp in enumerate(sub): | ||||
| check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind)) | check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind)) | ||||
| check_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind)) | |||||
| check_c_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind)) | |||||
| check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind)) | check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind)) | ||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| @@ -43,11 +43,12 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) { | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create objects for the tensor ops | // Create objects for the tensor ops | ||||
| std::shared_ptr<TensorTransform> decode_op = std::make_shared<vision::Decode>(); | |||||
| std::shared_ptr<TensorTransform> random_horizontal_flip_op = std::make_shared<vision::RandomHorizontalFlip>(0.5); | std::shared_ptr<TensorTransform> random_horizontal_flip_op = std::make_shared<vision::RandomHorizontalFlip>(0.5); | ||||
| EXPECT_NE(random_horizontal_flip_op, nullptr); | EXPECT_NE(random_horizontal_flip_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({random_horizontal_flip_op}, {}, {}, {"image"}); | |||||
| ds = ds->Map({decode_op, random_horizontal_flip_op}, {}, {}, {"image"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -95,14 +95,14 @@ def test_eager_exceptions(): | |||||
| img = C.Decode()(img) | img = C.Decode()(img) | ||||
| assert False | assert False | ||||
| except TypeError as e: | except TypeError as e: | ||||
| assert "Input should be an encoded image with 1-D NumPy type" in str(e) | |||||
| assert "Input should be an encoded image in 1-D NumPy format" in str(e) | |||||
| try: | try: | ||||
| img = np.array(["a", "b", "c"]) | img = np.array(["a", "b", "c"]) | ||||
| img = C.Decode()(img) | img = C.Decode()(img) | ||||
| assert False | assert False | ||||
| except TypeError as e: | except TypeError as e: | ||||
| assert "Input should be an encoded image with 1-D NumPy type" in str(e) | |||||
| assert "Input should be an encoded image in 1-D NumPy format" in str(e) | |||||
| try: | try: | ||||
| img = cv2.imread("../data/dataset/apple.jpg") | img = cv2.imread("../data/dataset/apple.jpg") | ||||
| @@ -239,7 +239,7 @@ def test_random_color_c_errors(): | |||||
| with pytest.raises(RuntimeError) as error_info: | with pytest.raises(RuntimeError) as error_info: | ||||
| for _ in enumerate(mnist_ds): | for _ in enumerate(mnist_ds): | ||||
| pass | pass | ||||
| assert "Invalid number of channels in input image" in str(error_info.value) | |||||
| assert "image shape is not <H,W,C> or channel is not 3" in str(error_info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||