| @@ -0,0 +1,248 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <thread> | |||
| #include "backend/kernel_compiler/cpu/cumsum_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void CumSumCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| axis_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis")); | |||
| dst_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| exclusive_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "exclusive"); | |||
| reverse_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reverse"); | |||
| int input_dim_length = SizeToInt(shape_.size()); | |||
| if (axis_ >= input_dim_length) { | |||
| MS_LOG(EXCEPTION) << "Axis out of bounds."; | |||
| } | |||
| while (axis_ < 0) { | |||
| axis_ += input_dim_length; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::InitWorkspaceSize() { | |||
| input_size_0_ = sizeof(T); | |||
| for (size_t i = 0; i < shape_.size(); i++) { | |||
| input_size_0_ *= shape_[i]; | |||
| } | |||
| workspace_size_list_.emplace_back(input_size_0_); | |||
| } | |||
| void CumSumCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { | |||
| CPUKernel::InitInputOutputSize(kernel_node); | |||
| if (dtype_ == kNumberTypeFloat32) { | |||
| InitWorkspaceSize<float_t>(); | |||
| } else if (dtype_ == kNumberTypeFloat16) { | |||
| InitWorkspaceSize<float16>(); | |||
| } else if (dtype_ == kNumberTypeInt32) { | |||
| InitWorkspaceSize<int32_t>(); | |||
| } else if (dtype_ == kNumberTypeInt8) { | |||
| InitWorkspaceSize<int8_t>(); | |||
| } else if (dtype_ == kNumberTypeUInt8) { | |||
| InitWorkspaceSize<uint8_t>(); | |||
| } | |||
| } | |||
| bool CumSumCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| Reshape(); | |||
| if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float_t>(inputs, workspace, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, workspace, outputs); | |||
| } else if (dtype_ == kNumberTypeInt32) { | |||
| LaunchKernel<int32_t>(inputs, workspace, outputs); | |||
| } else if (dtype_ == kNumberTypeInt8) { | |||
| LaunchKernel<int8_t>(inputs, workspace, outputs); | |||
| } else if (dtype_ == kNumberTypeUInt8) { | |||
| LaunchKernel<uint8_t>(inputs, workspace, outputs); | |||
| } | |||
| return true; | |||
| } | |||
| void CumSumCPUKernel::Reshape() { | |||
| dims_[0] = 1; | |||
| dims_[1] = shape_[IntToSize(axis_)]; | |||
| dims_[2] = 1; | |||
| for (size_t i = 0; i < IntToSize(axis_); i++) { | |||
| dims_[0] *= shape_[i]; | |||
| } | |||
| for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) { | |||
| dims_[2] *= shape_[i]; | |||
| } | |||
| stride_ = dims_[1] * dims_[2]; | |||
| stride2_ = dims_[2]; | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, | |||
| size_t stride2, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| size_t k1 = i / dim2 % dim0; | |||
| size_t k2 = i % dim2; | |||
| size_t offset = k1 * stride + k2; | |||
| for (size_t j = 0; j < dim1; ++j) { | |||
| size_t read_index = j * stride2 + offset; | |||
| if (j == 0) { | |||
| output[read_index] = (T)0; | |||
| } else { | |||
| size_t read_index2 = (j - 1) * stride2 + offset; | |||
| output[read_index] = input[read_index2]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, | |||
| size_t stride2, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| size_t k1 = i / dim2 % dim0; | |||
| size_t k2 = i % dim2; | |||
| size_t offset = k1 * stride + k2; | |||
| for (int j = SizeToInt(dim1 - 1); j >= 0; --j) { | |||
| size_t read_index = j * stride2 + offset; | |||
| if (j == SizeToInt(dim1 - 1)) { | |||
| output[read_index] = (T)0; | |||
| } else { | |||
| size_t read_index2 = (j + 1) * stride2 + offset; | |||
| output[read_index] = input[read_index2]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::Copy(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, | |||
| size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| size_t k1 = i / dim2 % dim0; | |||
| size_t k2 = i % dim2; | |||
| size_t offset = k1 * stride + k2; | |||
| for (size_t j = 0; j < dim1; ++j) { | |||
| size_t read_index = j * stride2 + offset; | |||
| input[read_index] = output[read_index]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, | |||
| size_t stride, size_t stride2, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| size_t k1 = i / dim2 % dim0; | |||
| size_t k2 = i % dim2; | |||
| size_t offset = k1 * stride + k2; | |||
| for (int j = SizeToInt(dim1 - 1); j >= 0; --j) { | |||
| size_t read_index = j * stride2 + offset; | |||
| if (j == SizeToInt(dim1 - 1)) { | |||
| output[read_index] = input[read_index]; | |||
| } else { | |||
| size_t read_index2 = (j + 1) * stride2 + offset; | |||
| output[read_index] = output[read_index2] + input[read_index]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, | |||
| size_t stride2, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| size_t k1 = i / dim2 % dim0; | |||
| size_t k2 = i % dim2; | |||
| size_t offset = k1 * stride + k2; | |||
| for (size_t j = 0; j < dim1; ++j) { | |||
| size_t read_index = j * stride2 + offset; | |||
| if (j == 0) { | |||
| output[read_index] = input[read_index]; | |||
| } else { | |||
| size_t read_index2 = (j - 1) * stride2 + offset; | |||
| output[read_index] = output[read_index2] + input[read_index]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::LaunchCumSum(const T *input, T *output, T *workspace, size_t start, size_t end) { | |||
| start = start / dims_[1]; | |||
| end = end / dims_[1]; | |||
| if (exclusive_) { | |||
| if (reverse_) { | |||
| RightMove(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| Copy(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| CumSumKernelReverse(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| } else { | |||
| LeftMove(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| Copy(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| CumSumKernel(workspace, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| } | |||
| } else { | |||
| if (reverse_) { | |||
| CumSumKernelReverse(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| } else { | |||
| CumSumKernel(input, output, dims_[0], dims_[1], dims_[2], stride_, stride2_, start, end); | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CumSumCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto input = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto ws = reinterpret_cast<T *>(workspace[0]->addr); | |||
| auto output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| // multithreading | |||
| size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(T)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (thread_num < 1 || once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num << "; once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(&CumSumCPUKernel::LaunchCumSum<T>, this, input, output, ws, start, end)); | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| } | |||
| return; | |||
| } | |||
| void CumSumCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumSumGpuKernel needs 1."; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class CumSumCPUKernel : public CPUKernel { | |||
| public: | |||
| CumSumCPUKernel() = default; | |||
| ~CumSumCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void InitWorkspaceSize(); | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs); | |||
| template <typename T> | |||
| void LaunchCumSum(const T *input_addr, T *output_addr, T *ws_addr, size_t start, size_t end); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| void Reshape(); | |||
| template <typename T> | |||
| void LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, | |||
| size_t start, size_t end); | |||
| template <typename T> | |||
| void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, | |||
| size_t start, size_t end); | |||
| template <typename T> | |||
| void Copy(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, size_t start, | |||
| size_t end); | |||
| template <typename T> | |||
| void CumSumKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, | |||
| size_t stride2, size_t start, size_t end); | |||
| template <typename T> | |||
| void CumSumKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, | |||
| size_t start, size_t end); | |||
| std::vector<size_t> shape_; | |||
| std::vector<size_t> dst_shape; | |||
| size_t input_size_0_; | |||
| size_t stride_; | |||
| size_t stride2_; | |||
| size_t dims_[3] = {}; | |||
| int exclusive_; | |||
| int reverse_; | |||
| int axis_; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| CumSumCPUKernel); | |||
| MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| CumSumCPUKernel); | |||
| MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CumSumCPUKernel); | |||
| MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CumSumCPUKernel); | |||
| MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CumSumCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_ | |||
| @@ -844,7 +844,7 @@ class CumSum(PrimitiveWithInfer): | |||
| Tensor, the shape of the output tensor is consistent with the input tensor's. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input = Tensor(np.array([[3, 4, 6, 10],[1, 6, 7, 9],[4, 3, 8, 7],[1, 3, 7, 9]]).astype(np.float32)) | |||
| @@ -0,0 +1,271 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| axis0 = 0 | |||
| axis1 = 1 | |||
| axis2 = 2 | |||
| axis3 = 3 | |||
| axis4 = 4 | |||
| axis5 = -1 | |||
| axis6 = -2 | |||
| x0 = np.random.rand(3, 3, 4, 5, 3).astype(np.float32) | |||
| x1 = np.random.rand(2, 3, 4, 5, 3).astype(np.float16) | |||
| x2 = np.random.randint(-10000, 10000, size=(2, 3, 4, 5, 3)).astype(np.int32) | |||
| x3 = np.random.randint(-5, 5, size=(2, 3, 4, 5, 3)).astype(np.int8) | |||
| x4 = np.random.randint(0, 10, size=(2, 3, 4, 5, 3)).astype(np.uint8) | |||
| x5 = np.random.rand(3).astype(np.float32) | |||
| list1 = [x0, x1, x2, x3, x4] | |||
| list2 = [axis0, axis1, axis2, axis3, axis4, axis5, axis6] | |||
| class CumSum(nn.Cell): | |||
| def __init__(self, exclusive=False, reverse=False): | |||
| super(CumSum, self).__init__() | |||
| self.cumsum_op = P.CumSum(exclusive, reverse) | |||
| self.x0 = Tensor(x0) | |||
| self.axis0 = axis0 | |||
| self.x1 = Tensor(x0) | |||
| self.axis1 = axis1 | |||
| self.x2 = Tensor(x0) | |||
| self.axis2 = axis2 | |||
| self.x3 = Tensor(x0) | |||
| self.axis3 = axis3 | |||
| self.x4 = Tensor(x0) | |||
| self.axis4 = axis4 | |||
| self.x5 = Tensor(x0) | |||
| self.axis5 = axis5 | |||
| self.x6 = Tensor(x0) | |||
| self.axis6 = axis6 | |||
| self.x7 = Tensor(x1) | |||
| self.axis7 = axis0 | |||
| self.x8 = Tensor(x1) | |||
| self.axis8 = axis1 | |||
| self.x9 = Tensor(x1) | |||
| self.axis9 = axis2 | |||
| self.x10 = Tensor(x1) | |||
| self.axis10 = axis3 | |||
| self.x11 = Tensor(x1) | |||
| self.axis11 = axis4 | |||
| self.x12 = Tensor(x1) | |||
| self.axis12 = axis5 | |||
| self.x13 = Tensor(x1) | |||
| self.axis13 = axis6 | |||
| self.x14 = Tensor(x2) | |||
| self.axis14 = axis0 | |||
| self.x15 = Tensor(x2) | |||
| self.axis15 = axis1 | |||
| self.x16 = Tensor(x2) | |||
| self.axis16 = axis2 | |||
| self.x17 = Tensor(x2) | |||
| self.axis17 = axis3 | |||
| self.x18 = Tensor(x2) | |||
| self.axis18 = axis4 | |||
| self.x19 = Tensor(x2) | |||
| self.axis19 = axis5 | |||
| self.x20 = Tensor(x2) | |||
| self.axis20 = axis6 | |||
| self.x21 = Tensor(x3) | |||
| self.axis21 = axis0 | |||
| self.x22 = Tensor(x3) | |||
| self.axis22 = axis1 | |||
| self.x23 = Tensor(x3) | |||
| self.axis23 = axis2 | |||
| self.x24 = Tensor(x3) | |||
| self.axis24 = axis3 | |||
| self.x25 = Tensor(x3) | |||
| self.axis25 = axis4 | |||
| self.x26 = Tensor(x3) | |||
| self.axis26 = axis5 | |||
| self.x27 = Tensor(x3) | |||
| self.axis27 = axis6 | |||
| self.x28 = Tensor(x4) | |||
| self.axis28 = axis0 | |||
| self.x29 = Tensor(x4) | |||
| self.axis29 = axis1 | |||
| self.x30 = Tensor(x4) | |||
| self.axis30 = axis2 | |||
| self.x31 = Tensor(x4) | |||
| self.axis31 = axis3 | |||
| self.x32 = Tensor(x4) | |||
| self.axis32 = axis4 | |||
| self.x33 = Tensor(x4) | |||
| self.axis33 = axis5 | |||
| self.x34 = Tensor(x4) | |||
| self.axis34 = axis6 | |||
| self.x35 = Tensor(x5) | |||
| self.axis35 = axis0 | |||
| def construct(self): | |||
| return (self.cumsum_op(self.x0, self.axis0), | |||
| self.cumsum_op(self.x1, self.axis1), | |||
| self.cumsum_op(self.x2, self.axis2), | |||
| self.cumsum_op(self.x3, self.axis3), | |||
| self.cumsum_op(self.x4, self.axis4), | |||
| self.cumsum_op(self.x5, self.axis5), | |||
| self.cumsum_op(self.x6, self.axis6), | |||
| self.cumsum_op(self.x7, self.axis7), | |||
| self.cumsum_op(self.x8, self.axis8), | |||
| self.cumsum_op(self.x9, self.axis9), | |||
| self.cumsum_op(self.x10, self.axis10), | |||
| self.cumsum_op(self.x11, self.axis11), | |||
| self.cumsum_op(self.x12, self.axis12), | |||
| self.cumsum_op(self.x13, self.axis13), | |||
| self.cumsum_op(self.x14, self.axis14), | |||
| self.cumsum_op(self.x15, self.axis15), | |||
| self.cumsum_op(self.x16, self.axis16), | |||
| self.cumsum_op(self.x17, self.axis17), | |||
| self.cumsum_op(self.x18, self.axis18), | |||
| self.cumsum_op(self.x19, self.axis19), | |||
| self.cumsum_op(self.x20, self.axis20), | |||
| self.cumsum_op(self.x21, self.axis21), | |||
| self.cumsum_op(self.x22, self.axis22), | |||
| self.cumsum_op(self.x23, self.axis23), | |||
| self.cumsum_op(self.x24, self.axis24), | |||
| self.cumsum_op(self.x25, self.axis25), | |||
| self.cumsum_op(self.x26, self.axis26), | |||
| self.cumsum_op(self.x27, self.axis27), | |||
| self.cumsum_op(self.x28, self.axis28), | |||
| self.cumsum_op(self.x29, self.axis29), | |||
| self.cumsum_op(self.x30, self.axis30), | |||
| self.cumsum_op(self.x31, self.axis31), | |||
| self.cumsum_op(self.x32, self.axis32), | |||
| self.cumsum_op(self.x33, self.axis33), | |||
| self.cumsum_op(self.x34, self.axis34), | |||
| self.cumsum_op(self.x35, self.axis35)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_cumsum(): | |||
| cumsum = CumSum() | |||
| output = cumsum() | |||
| k = 0 | |||
| for i in list1: | |||
| for j in list2: | |||
| expect = np.cumsum(i, axis=j) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||
| k += 1 | |||
| expect = np.cumsum(x5, axis=axis0) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||
| def test_cumsum2(): | |||
| cumsum = CumSum(exclusive=False, reverse=True) | |||
| output = cumsum() | |||
| k = 0 | |||
| for i in list1: | |||
| for j in list2: | |||
| result1 = np.flip(i, axis=j) | |||
| result2 = np.cumsum(result1, axis=j) | |||
| expect = np.flip(result2, axis=j) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||
| k += 1 | |||
| result1 = np.flip(x5, axis=axis0) | |||
| result2 = np.cumsum(result1, axis=axis0) | |||
| expect = np.flip(result2, axis=axis0) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||
| def test_cumsum3(): | |||
| cumsum = CumSum(exclusive=True, reverse=False) | |||
| output = cumsum() | |||
| k = 0 | |||
| for i in list1: | |||
| for j in list2: | |||
| result1 = np.insert(i, 0, [0], axis=j) | |||
| result2 = np.delete(result1, -1, axis=j) | |||
| expect = np.cumsum(result2, axis=j) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||
| k += 1 | |||
| result1 = np.insert(x5, 0, [0], axis=axis0) | |||
| result2 = np.delete(result1, -1, axis=axis0) | |||
| expect = np.cumsum(result2, axis=axis0) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||
| def test_cumsum4(): | |||
| cumsum = CumSum(exclusive=True, reverse=True) | |||
| output = cumsum() | |||
| k = 0 | |||
| for i in list1: | |||
| for j in list2: | |||
| result1 = np.flip(i, axis=j) | |||
| result2 = np.insert(result1, 0, [0], axis=j) | |||
| result3 = np.delete(result2, -1, axis=j) | |||
| result4 = np.cumsum(result3, axis=j) | |||
| expect = np.flip(result4, axis=j) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||
| k += 1 | |||
| result1 = np.flip(x5, axis=axis0) | |||
| result2 = np.insert(result1, 0, [0], axis=axis0) | |||
| result3 = np.delete(result2, -1, axis=axis0) | |||
| result4 = np.cumsum(result3, axis=axis0) | |||
| expect = np.flip(result4, axis=axis0) | |||
| diff = abs(output[k].asnumpy() - expect) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| assert np.all(diff < error) | |||
| assert output[k].shape == expect.shape | |||