From 28241d02937733452a85334d1eb12b5338af5bc4 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Mon, 11 May 2020 16:29:11 +0800 Subject: [PATCH] lstm ops --- .../ccsrc/kernel/cpu/concat_cpu_kernel.cc | 105 +++++++++++++++++ .../ccsrc/kernel/cpu/concat_cpu_kernel.h | 51 ++++++++ mindspore/ccsrc/kernel/cpu/cpu_kernel.cc | 26 ++++ mindspore/ccsrc/kernel/cpu/cpu_kernel.h | 9 ++ .../ccsrc/kernel/cpu/gather_cpu_kernel.cc | 111 ++++++++++++++++++ .../ccsrc/kernel/cpu/gather_cpu_kernel.h | 52 ++++++++ .../ccsrc/kernel/cpu/slice_cpu_kernel.cc | 91 ++++++++++++++ mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h | 48 ++++++++ .../ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc | 95 +++++++++++++++ .../ccsrc/kernel/cpu/slice_grad_cpu_kernel.h | 51 ++++++++ tests/st/ops/cpu/test_concat_op.py | 102 ++++++++++++++++ tests/st/ops/cpu/test_gather_op.py | 108 +++++++++++++++++ tests/st/ops/cpu/test_slice_grad_op.py | 58 +++++++++ tests/st/ops/cpu/test_slice_op.py | 50 ++++++++ 14 files changed, 957 insertions(+) create mode 100644 mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h create mode 100644 mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h create mode 100644 mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h create mode 100644 mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_concat_op.py create mode 100644 tests/st/ops/cpu/test_gather_op.py create mode 100644 tests/st/ops/cpu/test_slice_grad_op.py create mode 100644 tests/st/ops/cpu/test_slice_op.py diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc new file mode 100644 index 0000000000..4ddb855eb6 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/cpu/concat_cpu_kernel.h" +#include "device/cpu/cpu_device_address.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_1_shape.size()); + } + axis_ += 4 - input_1_shape.size(); + + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; i++) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + CPUKernelUtils::ExpandDimsTo4(&input_shape); + input_shape_list_.push_back(input_shape); + } + + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool ConcatCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t dim0 = output_shape_[0]; + size_t dim1 = output_shape_[1]; + size_t dim2 = output_shape_[2]; + + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + CopyDataToOutput(inputs, i, j, k, &output_addr); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + CopyDataToOutput(inputs, i, j, 0, &output_addr); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + CopyDataToOutput(inputs, i, 0, 0, &output_addr); + } + } else if (axis_ == 0) { + CopyDataToOutput(inputs, 0, 0, 0, &output_addr); + } + return true; +} + +void ConcatCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr) { + for (size_t i = 0; i < input_shape_list_.size(); ++i) { + auto input_i_shape = input_shape_list_[i]; + auto input_i_addr = reinterpret_cast(inputs[i]->addr); + + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_); + num *= input_i_shape[axis_]; + auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0); + auto ret = memcpy_s(*output_addr, num * sizeof(float), input_i_addr + pos, num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed."; + } + *output_addr += num; + } +} + +void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but ConcatCPUKernel olny support 4d or lower."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h new file mode 100644 index 0000000000..2d1fe06372 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#include +#include +#include "kernel/cpu/cpu_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ConcatCPUKernel : public CPUKernel { + public: + ConcatCPUKernel() = default; + ~ConcatCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr); + int axis_; + std::vector> input_shape_list_; + std::vector output_shape_; +}; + +MS_REG_CPU_KERNEL( + Concat, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConcatCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc index 7150c06eb5..c9d3770c6e 100644 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc @@ -40,5 +40,31 @@ void CPUKernel::Init(const CNodePtr &kernel_node) { InitInputOutputSize(kernel_node); InitKernel(kernel_node); } + +void CPUKernelUtils::ExpandDimsTo4(std::vector *shape) { + auto len = shape->size(); + if (len < 4) { + for (size_t i = 0; i < 4 - len; ++i) { + shape->insert(shape->begin(), 1); + } + } +} + +size_t CPUKernelUtils::CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, + size_t dim3) { + size_t offset = dim0 * shape[1] * shape[2] * shape[3] + dim1 * shape[2] * shape[3] + dim2 * shape[3] + dim3; + return offset; +} + +size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector &shape, int axis) { + if (axis < 0) { + axis = axis + SizeToInt(shape.size()); + } + size_t result = 1; + for (int j = 3; j > axis; --j) { + result *= shape[j]; + } + return result; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel.h index f9121cb175..378f36ac5b 100644 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/cpu_kernel.h @@ -46,6 +46,8 @@ const char IS_GRAD[] = "is_grad"; const char TRANSPOSE_NO = 'N'; const char TRANSPOSE_YES = 'T'; const char AXIS[] = "axis"; +const char BEGIN[] = "begin"; +const char SIZE[] = "size"; class CPUKernel : public kernel::KernelMod { public: @@ -69,6 +71,13 @@ class CPUKernel : public kernel::KernelMod { std::vector output_size_list_; std::vector workspace_size_list_; }; + +class CPUKernelUtils { + public: + static void ExpandDimsTo4(std::vector *shape); + static size_t CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3); + static size_t GetElementNumOnAxis(const std::vector &shape, int axis); +}; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc new file mode 100644 index 0000000000..044f276d76 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel/cpu/gather_cpu_kernel.h" +#include "device/cpu/cpu_device_address.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shape_.size()); + } + axis_ += 4 - input_shape_.size(); + + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool GatherV2CPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + + size_t dim0 = input_shape_[0]; + size_t dim1 = input_shape_[1]; + size_t dim2 = input_shape_[2]; + + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + CopyDataToOutput(inputs, i, j, k, &output_addr); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + CopyDataToOutput(inputs, i, j, 0, &output_addr); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + CopyDataToOutput(inputs, i, 0, 0, &output_addr); + } + } else if (axis_ == 0) { + CopyDataToOutput(inputs, 0, 0, 0, &output_addr); + } + + return true; +} + +void GatherV2CPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto indices_addr = reinterpret_cast(inputs[1]->addr); + + for (size_t i = 0; i < output_shape_[axis_]; ++i) { + size_t index = IntToSize(indices_addr[i]); + size_t pos = 0; + if (axis_ == 3) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); + } else if (axis_ == 2) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); + } else if (axis_ == 1) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); + } else if (axis_ == 0) { + pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); + } + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); + auto ret = memcpy_s(*output_addr, num * sizeof(float), input_addr + pos, num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed."; + } + *output_addr += num; + } +} + +void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower."; + } + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h new file mode 100644 index 0000000000..08201ff165 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#include +#include +#include "kernel/cpu/cpu_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class GatherV2CPUKernel : public CPUKernel { + public: + GatherV2CPUKernel() = default; + ~GatherV2CPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr); + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector indices_shape_; + std::vector output_shape_; + int axis_; +}; + +MS_REG_CPU_KERNEL( + GatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherV2CPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc new file mode 100644 index 0000000000..9e27ddf3b1 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel/cpu/slice_cpu_kernel.h" +#include "device/cpu/cpu_device_address.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + + begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); + size_ = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape_.size() < 4) { + for (size_t i = 0; i < 4 - input_shape_.size(); ++i) { + input_shape_.insert(input_shape_.begin(), 1); + begin_.insert(begin_.begin(), 0); + size_.insert(size_.begin(), 1); + } + } + + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); + + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_shape_[i]; + } + } + + for (size_t i = 0; i < size_.size(); i++) { + if (size_[i] < 0) { + size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; + } + } +} + +bool SliceCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + for (int i = begin_[0]; i < begin_[0] + size_[0]; ++i) { + for (int j = begin_[1]; j < begin_[1] + size_[1]; ++j) { + for (int k = begin_[2]; k < begin_[2] + size_[2]; ++k) { + for (int m = begin_[3]; m < begin_[3] + size_[3]; ++m) { + auto offset = CPUKernelUtils::CalcOffset(input_shape_, i, j, k, m); + *output_addr++ = input_addr[offset]; + } + } + } + } + + return true; +} + +void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; + } + if (input_shape.size() == 0) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h new file mode 100644 index 0000000000..d8a71a5335 --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#include +#include +#include "kernel/cpu/cpu_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SliceCPUKernel : public CPUKernel { + public: + SliceCPUKernel() = default; + ~SliceCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + std::vector begin_; + std::vector size_; + std::vector input_shape_; + std::vector output_shape_; +}; + +MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc new file mode 100644 index 0000000000..1176bef12c --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "kernel/cpu/slice_grad_cpu_kernel.h" +#include "device/cpu/cpu_device_address.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + + begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); + size_ = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + + input_dy_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_dy_shape_.size() < 4) { + for (size_t i = 0; i < 4 - input_dy_shape_.size(); ++i) { + input_dy_shape_.insert(input_dy_shape_.begin(), 1); + begin_.insert(begin_.begin(), 0); + size_.insert(size_.begin(), 1); + } + } + + input_x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_dx_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CPUKernelUtils::ExpandDimsTo4(&input_x_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_dx_shape_); + + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_x_shape_[i]; + } + } + + for (size_t i = 0; i < size_.size(); i++) { + if (size_[i] < 0) { + size_[i] = (size_[i] + input_x_shape_[i]) > 0 ? (size_[i] + input_x_shape_[i]) : 0; + } + } +} + +bool SliceGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_dy_addr = reinterpret_cast(inputs[0]->addr); + auto output_dx_addr = reinterpret_cast(outputs[0]->addr); + + auto out_size = sizeof(float) * output_dx_shape_[0] * output_dx_shape_[1] * output_dx_shape_[2] * output_dx_shape_[3]; + auto ret = memset_s(output_dx_addr, out_size, 0, out_size); + if (ret != EOK) { + MS_LOG(ERROR) << "output buff memset fail."; + return false; + } + + for (int i = begin_[0]; i < begin_[0] + size_[0]; ++i) { + for (int j = begin_[1]; j < begin_[1] + size_[1]; ++j) { + for (int k = begin_[2]; k < begin_[2] + size_[2]; ++k) { + for (int m = begin_[3]; m < begin_[3] + size_[3]; ++m) { + auto offset = CPUKernelUtils::CalcOffset(output_dx_shape_, i, j, k, m); + output_dx_addr[offset] = *input_dy_addr++; + } + } + } + } + return true; +} + +void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; + } + if (input_shape.size() == 0) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h new file mode 100644 index 0000000000..5508b1b0ba --- /dev/null +++ b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#include +#include +#include "kernel/cpu/cpu_kernel.h" +#include "kernel/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SliceGradCPUKernel : public CPUKernel { + public: + SliceGradCPUKernel() = default; + ~SliceGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + std::vector begin_; + std::vector size_; + std::vector input_dy_shape_; + std::vector input_x_shape_; + std::vector output_dx_shape_; +}; + +MS_REG_CPU_KERNEL( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_concat_op.py b/tests/st/ops/cpu/test_concat_op.py new file mode 100644 index 0000000000..42ac8e0843 --- /dev/null +++ b/tests/st/ops/cpu/test_concat_op.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + +class Concat_Axis0(nn.Cell): + def __init__(self): + super(Concat_Axis0, self).__init__() + self.cat = P.Concat(axis=0) + + def construct(self, x1, x2): + return self.cat((x1, x2)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_in2_axis0(): + x1 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) + x2 = Tensor(np.arange(3 * 2 * 2).reshape(3, 2, 2), mstype.float32) + cat = Concat_Axis0() + output_ms = cat(x1, x2) + print("output:\n", output_ms) + output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=0) + + error = np.ones(shape=output_np.shape) * 10e-6 + diff = output_ms.asnumpy() - output_np + assert np.all(diff < error) + assert np.all(-diff < error) + +class Concat_Axis1(nn.Cell): + def __init__(self): + super(Concat_Axis1, self).__init__() + self.cat = P.Concat(axis=1) + + def construct(self, x1, x2): + return self.cat((x1, x2)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_in2_axis1(): + x1 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) + x2 = Tensor(np.arange(2 * 3 * 2).reshape(2, 3, 2), mstype.float32) + cat = Concat_Axis1() + output_ms = cat(x1, x2) + print("output:\n", output_ms) + output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=1) + + error = np.ones(shape=output_np.shape) * 10e-6 + diff = output_ms.asnumpy() - output_np + assert np.all(diff < error) + assert np.all(-diff < error) + +class Concat_Axis2(nn.Cell): + def __init__(self): + super(Concat_Axis2, self).__init__() + self.cat = P.Concat(axis=-1) + + def construct(self, x1, x2): + return self.cat((x1, x2)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_in3_axis2(): + x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32) + x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) + x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32) + cat = Concat_Axis2() + output_ms = cat(x1, x2) + print("output:\n", output_ms) + output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=-1) + + error = np.ones(shape=output_np.shape) * 10e-6 + diff = output_ms.asnumpy() - output_np + assert np.all(diff < error) + assert np.all(-diff < error) + +if __name__ == '__main__': + test_in2_axis0() + test_in2_axis1() + test_in3_axis2() diff --git a/tests/st/ops/cpu/test_gather_op.py b/tests/st/ops/cpu/test_gather_op.py new file mode 100644 index 0000000000..50fb2096dd --- /dev/null +++ b/tests/st/ops/cpu/test_gather_op.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + +class NetGatherV2_axis0(nn.Cell): + def __init__(self): + super(NetGatherV2_axis0, self).__init__() + self.gatherv2 = P.GatherV2() + + def construct(self, params, indices): + return self.gatherv2(params, indices, 0) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherv2_axis0(): + x = Tensor(np.arange(3 * 2 * 2).reshape(3, 2, 2), mstype.float32) + indices = Tensor(np.array([1, 2]), mstype.int32) + gatherv2 = NetGatherV2_axis0() + ms_output = gatherv2(x, indices) + print("output:\n", ms_output) + expect = np.array([[[4., 5.], + [6., 7.]], + [[8., 9.], + [10., 11.]]]) + error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6 + diff = ms_output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +class NetGatherV2_axis1(nn.Cell): + def __init__(self): + super(NetGatherV2_axis1, self).__init__() + self.gatherv2 = P.GatherV2() + + def construct(self, params, indices): + return self.gatherv2(params, indices, 1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherv2_axis1(): + x = Tensor(np.arange(2 * 3 * 2).reshape(2, 3, 2), mstype.float32) + indices = Tensor(np.array([1, 2]), mstype.int32) + gatherv2 = NetGatherV2_axis1() + ms_output = gatherv2(x, indices) + print("output:\n", ms_output) + expect = np.array([[[2., 3.], + [4., 5.]], + [[8., 9.], + [10., 11.]]]) + error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6 + diff = ms_output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +class NetGatherV2_axisN1(nn.Cell): + def __init__(self): + super(NetGatherV2_axisN1, self).__init__() + self.gatherv2 = P.GatherV2() + + def construct(self, params, indices): + return self.gatherv2(params, indices, -1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherv2_axisN1(): + x = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32) + indices = Tensor(np.array([1, 2]), mstype.int32) + gatherv2 = NetGatherV2_axisN1() + ms_output = gatherv2(x, indices) + print("output:\n", ms_output) + expect = np.array([[[1., 2.], + [4., 5.]], + [[7., 8.], + [10.,11.]]]) + error = np.ones(shape=ms_output.asnumpy().shape) * 1.0e-6 + diff = ms_output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +if __name__ == '__main__': + test_gatherv2_axis0() + test_gatherv2_axis1() + test_gatherv2_axisN1() diff --git a/tests/st/ops/cpu/test_slice_grad_op.py b/tests/st/ops/cpu/test_slice_grad_op.py new file mode 100644 index 0000000000..5561d3cf34 --- /dev/null +++ b/tests/st/ops/cpu/test_slice_grad_op.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class SliceGrad(nn.Cell): + def __init__(self): + super(SliceGrad, self).__init__() + + self.slicegrad = G.SliceGrad() + + @ms_function + def construct(self, dy, x): + return self.slicegrad(dy, x, (0, 1, 0), (2, 1, 3)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_slice(): + x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]), mstype.float32) + dy = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]), mstype.float32) + slicegrad = SliceGrad() + output = slicegrad(dy, x) + expect = [[[0., 0., 0.], + [3., 1., 2.]], + [[0., 0., 0.], + [4., 1., 4.]], + [[0., 0., 0.], + [0., 0., 0.]]] + print("output:\n", output) + assert (output.asnumpy() == expect).all() + +if __name__ == '__main__': + test_slice() diff --git a/tests/st/ops/cpu/test_slice_op.py b/tests/st/ops/cpu/test_slice_op.py new file mode 100644 index 0000000000..d8ee86be6a --- /dev/null +++ b/tests/st/ops/cpu/test_slice_op.py @@ -0,0 +1,50 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + +class Slice(nn.Cell): + def __init__(self): + super(Slice, self).__init__() + self.slice = P.Slice() + + def construct(self, x): + return self.slice(x, (0, 1, 0), (2, 1, 3)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_slice(): + x = Tensor( + np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]), mstype.float32) + expect = [[[2., -2., 2.]], + [[4., -4., 4.]]] + + slice = Slice() + output = slice(x) + print("output:\n", output) + assert (output.asnumpy() == expect).all() + +if __name__ == '__main__': + test_slice() +