| @@ -30,5 +30,9 @@ MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).A | |||
| ArrayReduceGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ArrayReduceGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArrayReduceGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(ReduceMin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ArrayReduceGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -29,6 +29,7 @@ const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = { | |||
| {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, | |||
| {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, | |||
| {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, | |||
| {"ReduceMin", CUDNN_REDUCE_TENSOR_MIN}, | |||
| }; | |||
| template <typename T> | |||
| class ArrayReduceGpuKernel : public GpuKernel { | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cumsum_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void CumSumKernel(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, | |||
| size_t stride2) { | |||
| size_t num = dim0 * dim2; | |||
| size_t i, k, offset; | |||
| for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; | |||
| write_index += blockDim.x * gridDim.x) { | |||
| i = write_index / dim2 % dim0; | |||
| k = write_index % dim2; | |||
| offset = i * stride + k; | |||
| for (size_t j = 0; j < dim1; ++j) { | |||
| size_t read_index = j * stride2 + offset; | |||
| if (j == 0) { | |||
| output[read_index] = input[read_index]; | |||
| } else { | |||
| size_t read_index2 = (j - 1) * stride2 + offset; | |||
| output[read_index] = output[read_index2] + input[read_index]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, | |||
| cudaStream_t stream) { | |||
| int size = dim0 * dim2; | |||
| CumSumKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2); | |||
| return; | |||
| } | |||
| template void CumSum<float>(float *input, float *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, | |||
| size_t stride2, cudaStream_t stream); | |||
| @@ -0,0 +1,22 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ | |||
| template <typename T> | |||
| void CumSum(T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride, size_t stride2, | |||
| cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_ | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/math/cumsum_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| CumSumGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,102 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class CumSumGpuKernel : public GpuKernel { | |||
| public: | |||
| CumSumGpuKernel() : axis_(0), input_size_0_(0), stride_(0), stride2_(0) {} | |||
| ~CumSumGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| CumSum(input_addr, output_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumSumGpuKernel needs 1."; | |||
| return false; | |||
| } | |||
| input_size_0_ = sizeof(T); | |||
| shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| axis_ = GetAttr<int>(kernel_node, "axis"); | |||
| int input_dim_length = SizeToInt(shape_.size()); | |||
| if (axis_ >= input_dim_length) { | |||
| MS_LOG(EXCEPTION) << "Axis out of bounds."; | |||
| } | |||
| while (axis_ < 0) { | |||
| axis_ += input_dim_length; | |||
| } | |||
| for (size_t i = 0; i < shape_.size(); i++) { | |||
| input_size_0_ *= shape_[i]; | |||
| } | |||
| Reshape(); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_0_); | |||
| output_size_list_.push_back(input_size_0_); | |||
| } | |||
| private: | |||
| void Reshape() { | |||
| dims_[0] = 1; | |||
| dims_[1] = shape_[IntToSize(axis_)]; | |||
| dims_[2] = 1; | |||
| for (size_t i = 0; i < IntToSize(axis_); i++) { | |||
| dims_[0] *= shape_[i]; | |||
| } | |||
| for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) { | |||
| dims_[2] *= shape_[i]; | |||
| } | |||
| stride_ = dims_[1] * dims_[2]; | |||
| stride2_ = dims_[2]; | |||
| return; | |||
| } | |||
| int axis_; | |||
| size_t input_size_0_; | |||
| size_t stride_; | |||
| size_t stride2_; | |||
| size_t dims_[3] = {}; | |||
| std::vector<size_t> shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,132 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| x0 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis0 = 3 | |||
| x1 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis1 = 3 | |||
| x2 = np.random.rand(2, 3, 1, 4).astype(np.float32) | |||
| axis2 = 2 | |||
| x3 = np.random.rand(2, 3, 1, 4).astype(np.float32) | |||
| axis3 = 2 | |||
| x4 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis4 = 1 | |||
| x5 = np.random.rand(2, 3).astype(np.float32) | |||
| axis5 = 1 | |||
| x6 = np.random.rand(1, 1, 1, 1).astype(np.float32) | |||
| axis6 = 0 | |||
| context.set_context(device_target='GPU') | |||
| class CumSum(nn.Cell): | |||
| def __init__(self): | |||
| super(CumSum, self).__init__() | |||
| self.x0 = Tensor(x0) | |||
| self.axis0 = axis0 | |||
| self.x1 = Tensor(x1) | |||
| self.axis1 = axis1 | |||
| self.x2 = Tensor(x2) | |||
| self.axis2 = axis2 | |||
| self.x3 = Tensor(x3) | |||
| self.axis3 = axis3 | |||
| self.x4 = Tensor(x4) | |||
| self.axis4 = axis4 | |||
| self.x5 = Tensor(x5) | |||
| self.axis5 = axis5 | |||
| self.x6 = Tensor(x6) | |||
| self.axis6 = axis6 | |||
| @ms_function | |||
| def construct(self): | |||
| return (P.CumSum()(self.x0, self.axis0), | |||
| P.CumSum()(self.x1, self.axis1), | |||
| P.CumSum()(self.x2, self.axis2), | |||
| P.CumSum()(self.x3, self.axis3), | |||
| P.CumSum()(self.x4, self.axis4), | |||
| P.CumSum()(self.x5, self.axis5), | |||
| P.CumSum()(self.x6, self.axis6)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_CumSum(): | |||
| cumsum = CumSum() | |||
| output = cumsum() | |||
| expect0 = np.cumsum(x0, axis=axis0) | |||
| diff0 = abs(output[0].asnumpy() - expect0) | |||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | |||
| assert np.all(diff0 < error0) | |||
| assert output[0].shape == expect0.shape | |||
| expect1 = np.cumsum(x1, axis=axis1) | |||
| diff1 = abs(output[1].asnumpy() - expect1) | |||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||
| assert np.all(diff1 < error1) | |||
| assert output[1].shape == expect1.shape | |||
| expect2 = np.cumsum(x2, axis=axis2) | |||
| diff2 = abs(output[2].asnumpy() - expect2) | |||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | |||
| assert np.all(diff2 < error2) | |||
| assert output[2].shape == expect2.shape | |||
| expect3 = np.cumsum(x3, axis=axis3) | |||
| diff3 = abs(output[3].asnumpy() - expect3) | |||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | |||
| assert np.all(diff3 < error3) | |||
| assert output[3].shape == expect3.shape | |||
| expect4 = np.cumsum(x4, axis=axis4) | |||
| diff4 = abs(output[4].asnumpy() - expect4) | |||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | |||
| assert np.all(diff4 < error4) | |||
| assert output[4].shape == expect4.shape | |||
| expect5 = np.cumsum(x5, axis=axis5) | |||
| diff5 = abs(output[5].asnumpy() - expect5) | |||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | |||
| assert np.all(diff5 < error5) | |||
| assert output[5].shape == expect5.shape | |||
| expect6 = np.cumsum(x6, axis=axis6) | |||
| diff6 = abs(output[6].asnumpy() - expect6) | |||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | |||
| assert np.all(diff6 < error6) | |||
| assert output[6].shape == expect6.shape | |||
| @@ -0,0 +1,177 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| x0 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis0 = 3 | |||
| keep_dims0 = True | |||
| x1 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis1 = 3 | |||
| keep_dims1 = False | |||
| x2 = np.random.rand(2, 3, 1, 4).astype(np.float32) | |||
| axis2 = 2 | |||
| keep_dims2 = True | |||
| x3 = np.random.rand(2, 3, 1, 4).astype(np.float32) | |||
| axis3 = 2 | |||
| keep_dims3 = False | |||
| x4 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis4 = () | |||
| np_axis4 = None | |||
| keep_dims4 = True | |||
| x5 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis5 = () | |||
| np_axis5 = None | |||
| keep_dims5 = False | |||
| x6 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis6 = -2 | |||
| keep_dims6 = False | |||
| x7 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis7 = (-2, -1) | |||
| keep_dims7 = True | |||
| x8 = np.random.rand(1, 1, 1, 1).astype(np.float32) | |||
| axis8 = () | |||
| np_axis8 = None | |||
| keep_dims8 = True | |||
| context.set_context(device_target='GPU') | |||
| class ReduceMin(nn.Cell): | |||
| def __init__(self): | |||
| super(ReduceMin, self).__init__() | |||
| self.x0 = Tensor(x0) | |||
| self.axis0 = axis0 | |||
| self.keep_dims0 = keep_dims0 | |||
| self.x1 = Tensor(x1) | |||
| self.axis1 = axis1 | |||
| self.keep_dims1 = keep_dims1 | |||
| self.x2 = Tensor(x2) | |||
| self.axis2 = axis2 | |||
| self.keep_dims2 = keep_dims2 | |||
| self.x3 = Tensor(x3) | |||
| self.axis3 = axis3 | |||
| self.keep_dims3 = keep_dims3 | |||
| self.x4 = Tensor(x4) | |||
| self.axis4 = axis4 | |||
| self.keep_dims4 = keep_dims4 | |||
| self.x5 = Tensor(x5) | |||
| self.axis5 = axis5 | |||
| self.keep_dims5 = keep_dims5 | |||
| self.x6 = Tensor(x6) | |||
| self.axis6 = axis6 | |||
| self.keep_dims6 = keep_dims6 | |||
| self.x7 = Tensor(x7) | |||
| self.axis7 = axis7 | |||
| self.keep_dims7 = keep_dims7 | |||
| self.x8 = Tensor(x8) | |||
| self.axis8 = axis8 | |||
| self.keep_dims8 = keep_dims8 | |||
| @ms_function | |||
| def construct(self): | |||
| return (P.ReduceMin(self.keep_dims0)(self.x0, self.axis0), | |||
| P.ReduceMin(self.keep_dims1)(self.x1, self.axis1), | |||
| P.ReduceMin(self.keep_dims2)(self.x2, self.axis2), | |||
| P.ReduceMin(self.keep_dims3)(self.x3, self.axis3), | |||
| P.ReduceMin(self.keep_dims4)(self.x4, self.axis4), | |||
| P.ReduceMin(self.keep_dims5)(self.x5, self.axis5), | |||
| P.ReduceMin(self.keep_dims6)(self.x6, self.axis6), | |||
| P.ReduceMin(self.keep_dims7)(self.x7, self.axis7), | |||
| P.ReduceMin(self.keep_dims8)(self.x8, self.axis8)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_ReduceMin(): | |||
| reduce_min = ReduceMin() | |||
| output = reduce_min() | |||
| expect0 = np.min(x0, axis=axis0, keepdims=keep_dims0) | |||
| diff0 = abs(output[0].asnumpy() - expect0) | |||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | |||
| assert np.all(diff0 < error0) | |||
| assert output[0].shape == expect0.shape | |||
| expect1 = np.min(x1, axis=axis1, keepdims=keep_dims1) | |||
| diff1 = abs(output[1].asnumpy() - expect1) | |||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||
| assert np.all(diff1 < error1) | |||
| assert output[1].shape == expect1.shape | |||
| expect2 = np.min(x2, axis=axis2, keepdims=keep_dims2) | |||
| diff2 = abs(output[2].asnumpy() - expect2) | |||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | |||
| assert np.all(diff2 < error2) | |||
| assert output[2].shape == expect2.shape | |||
| expect3 = np.min(x3, axis=axis3, keepdims=keep_dims3) | |||
| diff3 = abs(output[3].asnumpy() - expect3) | |||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | |||
| assert np.all(diff3 < error3) | |||
| assert output[3].shape == expect3.shape | |||
| expect4 = np.min(x4, axis=np_axis4, keepdims=keep_dims4) | |||
| diff4 = abs(output[4].asnumpy() - expect4) | |||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | |||
| assert np.all(diff4 < error4) | |||
| assert output[4].shape == expect4.shape | |||
| expect5 = np.min(x5, axis=np_axis5, keepdims=keep_dims5) | |||
| diff5 = abs(output[5].asnumpy() - expect5) | |||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | |||
| assert np.all(diff5 < error5) | |||
| assert output[5].shape == expect5.shape | |||
| expect6 = np.min(x6, axis=axis6, keepdims=keep_dims6) | |||
| diff6 = abs(output[6].asnumpy() - expect6) | |||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | |||
| assert np.all(diff6 < error6) | |||
| assert output[6].shape == expect6.shape | |||
| expect7 = np.min(x7, axis=axis7, keepdims=keep_dims7) | |||
| diff7 = abs(output[7].asnumpy() - expect7) | |||
| error7 = np.ones(shape=expect7.shape) * 1.0e-5 | |||
| assert np.all(diff7 < error7) | |||
| expect8 = np.min(x8, axis=np_axis8, keepdims=keep_dims8) | |||
| diff8 = abs(output[8].asnumpy() - expect8) | |||
| error8 = np.ones(shape=expect8.shape) * 1.0e-5 | |||
| assert np.all(diff8 < error8) | |||