| @@ -15,7 +15,7 @@ | |||||
| """grad impl.""" | """grad impl.""" | ||||
| from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ | from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ | ||||
| grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops | |||||
| grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops | |||||
| from .grad_base import get_bprop_fn | from .grad_base import get_bprop_fn | ||||
| __all__ = ['get_bprop_fn'] | __all__ = ['get_bprop_fn'] | ||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """array_ops""" | |||||
| from ..operations import _grad_ops as G | |||||
| from ..operations import _inner_ops as inner | |||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||||
| from .grad_base import bprop_getters | |||||
| @bprop_getters.register(inner.StridedSliceAICPU) | |||||
| def get_bprop_strided_slice_aicpu(self): | |||||
| """Generate bprop for StridedSlice""" | |||||
| input_grad = G.StridedSliceGradAICPU(self.begin_mask, | |||||
| self.end_mask, | |||||
| self.ellipsis_mask, | |||||
| self.new_axis_mask, | |||||
| self.shrink_axis_mask) | |||||
| def bprop(x, begin, end, strides, out, dout): | |||||
| dx = input_grad(dout, shape_op(x), begin, end, strides) | |||||
| return dx, zeros_like(begin), zeros_like(end), zeros_like(strides) | |||||
| return bprop | |||||
| @@ -40,3 +40,5 @@ from .poisson import _poisson_aicpu | |||||
| from .uniform_int import _uniform_int_aicpu | from .uniform_int import _uniform_int_aicpu | ||||
| from .uniform_real import _uniform_real_aicpu | from .uniform_real import _uniform_real_aicpu | ||||
| from .laplace import _laplace_aicpu | from .laplace import _laplace_aicpu | ||||
| from .strided_slice import _strided_slice_aicpu | |||||
| from .strided_slice_grad import _strided_slice_grad_aicpu | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """StridedSlice op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "input", "required") \ | |||||
| .input(1, "begin", "required") \ | |||||
| .input(2, "end", "required") \ | |||||
| .input(3, "stride", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .attr("begin_mask", "int") \ | |||||
| .attr("end_mask", "int") \ | |||||
| .attr("ellipsis_mask", "int") \ | |||||
| .attr("new_axis_mask", "int") \ | |||||
| .attr("shrink_axis_mask", "int") \ | |||||
| .dtype_format(DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(strided_slice_op_info) | |||||
| def _strided_slice_aicpu(): | |||||
| """StridedSlice AiCPU register""" | |||||
| return | |||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """StridedSliceGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||||
| strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .input(0, "dy", "required") \ | |||||
| .input(1, "shape", "required") \ | |||||
| .input(2, "begin", "required") \ | |||||
| .input(3, "end", "required") \ | |||||
| .input(4, "stride", "required") \ | |||||
| .output(0, "output", "required") \ | |||||
| .attr("begin_mask", "int") \ | |||||
| .attr("end_mask", "int") \ | |||||
| .attr("ellipsis_mask", "int") \ | |||||
| .attr("new_axis_mask", "int") \ | |||||
| .attr("shrink_axis_mask", "int") \ | |||||
| .dtype_format(DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.I32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .get_op_info() | |||||
| @op_info_register(strided_slice_grad_op_info) | |||||
| def _strided_slice_grad_aicpu(): | |||||
| """StridedSliceGrad AiCPU register""" | |||||
| return | |||||
| @@ -1110,6 +1110,54 @@ class StridedSliceGrad(PrimitiveWithInfer): | |||||
| 'value': None} | 'value': None} | ||||
| class StridedSliceGradAICPU(PrimitiveWithInfer): | |||||
| """ | |||||
| Performs grad of StridedSlice operation. | |||||
| Args: | |||||
| begin_mask (int): Start indexing the slice. Default: 0. | |||||
| end_mask (int): End indexing the slice. Default: 0. | |||||
| ellipsis_mask (int): An int32 mask. Default: 0. | |||||
| new_axis_mask (int): An int32 mask. Default: 0. | |||||
| shrink_axis_mask (int): An int32 mask. Default: 0. | |||||
| Returns: | |||||
| Tensor, has the same shape of input. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, | |||||
| begin_mask=0, | |||||
| end_mask=0, | |||||
| ellipsis_mask=0, | |||||
| new_axis_mask=0, | |||||
| shrink_axis_mask=0): | |||||
| """init StrideSliceGrad""" | |||||
| validator.check_value_type('begin_mask', begin_mask, [int], self.name) | |||||
| validator.check_value_type('end_mask', end_mask, [int], self.name) | |||||
| validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) | |||||
| validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) | |||||
| validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) | |||||
| self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) | |||||
| def __infer__(self, dy, shapex, begin, end, strides): | |||||
| args = {"dy": dy['dtype']} | |||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| for idx, item in enumerate(shapex['value']): | |||||
| validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) | |||||
| for idx, item in enumerate(begin['value']): | |||||
| validator.check_value_type("begin[%d]" % idx, item, [int], self.name) | |||||
| for idx, item in enumerate(end['value']): | |||||
| validator.check_value_type("end[%d]" % idx, item, [int], self.name) | |||||
| for idx, item in enumerate(strides['value']): | |||||
| validator.check_value_type("strides[%d]" % idx, item, [int], self.name) | |||||
| return {'shape': shapex['value'], | |||||
| 'dtype': dy['dtype'], | |||||
| 'value': None} | |||||
| class SoftplusGrad(PrimitiveWithInfer): | class SoftplusGrad(PrimitiveWithInfer): | ||||
| """Computes gradient for the Log Softmax activation.""" | """Computes gradient for the Log Softmax activation.""" | ||||
| @@ -21,6 +21,137 @@ from ...common import dtype as mstype | |||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | from ..primitive import PrimitiveWithInfer, prim_attr_register | ||||
| class StridedSliceAICPU(PrimitiveWithInfer): | |||||
| r""" | |||||
| Extracts a strided slice of a tensor. | |||||
| Given an input tensor, this operation inserts a dimension of length 1 at the dimension. | |||||
| This operation extracts a fragment of size (end-begin)/stride from the given | |||||
| 'input_tensor'. Starting from the position specified by the begin, the fragment | |||||
| continues adding stride to the index until all dimensions are not less than end. | |||||
| Note: | |||||
| The stride may be negative value, which causes reverse slicing. | |||||
| The shape of `begin`, `end` and `strides` should be the same. | |||||
| Args: | |||||
| begin_mask (int): Starting index of the slice. Default: 0. | |||||
| end_mask (int): Ending index of the slice. Default: 0. | |||||
| ellipsis_mask (int): An int mask. Default: 0. | |||||
| new_axis_mask (int): An int mask. Default: 0. | |||||
| shrink_axis_mask (int): An int mask. Default: 0. | |||||
| Currently all the masks are not in used. Use default 0 only. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input Tensor. | |||||
| - **begin** (tuple[int]) - A tuple which represents the location where to start. Only | |||||
| constant value is allowed. | |||||
| - **end** (tuple[int]) - A tuple or which represents the maximum location where to stop. | |||||
| Only constant value is allowed. | |||||
| - **strides** (tuple[int]) - A tuple which represents the stride continuously added | |||||
| before reach the maximum location. Only constant value is allowed. | |||||
| Outputs: | |||||
| Tensor. | |||||
| Explain with the following example. | |||||
| - In the 0th dim, begin is 1, end is 2, and strides is 1, | |||||
| because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`. | |||||
| Thus, return the element with :math:`index = 1` in 0th dim, i.e., [[3, 3, 3], [4, 4, 4]]. | |||||
| - In the 1st dim, similarly, the interval is :math:`[0,1)`. | |||||
| Based on the return value of the 0th dim, return the element with :math:`index = 0`, | |||||
| i.e., [3, 3, 3]. | |||||
| - In the 2nd dim, similarly, the interval is :math:`[0,3)`. | |||||
| Based on the return value of the 1st dim, return the element with :math:`index = 0,1,2`, | |||||
| i.e., [3, 3, 3]. | |||||
| - Finally, the output is [3, 3, 3]. | |||||
| Examples | |||||
| >>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], | |||||
| >>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32) | |||||
| >>> slice = P.StridedSliceAICPU() | |||||
| >>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 2)) | |||||
| >>> output.shape | |||||
| (1, 1, 2) | |||||
| >>> output | |||||
| [[[3, 3]]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, | |||||
| begin_mask=0, | |||||
| end_mask=0, | |||||
| ellipsis_mask=0, | |||||
| new_axis_mask=0, | |||||
| shrink_axis_mask=0): | |||||
| """init StrideSlice""" | |||||
| self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) | |||||
| validator.check_value_type('begin_mask', begin_mask, [int], self.name) | |||||
| validator.check_value_type('end_mask', end_mask, [int], self.name) | |||||
| validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) | |||||
| validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) | |||||
| validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) | |||||
| def __infer__(self, x, begin, end, strides): | |||||
| begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] | |||||
| validator.check_value_type("begin", begin_v, [tuple], self.name) | |||||
| validator.check_value_type("end", end_v, [tuple], self.name) | |||||
| validator.check_value_type("strides", strides_v, [tuple], self.name) | |||||
| x_shape = x['shape'] | |||||
| x_shp_len = len(x_shape) | |||||
| if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len: | |||||
| raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and " | |||||
| f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.") | |||||
| ret_shape = [] | |||||
| append_dimensions = [] | |||||
| shrink_pos = bin(self.shrink_axis_mask)[::-1] | |||||
| new_pos = bin(self.new_axis_mask)[::-1] | |||||
| for i in range(x_shp_len): | |||||
| # After the integer is converted to binary, it is a str and the first two chars are the flag char '0b' | |||||
| if i < (len(new_pos) - 2) and new_pos[i] == '1': | |||||
| ret_shape.append(1) | |||||
| append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)]) | |||||
| continue | |||||
| if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1': | |||||
| validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name) | |||||
| validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name) | |||||
| continue | |||||
| begin_idx = begin_v[i] | |||||
| end_idx = end_v[i] | |||||
| strides_idx = strides_v[i] | |||||
| if self.begin_mask: | |||||
| begin_idx = 0 | |||||
| if self.end_mask: | |||||
| end_idx = x_shape[i] | |||||
| validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name) | |||||
| validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name) | |||||
| validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name) | |||||
| if strides_idx > 0: | |||||
| # If sliced forward , end_idx >= begin_idx | |||||
| validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE) | |||||
| if begin_idx < 0 < end_idx: | |||||
| # Turn negative begin_idx into positive values | |||||
| begin_idx = x_shape[i] + begin_idx | |||||
| num_elems = (end_idx - begin_idx + strides_idx - 1) // strides_idx | |||||
| else: | |||||
| # If sliced backwards, end_idx <= begin_idx | |||||
| validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.GE) | |||||
| if end_idx < 0 < begin_idx: | |||||
| # Turn negative end_idx into positive values | |||||
| end_idx = x_shape[i] + end_idx | |||||
| num_elems = (end_idx - begin_idx + strides_idx + 1) // strides_idx | |||||
| ret_shape.append(num_elems) | |||||
| if append_dimensions: | |||||
| ret_shape += append_dimensions[::-1] | |||||
| return {'shape': ret_shape, | |||||
| 'dtype': x['dtype'], | |||||
| 'value': None} | |||||
| class ExtractImagePatches(PrimitiveWithInfer): | class ExtractImagePatches(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Extract patches from images. | Extract patches from images. | ||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, begin, end, strides): | |||||
| super(Net, self).__init__() | |||||
| self.strided_slice = inner.StridedSliceAICPU() | |||||
| self.begin = begin | |||||
| self.end = end | |||||
| self.strides = strides | |||||
| def construct(self, input): | |||||
| return self.strided_slice(input, self.begin, self.end, self.strides) | |||||
| input_x = np.array([[[0, 1, 2], [3, 4, 5]], | |||||
| [[6, 7, 8], [9, 10, 11]], | |||||
| [[12, 13, 14], [15, 16, 17]] | |||||
| ]).astype(np.float32) | |||||
| begin = (1, 0, 0) | |||||
| end = (2, 2, 3) | |||||
| strides = (1, 1, 2) | |||||
| def test_net(): | |||||
| net = Net(begin, end, strides) | |||||
| tinput = Tensor(input_x) | |||||
| output = net(tinput) | |||||
| print(output.asnumpy()) | |||||
| assert np.all([[[6, 8], [9, 11]]] == output.asnumpy()) | |||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, shape_x, begin, end, strides): | |||||
| super(Net, self).__init__() | |||||
| self.strided_slice_grad = G.StridedSliceGradAICPU() | |||||
| self.shape_x = shape_x | |||||
| self.begin = begin | |||||
| self.end = end | |||||
| self.strides = strides | |||||
| def construct(self, dy): | |||||
| return self.strided_slice_grad(dy, self.shape_x, self.begin, self.end, self.strides) | |||||
| dy = np.array([[[6, 8], [9, 11]]]).astype(np.float32) | |||||
| shape_x = (3, 2, 3) | |||||
| begin = (1, 0, 0) | |||||
| end = (2, 2, 3) | |||||
| strides = (1, 1, 2) | |||||
| def test_net(): | |||||
| net = Net(shape_x, begin, end, strides) | |||||
| tdy = Tensor(dy) | |||||
| output = net(tdy) | |||||
| print(output.asnumpy()) | |||||
| assert np.all([[[0, 0, 0], [0, 0, 0]], | |||||
| [[6, 0, 8], [9, 0, 11]], | |||||
| [[0, 0, 0], [0, 0, 0]] | |||||
| ] == output.asnumpy()) | |||||