| @@ -6,6 +6,7 @@ if (ENABLE_PYTHON) | |||
| python/pybind_conversion.cc | |||
| python/bindings/dataset/include/datasets_bindings.cc | |||
| python/bindings/dataset/include/iterator_bindings.cc | |||
| python/bindings/dataset/include/execute_binding.cc | |||
| python/bindings/dataset/include/schema_bindings.cc | |||
| python/bindings/dataset/engine/cache/bindings.cc | |||
| python/bindings/dataset/core/bindings.cc | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind11/pybind11.h" | |||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/include/execute.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER(Execute, 0, ([](const py::module *m) { | |||
| (void)py::class_<Execute, std::shared_ptr<Execute>>(*m, "Execute") | |||
| .def(py::init([](py::object operation) { | |||
| auto execute = std::make_shared<Execute>(toTensorOperation(operation)); | |||
| return execute; | |||
| })) | |||
| .def("__call__", [](Execute &self, std::shared_ptr<Tensor> in) { | |||
| std::shared_ptr<Tensor> out = self(in); | |||
| return out; | |||
| }); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -104,6 +104,18 @@ std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<p | |||
| return vector; | |||
| } | |||
| std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation) { | |||
| std::shared_ptr<TensorOperation> op; | |||
| std::shared_ptr<TensorOp> tensor_op; | |||
| if (py::isinstance<TensorOp>(operation)) { | |||
| tensor_op = operation.cast<std::shared_ptr<TensorOp>>(); | |||
| } else { | |||
| THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: input operation is not a tensor_op."); }()); | |||
| } | |||
| op = std::make_shared<transforms::PreBuiltOperation>(tensor_op); | |||
| return op; | |||
| } | |||
| std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) { | |||
| std::vector<std::shared_ptr<DatasetNode>> vector; | |||
| vector.push_back(self); | |||
| @@ -59,6 +59,8 @@ std::vector<std::pair<int, int>> toPairVector(const py::list list); | |||
| std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations); | |||
| std::shared_ptr<TensorOperation> toTensorOperation(py::handle operation); | |||
| std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets); | |||
| std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false); | |||
| @@ -2218,7 +2218,7 @@ class MapDataset(Dataset): | |||
| # wraps adjacent Python operations in a Compose to allow mixing of Python and C++ operations | |||
| new_ops, start_ind, end_ind = [], 0, 0 | |||
| for i, op in enumerate(operations): | |||
| if not callable(op): | |||
| if str(op).find("c_transform") >= 0: | |||
| # reset counts | |||
| if start_ind != end_ind: | |||
| new_ops.append(py_transforms.Compose(operations[start_ind:end_ind])) | |||
| @@ -36,11 +36,11 @@ def compose(transforms, *args): | |||
| Compose a list of transforms and apply on the image. | |||
| Args: | |||
| img (numpy.ndarray): An image in Numpy ndarray. | |||
| img (numpy.ndarray): An image in NumPy ndarray. | |||
| transforms (list): A list of transform Class objects to be composed. | |||
| Returns: | |||
| img (numpy.ndarray), An augmented image in Numpy ndarray. | |||
| img (numpy.ndarray), An augmented image in NumPy ndarray. | |||
| """ | |||
| if all_numpy(args): | |||
| for transform in transforms: | |||
| @@ -49,8 +49,8 @@ def compose(transforms, *args): | |||
| if all_numpy(args): | |||
| return args | |||
| raise TypeError('args should be Numpy ndarray. Got {}. Append ToTensor() to transforms.'.format(type(args))) | |||
| raise TypeError('args should be Numpy ndarray. Got {}.'.format(type(args))) | |||
| raise TypeError('args should be NumPy ndarray. Got {}. Append ToTensor() to transforms.'.format(type(args))) | |||
| raise TypeError('args should be NumPy ndarray. Got {}.'.format(type(args))) | |||
| def one_hot_encoding(label, num_classes, epsilon): | |||
| @@ -44,6 +44,8 @@ Examples: | |||
| >>> data1 = data1.map(operations=onehot_op, input_columns="label") | |||
| """ | |||
| import numbers | |||
| import numpy as np | |||
| from PIL import Image | |||
| import mindspore._c_dataengine as cde | |||
| from .utils import Inter, Border, ImageBatchFormat | |||
| @@ -280,6 +282,22 @@ class Normalize(cde.NormalizeOp): | |||
| self.std = std | |||
| super().__init__(*mean, *std) | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| Args: | |||
| img (NumPy or PIL image): Image array to be normalized. | |||
| Returns: | |||
| img (NumPy), Normalized Image array. | |||
| """ | |||
| if not isinstance(img, (np.ndarray, Image.Image)): | |||
| raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) | |||
| normalize = cde.Execute(cde.NormalizeOp(*self.mean, *self.std)) | |||
| img = normalize(cde.Tensor(np.asarray(img))) | |||
| return img.as_array() | |||
| class RandomAffine(cde.RandomAffineOp): | |||
| """ | |||
| @@ -676,13 +694,29 @@ class Resize(cde.ResizeOp): | |||
| @check_resize_interpolation | |||
| def __init__(self, size, interpolation=Inter.LINEAR): | |||
| if isinstance(size, int): | |||
| size = (size, 0) | |||
| self.size = size | |||
| self.interpolation = interpolation | |||
| interpoltn = DE_C_INTER_MODE[interpolation] | |||
| if isinstance(size, int): | |||
| size = (size, 0) | |||
| super().__init__(*size, interpoltn) | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| Args: | |||
| img (NumPy or PIL image): Image to be resized. | |||
| Returns: | |||
| img (NumPy), Resized image. | |||
| """ | |||
| if not isinstance(img, (np.ndarray, Image.Image)): | |||
| raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) | |||
| resize = cde.Execute(cde.ResizeOp(*self.size, DE_C_INTER_MODE[self.interpolation])) | |||
| img = resize(cde.Tensor(np.asarray(img))) | |||
| return img.as_array() | |||
| class ResizeWithBBox(cde.ResizeWithBBoxOp): | |||
| """ | |||
| @@ -995,6 +1029,22 @@ class Rescale(cde.RescaleOp): | |||
| self.shift = shift | |||
| super().__init__(rescale, shift) | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| Args: | |||
| img (NumPy or PIL image): Image to be rescaled. | |||
| Returns: | |||
| img (NumPy), Rescaled image. | |||
| """ | |||
| if not isinstance(img, (np.ndarray, Image.Image)): | |||
| raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) | |||
| rescale = cde.Execute(cde.RescaleOp(self.rescale, self.shift)) | |||
| img = rescale(cde.Tensor(np.asarray(img))) | |||
| return img.as_array() | |||
| class RandomResize(cde.RandomResizeOp): | |||
| """ | |||
| @@ -1067,6 +1117,22 @@ class HWC2CHW(cde.ChannelSwapOp): | |||
| >>> data1 = data1.map(operations=transforms_list, input_columns=["image"]) | |||
| """ | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| Args: | |||
| img (NumPy or PIL image): Image array, of shape (H, W, C), to have channels swapped. | |||
| Returns: | |||
| img (NumPy), Image array, of shape (C, H, W), with channels swapped. | |||
| """ | |||
| if not isinstance(img, (np.ndarray, Image.Image)): | |||
| raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) | |||
| hwc2chw = cde.Execute(cde.ChannelSwapOp()) | |||
| img = hwc2chw(cde.Tensor(np.asarray(img))) | |||
| return img.as_array() | |||
| class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp): | |||
| """ | |||
| @@ -1156,13 +1222,28 @@ class Pad(cde.PadOp): | |||
| padding = parse_padding(padding) | |||
| if isinstance(fill_value, int): | |||
| fill_value = tuple([fill_value] * 3) | |||
| padding_mode = DE_C_BORDER_TYPE[padding_mode] | |||
| self.padding = padding | |||
| self.fill_value = fill_value | |||
| self.padding_mode = padding_mode | |||
| padding_mode = DE_C_BORDER_TYPE[padding_mode] | |||
| super().__init__(*padding, padding_mode, *fill_value) | |||
| def __call__(self, img): | |||
| """ | |||
| Call method. | |||
| Args: | |||
| img (NumPy or PIL image): Image to be padded. | |||
| Returns: | |||
| img (NumPy), Padded image. | |||
| """ | |||
| if not isinstance(img, (np.ndarray, Image.Image)): | |||
| raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img))) | |||
| pad = cde.Execute(cde.PadOp(*self.padding, DE_C_BORDER_TYPE[self.padding_mode], *self.fill_value)) | |||
| img = pad(cde.Tensor(np.asarray(img))) | |||
| return img.as_array() | |||
| class UniformAugment(cde.UniformAugOp): | |||
| """ | |||
| @@ -235,15 +235,15 @@ def test_py_transforms_with_c_vision(): | |||
| return res | |||
| with pytest.raises(ValueError) as error_info: | |||
| test_config(py_transforms.RandomApply([c_vision.Resize(200)])) | |||
| test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)])) | |||
| assert "transforms[0] is not callable." in str(error_info.value) | |||
| with pytest.raises(ValueError) as error_info: | |||
| test_config(py_transforms.RandomChoice([c_vision.Resize(200)])) | |||
| test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)])) | |||
| assert "transforms[0] is not callable." in str(error_info.value) | |||
| with pytest.raises(ValueError) as error_info: | |||
| test_config(py_transforms.RandomOrder([np.array, c_vision.Resize(200)])) | |||
| test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)])) | |||
| assert "transforms[1] is not callable." in str(error_info.value) | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| @@ -0,0 +1,102 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import cv2 | |||
| from PIL import Image | |||
| import mindspore.dataset.vision.c_transforms as C | |||
| from mindspore import log as logger | |||
| def test_eager_resize(): | |||
| img = cv2.imread("../data/dataset/apple.jpg") | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) | |||
| img = C.Resize(size=(32, 32))(img) | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) | |||
| assert img.shape == (32, 32, 3) | |||
| def test_eager_rescale(): | |||
| img = cv2.imread("../data/dataset/apple.jpg") | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) | |||
| pixel = img[0][0][0] | |||
| rescale_factor = 0.5 | |||
| img = C.Rescale(rescale=rescale_factor, shift=0)(img) | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) | |||
| pixel_rescaled = img[0][0][0] | |||
| assert pixel*rescale_factor == pixel_rescaled | |||
| def test_eager_normalize(): | |||
| img = Image.open("../data/dataset/apple.jpg").convert("RGB") | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) | |||
| pixel = img.getpixel((0, 0))[0] | |||
| mean_vec = [100, 100, 100] | |||
| std_vec = [2, 2, 2] | |||
| img = C.Normalize(mean=mean_vec, std=std_vec)(img) | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) | |||
| pixel_normalized = img[0][0][0] | |||
| assert (pixel - mean_vec[0]) / std_vec[0] == pixel_normalized | |||
| def test_eager_HWC2CHW(): | |||
| img = cv2.imread("../data/dataset/apple.jpg") | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) | |||
| channel = img.shape | |||
| img = C.HWC2CHW()(img) | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) | |||
| channel_swaped = img.shape | |||
| assert channel == (channel_swaped[1], channel_swaped[2], channel_swaped[0]) | |||
| def test_eager_pad(): | |||
| img = Image.open("../data/dataset/apple.jpg").convert("RGB") | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) | |||
| img = C.Resize(size=(32, 32))(img) | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) | |||
| size = img.shape | |||
| pad = 4 | |||
| img = C.Pad(padding=pad)(img) | |||
| logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.size)) | |||
| size_padded = img.shape | |||
| assert size_padded == (size[0] + 2 * pad, size[1] + 2 * pad, size[2]) | |||
| def test_eager_exceptions(): | |||
| try: | |||
| img = cv2.imread("../data/dataset/apple.jpg") | |||
| img = C.Resize(size=(-32, 32))(img) | |||
| assert False | |||
| except ValueError as e: | |||
| assert "not within the required interval" in str(e) | |||
| try: | |||
| img = "../data/dataset/apple.jpg" | |||
| img = C.Pad(padding=4)(img) | |||
| assert False | |||
| except TypeError as e: | |||
| assert "Input should be NumPy or PIL image" in str(e) | |||
| if __name__ == '__main__': | |||
| test_eager_resize() | |||
| test_eager_rescale() | |||
| test_eager_normalize() | |||
| test_eager_HWC2CHW() | |||
| test_eager_pad() | |||
| test_eager_exceptions() | |||