From aa4bf9a3f215f4c453f3a87d6c33dca88295657d Mon Sep 17 00:00:00 2001 From: luoyang Date: Fri, 27 Nov 2020 22:06:26 +0800 Subject: [PATCH] [MD] Support python-eager: Resize, Rescale, Normalize, HWC2CHW, Pad --- .../ccsrc/minddata/dataset/api/CMakeLists.txt | 1 + .../dataset/include/execute_binding.cc | 37 +++++++ .../dataset/api/python/pybind_conversion.cc | 12 +++ .../dataset/api/python/pybind_conversion.h | 2 + mindspore/dataset/engine/datasets.py | 2 +- .../dataset/transforms/py_transforms_util.py | 8 +- mindspore/dataset/vision/c_transforms.py | 89 ++++++++++++++- tests/ut/python/dataset/test_compose.py | 6 +- tests/ut/python/dataset/test_eager_vision.py | 102 ++++++++++++++++++ 9 files changed, 247 insertions(+), 12 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/execute_binding.cc create mode 100644 tests/ut/python/dataset/test_eager_vision.py diff --git a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt index e298c08290..1f213c7620 100644 --- a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt @@ -6,6 +6,7 @@ if (ENABLE_PYTHON) python/pybind_conversion.cc python/bindings/dataset/include/datasets_bindings.cc python/bindings/dataset/include/iterator_bindings.cc + python/bindings/dataset/include/execute_binding.cc python/bindings/dataset/include/schema_bindings.cc python/bindings/dataset/engine/cache/bindings.cc python/bindings/dataset/core/bindings.cc diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/execute_binding.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/execute_binding.cc new file mode 100644 index 0000000000..0f9111b652 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/execute_binding.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pybind11/pybind11.h" + +#include "minddata/dataset/api/python/pybind_conversion.h" +#include "minddata/dataset/api/python/pybind_register.h" +#include "minddata/dataset/include/execute.h" + +namespace mindspore { +namespace dataset { + +PYBIND_REGISTER(Execute, 0, ([](const py::module *m) { + (void)py::class_>(*m, "Execute") + .def(py::init([](py::object operation) { + auto execute = std::make_shared(toTensorOperation(operation)); + return execute; + })) + .def("__call__", [](Execute &self, std::shared_ptr in) { + std::shared_ptr out = self(in); + return out; + }); + })); +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc index 3e1b223db1..ad4eb018a4 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.cc @@ -104,6 +104,18 @@ std::vector> toTensorOperations(std::optional

