| @@ -47,6 +47,7 @@ const char TRANSPOSE_NO = 'N'; | |||||
| const char TRANSPOSE_YES = 'T'; | const char TRANSPOSE_YES = 'T'; | ||||
| const char AXIS[] = "axis"; | const char AXIS[] = "axis"; | ||||
| const char BEGIN[] = "begin"; | const char BEGIN[] = "begin"; | ||||
| const char END[] = "end"; | |||||
| const char SIZE[] = "size"; | const char SIZE[] = "size"; | ||||
| class CPUKernel : public kernel::KernelMod { | class CPUKernel : public kernel::KernelMod { | ||||
| @@ -21,31 +21,53 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| CheckParam(kernel_node); | CheckParam(kernel_node); | ||||
| begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN); | |||||
| size_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE); | |||||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| if (input_shape_.size() < 4) { | |||||
| for (size_t i = 0; i < 4 - input_shape_.size(); ++i) { | |||||
| input_shape_.insert(input_shape_.begin(), 1); | |||||
| begin_.insert(begin_.begin(), 0); | |||||
| size_.insert(size_.begin(), 1); | |||||
| } | |||||
| } | |||||
| output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| CPUKernelUtils::ExpandDimsTo4(&output_shape_); | CPUKernelUtils::ExpandDimsTo4(&output_shape_); | ||||
| begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN); | |||||
| for (size_t i = 0; i < begin_.size(); i++) { | for (size_t i = 0; i < begin_.size(); i++) { | ||||
| if (begin_[i] < 0) { | if (begin_[i] < 0) { | ||||
| begin_[i] = begin_[i] + input_shape_[i]; | begin_[i] = begin_[i] + input_shape_[i]; | ||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < size_.size(); i++) { | |||||
| if (size_[i] < 0) { | |||||
| size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto strides = prim->GetAttr(STRIDES); | |||||
| if (strides != nullptr) { | |||||
| strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES); | |||||
| end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END); | |||||
| if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; | |||||
| } | |||||
| for (size_t i = 0; i < strides_.size(); ++i) { | |||||
| if (strides_[i] < 0) { | |||||
| strides_[i] = (strides_[i] + input_shape_[i]) > 0 ? (strides_[i] + input_shape_[i]) : 0; | |||||
| } | |||||
| if (end_[i] < 0) { | |||||
| end_[i] = (end_[i] + input_shape_[i]) > 0 ? (end_[i] + input_shape_[i]) : 0; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE); | |||||
| if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; | |||||
| } | |||||
| for (size_t i = 0; i < sizes.size(); ++i) { | |||||
| if (sizes[i] < 0) { | |||||
| sizes[i] = (sizes[i] + input_shape_[i]) > 0 ? (sizes[i] + input_shape_[i]) : 0; | |||||
| } | |||||
| strides_.emplace_back(1); | |||||
| end_.emplace_back(begin_[i] + sizes[i]); | |||||
| } | |||||
| } | |||||
| auto input_len = input_shape_.size(); | |||||
| if (input_len < 4) { | |||||
| for (size_t i = 0; i < 4 - input_len; ++i) { | |||||
| input_shape_.insert(input_shape_.begin(), 1); | |||||
| begin_.insert(begin_.begin(), 0); | |||||
| strides_.insert(strides_.begin(), 1); | |||||
| end_.insert(end_.begin(), 1); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -56,10 +78,10 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | ||||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | ||||
| for (int i = begin_[0]; i < begin_[0] + size_[0]; ++i) { | |||||
| for (int j = begin_[1]; j < begin_[1] + size_[1]; ++j) { | |||||
| for (int k = begin_[2]; k < begin_[2] + size_[2]; ++k) { | |||||
| for (int m = begin_[3]; m < begin_[3] + size_[3]; ++m) { | |||||
| for (int i = begin_[0]; i < end_[0]; i += strides_[0]) { | |||||
| for (int j = begin_[1]; j < end_[1]; j += strides_[1]) { | |||||
| for (int k = begin_[2]; k < end_[2]; k += strides_[2]) { | |||||
| for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { | |||||
| auto offset = CPUKernelUtils::CalcOffset(input_shape_, i, j, k, m); | auto offset = CPUKernelUtils::CalcOffset(input_shape_, i, j, k, m); | ||||
| *output_addr++ = input_addr[offset]; | *output_addr++ = input_addr[offset]; | ||||
| } | } | ||||
| @@ -35,13 +35,16 @@ class SliceCPUKernel : public CPUKernel { | |||||
| private: | private: | ||||
| void CheckParam(const CNodePtr &kernel_node); | void CheckParam(const CNodePtr &kernel_node); | ||||
| std::vector<int> begin_; | std::vector<int> begin_; | ||||
| std::vector<int> size_; | |||||
| std::vector<int> end_; | |||||
| std::vector<int> strides_; | |||||
| std::vector<size_t> input_shape_; | std::vector<size_t> input_shape_; | ||||
| std::vector<size_t> output_shape_; | std::vector<size_t> output_shape_; | ||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| SliceCPUKernel); | SliceCPUKernel); | ||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| SliceCPUKernel); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,33 +21,54 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| CheckParam(kernel_node); | CheckParam(kernel_node); | ||||
| begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN); | |||||
| size_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE); | |||||
| input_dy_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| if (input_dy_shape_.size() < 4) { | |||||
| for (size_t i = 0; i < 4 - input_dy_shape_.size(); ++i) { | |||||
| input_dy_shape_.insert(input_dy_shape_.begin(), 1); | |||||
| begin_.insert(begin_.begin(), 0); | |||||
| size_.insert(size_.begin(), 1); | |||||
| } | |||||
| } | |||||
| input_x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| output_dx_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | output_dx_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| CPUKernelUtils::ExpandDimsTo4(&input_x_shape_); | |||||
| CPUKernelUtils::ExpandDimsTo4(&output_dx_shape_); | |||||
| input_dy_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN); | |||||
| for (size_t i = 0; i < begin_.size(); i++) { | for (size_t i = 0; i < begin_.size(); i++) { | ||||
| if (begin_[i] < 0) { | if (begin_[i] < 0) { | ||||
| begin_[i] = begin_[i] + input_x_shape_[i]; | |||||
| begin_[i] = begin_[i] + output_dx_shape_[i]; | |||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < size_.size(); i++) { | |||||
| if (size_[i] < 0) { | |||||
| size_[i] = (size_[i] + input_x_shape_[i]) > 0 ? (size_[i] + input_x_shape_[i]) : 0; | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto strides = prim->GetAttr(STRIDES); | |||||
| if (strides != nullptr) { | |||||
| strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES); | |||||
| end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END); | |||||
| if (strides_.size() != end_.size() || strides_.size() != output_dx_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; | |||||
| } | |||||
| for (size_t i = 0; i < strides_.size(); ++i) { | |||||
| if (strides_[i] < 0) { | |||||
| strides_[i] = (strides_[i] + output_dx_shape_[i]) > 0 ? (strides_[i] + output_dx_shape_[i]) : 0; | |||||
| } | |||||
| if (end_[i] < 0) { | |||||
| end_[i] = (end_[i] + output_dx_shape_[i]) > 0 ? (end_[i] + output_dx_shape_[i]) : 0; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE); | |||||
| if (sizes.size() != output_dx_shape_.size() || begin_.size() != output_dx_shape_.size()) { | |||||
| MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; | |||||
| } | |||||
| for (size_t i = 0; i < sizes.size(); ++i) { | |||||
| if (sizes[i] < 0) { | |||||
| sizes[i] = (sizes[i] + output_dx_shape_[i]) > 0 ? (sizes[i] + output_dx_shape_[i]) : 0; | |||||
| } | |||||
| strides_.emplace_back(1); | |||||
| end_.emplace_back(begin_[i] + sizes[i]); | |||||
| } | |||||
| } | |||||
| CPUKernelUtils::ExpandDimsTo4(&output_dx_shape_); | |||||
| auto input_len = input_dy_shape_.size(); | |||||
| if (input_len < 4) { | |||||
| for (size_t i = 0; i < 4 - input_len; ++i) { | |||||
| input_dy_shape_.insert(input_dy_shape_.begin(), 1); | |||||
| begin_.insert(begin_.begin(), 0); | |||||
| strides_.insert(strides_.begin(), 1); | |||||
| end_.insert(end_.begin(), 1); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -65,10 +86,10 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| return false; | return false; | ||||
| } | } | ||||
| for (int i = begin_[0]; i < begin_[0] + size_[0]; ++i) { | |||||
| for (int j = begin_[1]; j < begin_[1] + size_[1]; ++j) { | |||||
| for (int k = begin_[2]; k < begin_[2] + size_[2]; ++k) { | |||||
| for (int m = begin_[3]; m < begin_[3] + size_[3]; ++m) { | |||||
| for (int i = begin_[0]; i < end_[0]; i += strides_[0]) { | |||||
| for (int j = begin_[1]; j < end_[1]; j += strides_[1]) { | |||||
| for (int k = begin_[2]; k < end_[2]; k += strides_[2]) { | |||||
| for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { | |||||
| auto offset = CPUKernelUtils::CalcOffset(output_dx_shape_, i, j, k, m); | auto offset = CPUKernelUtils::CalcOffset(output_dx_shape_, i, j, k, m); | ||||
| output_dx_addr[offset] = *input_dy_addr++; | output_dx_addr[offset] = *input_dy_addr++; | ||||
| } | } | ||||
| @@ -35,9 +35,9 @@ class SliceGradCPUKernel : public CPUKernel { | |||||
| private: | private: | ||||
| void CheckParam(const CNodePtr &kernel_node); | void CheckParam(const CNodePtr &kernel_node); | ||||
| std::vector<int> begin_; | std::vector<int> begin_; | ||||
| std::vector<int> size_; | |||||
| std::vector<int> end_; | |||||
| std::vector<int> strides_; | |||||
| std::vector<size_t> input_dy_shape_; | std::vector<size_t> input_dy_shape_; | ||||
| std::vector<size_t> input_x_shape_; | |||||
| std::vector<size_t> output_dx_shape_; | std::vector<size_t> output_dx_shape_; | ||||
| }; | }; | ||||
| @@ -45,6 +45,8 @@ MS_REG_CPU_KERNEL( | |||||
| SliceGrad, | SliceGrad, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| SliceGradCPUKernel); | SliceGradCPUKernel); | ||||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| SliceGradCPUKernel); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||||
| class StridedSliceGrad(nn.Cell): | |||||
| def __init__(self): | |||||
| super(StridedSliceGrad, self).__init__() | |||||
| self.ssg = G.StridedSliceGrad() | |||||
| self.shape = P.Shape() | |||||
| @ms_function | |||||
| def construct(self, dy, x): | |||||
| return self.ssg(dy, self.shape(x), (2, 0, 0), (3, 2, 3), (1, 1, 1)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_slice(): | |||||
| x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) | |||||
| dy = Tensor(np.array([[[5., 1., 5.], [6., 1., 8.]]]).astype(np.float32)) | |||||
| ssg = StridedSliceGrad() | |||||
| output = ssg(dy, x) | |||||
| expect = [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[5, 1, 5], [6, 1, 8]]] | |||||
| assert (output.asnumpy() == expect).all() | |||||
| @@ -0,0 +1,45 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||||
| class StridedSlice(nn.Cell): | |||||
| def __init__(self): | |||||
| super(StridedSlice, self).__init__() | |||||
| self.stridedslice = P.StridedSlice() | |||||
| def construct(self, x): | |||||
| return self.stridedslice(x, (2, 0, 0), (3, 2, 3), (1, 1, 1)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_slice(): | |||||
| x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) | |||||
| stridedslice = StridedSlice() | |||||
| output = stridedslice(x) | |||||
| expect = [[[5., 5., 5.], | |||||
| [6., 7., 8.]]] | |||||
| assert (output.asnumpy() == expect).all() | |||||