toTensorOperation(py::handle operation) { + std::shared_ptr op; + std::shared_ptr tensor_op; + if (py::isinstance(operation)) { + tensor_op = operation.cast>(); + } else { + THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op."); }()); + } + op = std::make_shared(tensor_op); + return op; +} + std::vector> toDatasetNode(std::shared_ptr self, py::list datasets) { std::vector> vector; vector.push_back(self); diff --git a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.h b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.h index 9a2f03843e..2b6788affa 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.h +++ b/mindspore/ccsrc/minddata/dataset/api/python/pybind_conversion.h @@ -59,6 +59,8 @@ std::vector> toPairVector(const py::list list); std::vector> toTensorOperations(std::optional operations); +std::shared_ptr toTensorOperation(py::handle operation); + std::vector> toDatasetNode(std::shared_ptr self, py::list datasets); std::shared_ptr toSamplerObj(std::optional py_sampler, bool isMindDataset = false); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 2a9e4a4bc7..24ceced228 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2218,7 +2218,7 @@ class MapDataset(Dataset): # wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations new_ops, start_ind, end_ind = [], 0, 0 for i, op in enumerate(operations): - if not callable(op): + if str(op).find("c_transform") >= 0: # reset counts if start_ind != end_ind: new_ops.append(py_transforms.Compose(operations[start_ind:end_ind])) diff --git a/mindspore/dataset/transforms/py_transforms_util.py b/mindspore/dataset/transforms/py_transforms_util.py index ed5eba5eb1..c164fab851 100644 --- a/mindspore/dataset/transforms/py_transforms_util.py +++ b/mindspore/dataset/transforms/py_transforms_util.py @@ -36,11 +36,11 @@ def compose(transforms, *args): Compose a list of transforms and apply on the image. Args: - img (numpy.ndarray): An image in Numpy ndarray. + img (numpy.ndarray): An image in NumPy ndarray. transforms (list): A list of transform Class objects to be composed. Returns: - img (numpy.ndarray), An augmented image in Numpy ndarray. + img (numpy.ndarray), An augmented image in NumPy ndarray. """ if all_numpy(args): for transform in transforms: @@ -49,8 +49,8 @@ def compose(transforms, *args): if all_numpy(args): return args - raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms.'.format(type(args))) - raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args))) + raise TypeError('args should be NumPy ndarray. Got {}. Append ToTensor() to transforms.'.format(type(args))) + raise TypeError('args should be NumPy ndarray. Got {}.'.format(type(args))) def one_hot_encoding(label, num_classes, epsilon): diff --git a/mindspore/dataset/vision/c_transforms.py b/mindspore/dataset/vision/c_transforms.py index 6a61b9d1cf..c2b822c9ab 100644 --- a/mindspore/dataset/vision/c_transforms.py +++ b/mindspore/dataset/vision/c_transforms.py @@ -44,6 +44,8 @@ Examples: >>> data1 = data1.map(operations=onehot_op, input_columns="label") """ import numbers +import numpy as np +from PIL import Image import mindspore._c_dataengine as cde from .utils import Inter, Border, ImageBatchFormat @@ -280,6 +282,22 @@ class Normalize(cde.NormalizeOp): self.std = std super().__init__(*mean, *std) + def __call__(self, img): + """ + Call method. + + Args: + img (NumPy or PIL image): Image array to be normalized. + + Returns: + img (NumPy), Normalized Image array. + """ + if not isinstance(img, (np.ndarray, Image.Image)): + raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) + normalize = cde.Execute(cde.NormalizeOp(*self.mean, *self.std)) + img = normalize(cde.Tensor(np.asarray(img))) + return img.as_array() + class RandomAffine(cde.RandomAffineOp): """ @@ -676,13 +694,29 @@ class Resize(cde.ResizeOp): @check_resize_interpolation def __init__(self, size, interpolation=Inter.LINEAR): + if isinstance(size, int): + size = (size, 0) self.size = size self.interpolation = interpolation interpoltn = DE_C_INTER_MODE[interpolation] - if isinstance(size, int): - size = (size, 0) super().__init__(*size, interpoltn) + def __call__(self, img): + """ + Call method. + + Args: + img (NumPy or PIL image): Image to be resized. + + Returns: + img (NumPy), Resized image. + """ + if not isinstance(img, (np.ndarray, Image.Image)): + raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) + resize = cde.Execute(cde.ResizeOp(*self.size, DE_C_INTER_MODE[self.interpolation])) + img = resize(cde.Tensor(np.asarray(img))) + return img.as_array() + class ResizeWithBBox(cde.ResizeWithBBoxOp): """ @@ -995,6 +1029,22 @@ class Rescale(cde.RescaleOp): self.shift = shift super().__init__(rescale, shift) + def __call__(self, img): + """ + Call method. + + Args: + img (NumPy or PIL image): Image to be rescaled. + + Returns: + img (NumPy), Rescaled image. + """ + if not isinstance(img, (np.ndarray, Image.Image)): + raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) + rescale = cde.Execute(cde.RescaleOp(self.rescale, self.shift)) + img = rescale(cde.Tensor(np.asarray(img))) + return img.as_array() + class RandomResize(cde.RandomResizeOp): """ @@ -1067,6 +1117,22 @@ class HWC2CHW(cde.ChannelSwapOp): >>> data1 = data1.map(operations=transforms_list, input_columns=["image"]) """ + def __call__(self, img): + """ + Call method. + + Args: + img (NumPy or PIL image): Image array, of shape (H, W, C), to have channels swapped. + + Returns: + img (NumPy), Image array, of shape (C, H, W), with channels swapped. + """ + if not isinstance(img, (np.ndarray, Image.Image)): + raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) + hwc2chw = cde.Execute(cde.ChannelSwapOp()) + img = hwc2chw(cde.Tensor(np.asarray(img))) + return img.as_array() + class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp): """ @@ -1156,13 +1222,28 @@ class Pad(cde.PadOp): padding = parse_padding(padding) if isinstance(fill_value, int): fill_value = tuple([fill_value] * 3) - padding_mode = DE_C_BORDER_TYPE[padding_mode] - self.padding = padding self.fill_value = fill_value self.padding_mode = padding_mode + padding_mode = DE_C_BORDER_TYPE[padding_mode] super().__init__(*padding, padding_mode, *fill_value) + def __call__(self, img): + """ + Call method. + + Args: + img (NumPy or PIL image): Image to be padded. + + Returns: + img (NumPy), Padded image. + """ + if not isinstance(img, (np.ndarray, Image.Image)): + raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) + pad = cde.Execute(cde.PadOp(*self.padding, DE_C_BORDER_TYPE[self.padding_mode], *self.fill_value)) + img = pad(cde.Tensor(np.asarray(img))) + return img.as_array() + class UniformAugment(cde.UniformAugOp): """ diff --git a/tests/ut/python/dataset/test_compose.py b/tests/ut/python/dataset/test_compose.py index f7650f6ce3..1fa2708e64 100644 --- a/tests/ut/python/dataset/test_compose.py +++ b/tests/ut/python/dataset/test_compose.py @@ -235,15 +235,15 @@ def test_py_transforms_with_c_vision(): return res with pytest.raises(ValueError) as error_info: - test_config(py_transforms.RandomApply([c_vision.Resize(200)])) + test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)])) assert "transforms[0] is not callable." in str(error_info.value) with pytest.raises(ValueError) as error_info: - test_config(py_transforms.RandomChoice([c_vision.Resize(200)])) + test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)])) assert "transforms[0] is not callable." in str(error_info.value) with pytest.raises(ValueError) as error_info: - test_config(py_transforms.RandomOrder([np.array, c_vision.Resize(200)])) + test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)])) assert "transforms[1] is not callable." in str(error_info.value) with pytest.raises(RuntimeError) as error_info: diff --git a/tests/ut/python/dataset/test_eager_vision.py b/tests/ut/python/dataset/test_eager_vision.py new file mode 100644 index 0000000000..f9d6535645 --- /dev/null +++ b/tests/ut/python/dataset/test_eager_vision.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import cv2 +from PIL import Image +import mindspore.dataset.vision.c_transforms as C +from mindspore import log as logger + +def test_eager_resize(): + img = cv2.imread("../data/dataset/apple.jpg") + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + + img = C.Resize(size=(32, 32))(img) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + + assert img.shape == (32, 32, 3) + +def test_eager_rescale(): + img = cv2.imread("../data/dataset/apple.jpg") + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + pixel = img[0][0][0] + + rescale_factor = 0.5 + img = C.Rescale(rescale=rescale_factor, shift=0)(img) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + pixel_rescaled = img[0][0][0] + + assert pixel*rescale_factor == pixel_rescaled + +def test_eager_normalize(): + img = Image.open("../data/dataset/apple.jpg").convert("RGB") + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) + pixel = img.getpixel((0, 0))[0] + + mean_vec = [100, 100, 100] + std_vec = [2, 2, 2] + img = C.Normalize(mean=mean_vec, std=std_vec)(img) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + pixel_normalized = img[0][0][0] + + assert (pixel - mean_vec[0]) / std_vec[0] == pixel_normalized + +def test_eager_HWC2CHW(): + img = cv2.imread("../data/dataset/apple.jpg") + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + channel = img.shape + + img = C.HWC2CHW()(img) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) + channel_swaped = img.shape + + assert channel == (channel_swaped[1], channel_swaped[2], channel_swaped[0]) + +def test_eager_pad(): + img = Image.open("../data/dataset/apple.jpg").convert("RGB") + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) + + img = C.Resize(size=(32, 32))(img) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) + size = img.shape + + pad = 4 + img = C.Pad(padding=pad)(img) + logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) + size_padded = img.shape + + assert size_padded == (size[0] + 2 * pad, size[1] + 2 * pad, size[2]) + +def test_eager_exceptions(): + try: + img = cv2.imread("../data/dataset/apple.jpg") + img = C.Resize(size=(-32, 32))(img) + assert False + except ValueError as e: + assert "not within the required interval" in str(e) + + try: + img = "../data/dataset/apple.jpg" + img = C.Pad(padding=4)(img) + assert False + except TypeError as e: + assert "Input should be NumPy or PIL image" in str(e) + +if __name__ == '__main__': + test_eager_resize() + test_eager_rescale() + test_eager_normalize() + test_eager_HWC2CHW() + test_eager_pad() + test_eager_exceptions() + \ No newline at end of file