| @@ -43,7 +43,7 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| inputA_descriptor_(nullptr), | |||
| outputC_descriptor_(nullptr), | |||
| keep_dims_(false), | |||
| is_reduce_dim_one_(true), | |||
| all_match_(false), | |||
| is_null_input_(false), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| @@ -65,7 +65,9 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| if (is_reduce_dim_one_) { | |||
| if (all_match_) { | |||
| MS_LOG(WARNING) | |||
| << "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."); | |||
| @@ -178,6 +180,7 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) { | |||
| std::vector<size_t> inputA_shape = input_shape; | |||
| std::vector<size_t> outputC_shape = output_shape; | |||
| std::vector<int> real_input_shape; | |||
| int shapeA_n, shapeA_c, shapeA_h, shapeA_w; | |||
| shapeA_n = inputA_shape.size() < 4 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 4]); | |||
| shapeA_c = inputA_shape.size() < 3 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 3]); | |||
| @@ -196,7 +199,9 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, | |||
| shapeC_n, shapeC_c, shapeC_h, shapeC_w), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| is_reduce_dim_one_ = false; | |||
| if (shapeA_n == shapeC_n && shapeA_c == shapeC_c && shapeA_h == shapeC_h && shapeA_w == shapeC_w) { | |||
| all_match_ = true; | |||
| } | |||
| return; | |||
| } | |||
| @@ -205,21 +210,16 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); | |||
| } | |||
| } | |||
| for (auto i : axis_) { | |||
| if (inputA_shape[IntToSize(i)] != 1) { | |||
| // To avoid cudnnReduceTensor bug when the dimension which needs to be | |||
| // reduced is already 1. | |||
| is_reduce_dim_one_ = false; | |||
| } | |||
| } | |||
| shapeC_n = outputC_shape.size() < 4 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 4]); | |||
| shapeC_c = outputC_shape.size() < 3 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 3]); | |||
| shapeC_h = outputC_shape.size() < 2 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 2]); | |||
| shapeC_w = SizeToInt(outputC_shape[outputC_shape.size() - 1]); | |||
| shapeC_w = outputC_shape.size() == 0 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 1]); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, shapeC_n, | |||
| shapeC_c, shapeC_h, shapeC_w), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| if (shapeA_n == shapeC_n && shapeA_c == shapeC_c && shapeA_h == shapeC_h && shapeA_w == shapeC_w) { | |||
| all_match_ = true; | |||
| } | |||
| return; | |||
| } | |||
| @@ -234,7 +234,7 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||
| std::vector<int> axis_; | |||
| bool keep_dims_; | |||
| bool is_reduce_dim_one_; | |||
| bool all_match_; | |||
| bool is_null_input_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -114,6 +114,14 @@ class BinaryOpGpuKernel : public GpuKernel { | |||
| InferBinaryType(kernel_node); | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto input_shapeB = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (input_shape != output_shape && input_shapeB != output_shape) { | |||
| MS_LOG(ERROR) << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" | |||
| "InputA must match the corresponding dimension of the destination tensor outC, and each " | |||
| "dimension of the inputB " | |||
| "must match the corresponding dimension of outC or must be equal to 1."; | |||
| return false; | |||
| } | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_shapeB); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "BinaryOpGpuKernel input is null"; | |||
| @@ -90,7 +90,7 @@ class TensorAddGpuFwdKernel : public GpuKernel { | |||
| if (input_shape != output_shape && input_shapeB != output_shape) { | |||
| MS_LOG(ERROR) << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" | |||
| "InputA must match the corresponding dimension of the destination tensor outC, and each " | |||
| "dimension of the inputB" | |||
| "dimension of the inputB " | |||
| "must match the corresponding dimension of outC or must be equal to 1."; | |||
| return false; | |||
| } | |||
| @@ -50,7 +50,11 @@ template <typename T> | |||
| class UnaryOpGpuKernel : public GpuKernel { | |||
| public: | |||
| UnaryOpGpuKernel() | |||
| : unary_op_type_(UNARY_OP_INVALID_TYPE), input_size_(sizeof(T)), output_size_(sizeof(T)), workspace_size_(0) {} | |||
| : unary_op_type_(UNARY_OP_INVALID_TYPE), | |||
| input_size_(sizeof(T)), | |||
| output_size_(sizeof(T)), | |||
| workspace_size_(0), | |||
| is_null_input_(false) {} | |||
| ~UnaryOpGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -121,6 +125,12 @@ class UnaryOpGpuKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "UnaryOpGpuKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| @@ -140,6 +150,7 @@ class UnaryOpGpuKernel : public GpuKernel { | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| bool is_null_input_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -55,6 +55,11 @@ x7 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis7 = (-2, -1) | |||
| keep_dims7 = True | |||
| x8 = np.random.rand(1, 1, 1, 1).astype(np.float32) | |||
| axis8 = () | |||
| np_axis8 = None | |||
| keep_dims8 = True | |||
| context.set_context(device_target='GPU') | |||
| @@ -94,6 +99,10 @@ class ReduceMax(nn.Cell): | |||
| self.axis7 = axis7 | |||
| self.keep_dims7 = keep_dims7 | |||
| self.x8 = Tensor(x8) | |||
| self.axis8 = axis8 | |||
| self.keep_dims8 = keep_dims8 | |||
| @ms_function | |||
| def construct(self): | |||
| return (P.ReduceMax(self.keep_dims0)(self.x0, self.axis0), | |||
| @@ -103,7 +112,8 @@ class ReduceMax(nn.Cell): | |||
| P.ReduceMax(self.keep_dims4)(self.x4, self.axis4), | |||
| P.ReduceMax(self.keep_dims5)(self.x5, self.axis5), | |||
| P.ReduceMax(self.keep_dims6)(self.x6, self.axis6), | |||
| P.ReduceMax(self.keep_dims7)(self.x7, self.axis7)) | |||
| P.ReduceMax(self.keep_dims7)(self.x7, self.axis7), | |||
| P.ReduceMax(self.keep_dims8)(self.x8, self.axis8)) | |||
| @pytest.mark.level0 | |||
| @@ -114,48 +124,53 @@ def test_ReduceMax(): | |||
| output = reduce_max() | |||
| expect0 = np.max(x0, axis=axis0, keepdims=keep_dims0) | |||
| diff0 = output[0].asnumpy() - expect0 | |||
| diff0 = abs(output[0].asnumpy() - expect0) | |||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | |||
| assert np.all(diff0 < error0) | |||
| assert (output[0].shape() == expect0.shape) | |||
| expect1 = np.max(x1, axis=axis1, keepdims=keep_dims1) | |||
| diff1 = output[1].asnumpy() - expect1 | |||
| diff1 = abs(output[1].asnumpy() - expect1) | |||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||
| assert np.all(diff1 < error1) | |||
| assert (output[1].shape() == expect1.shape) | |||
| expect2 = np.max(x2, axis=axis2, keepdims=keep_dims2) | |||
| diff2 = output[2].asnumpy() - expect2 | |||
| diff2 = abs(output[2].asnumpy() - expect2) | |||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | |||
| assert np.all(diff2 < error2) | |||
| assert (output[2].shape() == expect2.shape) | |||
| expect3 = np.max(x3, axis=axis3, keepdims=keep_dims3) | |||
| diff3 = output[3].asnumpy() - expect3 | |||
| diff3 = abs(output[3].asnumpy() - expect3) | |||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | |||
| assert np.all(diff3 < error3) | |||
| assert (output[3].shape() == expect3.shape) | |||
| expect4 = np.max(x4, axis=np_axis4, keepdims=keep_dims4) | |||
| diff4 = output[4].asnumpy() - expect4 | |||
| diff4 = abs(output[4].asnumpy() - expect4) | |||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | |||
| assert np.all(diff4 < error4) | |||
| assert (output[4].shape() == expect4.shape) | |||
| expect5 = np.max(x5, axis=np_axis5, keepdims=keep_dims5) | |||
| diff5 = output[5].asnumpy() - expect5 | |||
| diff5 = abs(output[5].asnumpy() - expect5) | |||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | |||
| assert np.all(diff5 < error5) | |||
| assert (output[5].shape() == expect5.shape) | |||
| expect6 = np.max(x6, axis=axis6, keepdims=keep_dims6) | |||
| diff6 = output[6].asnumpy() - expect6 | |||
| diff6 = abs(output[6].asnumpy() - expect6) | |||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | |||
| assert np.all(diff6 < error6) | |||
| assert (output[6].shape() == expect6.shape) | |||
| expect7 = np.max(x7, axis=axis7, keepdims=keep_dims7) | |||
| diff7 = output[7].asnumpy() - expect7 | |||
| diff7 = abs(output[7].asnumpy() - expect7) | |||
| error7 = np.ones(shape=expect7.shape) * 1.0e-5 | |||
| assert np.all(diff7 < error7) | |||
| expect8 = np.max(x8, axis=np_axis8, keepdims=keep_dims8) | |||
| diff8 = abs(output[8].asnumpy() - expect8) | |||
| error8 = np.ones(shape=expect8.shape) * 1.0e-5 | |||
| assert np.all(diff8 < error8) | |||
| @@ -77,6 +77,11 @@ x13 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis13 = (-2, -1) | |||
| keep_dims13 = True | |||
| x14 = np.random.rand(1, 1, 1, 1).astype(np.float32) | |||
| axis14 = () | |||
| np_axis14 = None | |||
| keep_dims14 = True | |||
| context.set_context(device_target='GPU') | |||
| @@ -140,6 +145,10 @@ class ReduceMean(nn.Cell): | |||
| self.axis13 = axis13 | |||
| self.keep_dims13 = keep_dims13 | |||
| self.x14 = Tensor(x14) | |||
| self.axis14 = axis14 | |||
| self.keep_dims14 = keep_dims14 | |||
| @ms_function | |||
| def construct(self): | |||
| return (P.ReduceMean(self.keep_dims0)(self.x0, self.axis0), | |||
| @@ -155,7 +164,8 @@ class ReduceMean(nn.Cell): | |||
| P.ReduceMean(self.keep_dims10)(self.x10, self.axis10), | |||
| P.ReduceMean(self.keep_dims11)(self.x11, self.axis11), | |||
| P.ReduceMean(self.keep_dims12)(self.x12, self.axis12), | |||
| P.ReduceMean(self.keep_dims13)(self.x13, self.axis13)) | |||
| P.ReduceMean(self.keep_dims13)(self.x13, self.axis13), | |||
| P.ReduceMean(self.keep_dims14)(self.x14, self.axis14)) | |||
| @pytest.mark.level0 | |||
| @@ -166,85 +176,91 @@ def test_ReduceMean(): | |||
| output = reduce_mean() | |||
| expect0 = np.mean(x0, axis=axis0, keepdims=keep_dims0) | |||
| diff0 = output[0].asnumpy() - expect0 | |||
| diff0 = abs(output[0].asnumpy() - expect0) | |||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | |||
| assert np.all(diff0 < error0) | |||
| assert (output[0].shape() == expect0.shape) | |||
| expect1 = np.mean(x1, axis=axis1, keepdims=keep_dims1) | |||
| diff1 = output[1].asnumpy() - expect1 | |||
| diff1 = abs(output[1].asnumpy() - expect1) | |||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||
| assert np.all(diff1 < error1) | |||
| assert (output[1].shape() == expect1.shape) | |||
| expect2 = np.mean(x2, axis=axis2, keepdims=keep_dims2) | |||
| diff2 = output[2].asnumpy() - expect2 | |||
| diff2 = abs(output[2].asnumpy() - expect2) | |||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | |||
| assert np.all(diff2 < error2) | |||
| assert (output[2].shape() == expect2.shape) | |||
| expect3 = np.mean(x3, axis=axis3, keepdims=keep_dims3) | |||
| diff3 = output[3].asnumpy() - expect3 | |||
| diff3 = abs(output[3].asnumpy() - expect3) | |||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | |||
| assert np.all(diff3 < error3) | |||
| assert (output[3].shape() == expect3.shape) | |||
| expect4 = np.mean(x4, axis=axis4, keepdims=keep_dims4) | |||
| diff4 = output[4].asnumpy() - expect4 | |||
| diff4 = abs(output[4].asnumpy() - expect4) | |||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | |||
| assert np.all(diff4 < error4) | |||
| assert (output[4].shape() == expect4.shape) | |||
| expect5 = np.mean(x5, axis=axis5, keepdims=keep_dims5) | |||
| diff5 = output[5].asnumpy() - expect5 | |||
| diff5 = abs(output[5].asnumpy() - expect5) | |||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | |||
| assert np.all(diff5 < error5) | |||
| assert (output[5].shape() == expect5.shape) | |||
| expect6 = np.mean(x6, axis=axis6, keepdims=keep_dims6) | |||
| diff6 = output[6].asnumpy() - expect6 | |||
| diff6 = abs(output[6].asnumpy() - expect6) | |||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | |||
| assert np.all(diff6 < error6) | |||
| assert (output[6].shape() == expect6.shape) | |||
| expect7 = np.mean(x7, axis=axis7, keepdims=keep_dims7) | |||
| diff7 = output[7].asnumpy() - expect7 | |||
| diff7 = abs(output[7].asnumpy() - expect7) | |||
| error7 = np.ones(shape=expect7.shape) * 1.0e-5 | |||
| assert np.all(diff7 < error7) | |||
| assert (output[7].shape() == expect7.shape) | |||
| expect8 = np.mean(x8, axis=axis8, keepdims=keep_dims8) | |||
| diff8 = output[8].asnumpy() - expect8 | |||
| diff8 = abs(output[8].asnumpy() - expect8) | |||
| error8 = np.ones(shape=expect8.shape) * 1.0e-5 | |||
| assert np.all(diff8 < error8) | |||
| assert (output[8].shape() == expect8.shape) | |||
| expect9 = np.mean(x9, axis=axis9, keepdims=keep_dims9) | |||
| diff9 = output[9].asnumpy() - expect9 | |||
| diff9 = abs(output[9].asnumpy() - expect9) | |||
| error9 = np.ones(shape=expect9.shape) * 1.0e-5 | |||
| assert np.all(diff9 < error9) | |||
| assert (output[9].shape() == expect9.shape) | |||
| expect10 = np.mean(x10, axis=axis10, keepdims=keep_dims10) | |||
| diff10 = output[10].asnumpy() - expect10 | |||
| diff10 = abs(output[10].asnumpy() - expect10) | |||
| error10 = np.ones(shape=expect10.shape) * 1.0e-5 | |||
| assert np.all(diff10 < error10) | |||
| assert (output[10].shape() == expect10.shape) | |||
| expect11 = np.mean(x11, axis=axis11, keepdims=keep_dims11) | |||
| diff11 = output[11].asnumpy() - expect11 | |||
| diff11 = abs(output[11].asnumpy() - expect11) | |||
| error11 = np.ones(shape=expect11.shape) * 1.0e-5 | |||
| assert np.all(diff11 < error11) | |||
| assert (output[11].shape() == expect11.shape) | |||
| expect12 = np.sum(x12, axis=axis12, keepdims=keep_dims12) | |||
| diff12 = output[12].asnumpy() - expect12 | |||
| expect12 = np.mean(x12, axis=axis12, keepdims=keep_dims12) | |||
| diff12 = abs(output[12].asnumpy() - expect12) | |||
| error12 = np.ones(shape=expect12.shape) * 1.0e-5 | |||
| assert np.all(diff12 < error12) | |||
| assert (output[12].shape() == expect12.shape) | |||
| expect13 = np.sum(x13, axis=axis13, keepdims=keep_dims13) | |||
| diff13 = output[13].asnumpy() - expect13 | |||
| expect13 = np.mean(x13, axis=axis13, keepdims=keep_dims13) | |||
| diff13 = abs(output[13].asnumpy() - expect13) | |||
| error13 = np.ones(shape=expect13.shape) * 1.0e-5 | |||
| assert np.all(diff13 < error13) | |||
| assert (output[13].shape() == expect13.shape) | |||
| expect14 = np.mean(x14, axis=np_axis14, keepdims=keep_dims14) | |||
| diff14 = abs(output[14].asnumpy() - expect14) | |||
| error14 = np.ones(shape=expect14.shape) * 1.0e-5 | |||
| assert np.all(diff14 < error14) | |||
| assert (output[14].shape() == expect14.shape) | |||
| @@ -79,6 +79,11 @@ x13 = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| axis13 = (-2, -1) | |||
| keep_dims13 = True | |||
| x14 = np.random.rand(1, 1, 1, 1).astype(np.float32) | |||
| axis14 = () | |||
| np_axis14 = None | |||
| keep_dims14 = True | |||
| context.set_context(device_target='GPU') | |||
| @@ -142,6 +147,10 @@ class ReduceSum(nn.Cell): | |||
| self.axis13 = axis13 | |||
| self.keep_dims13 = keep_dims13 | |||
| self.x14 = Tensor(x14) | |||
| self.axis14 = axis14 | |||
| self.keep_dims14 = keep_dims14 | |||
| @ms_function | |||
| def construct(self): | |||
| return (P.ReduceSum(self.keep_dims0)(self.x0, self.axis0), | |||
| @@ -157,7 +166,8 @@ class ReduceSum(nn.Cell): | |||
| P.ReduceSum(self.keep_dims10)(self.x10, self.axis10), | |||
| P.ReduceSum(self.keep_dims11)(self.x11, self.axis11), | |||
| P.ReduceSum(self.keep_dims12)(self.x12, self.axis12), | |||
| P.ReduceSum(self.keep_dims13)(self.x13, self.axis13)) | |||
| P.ReduceSum(self.keep_dims13)(self.x13, self.axis13), | |||
| P.ReduceSum(self.keep_dims14)(self.x14, self.axis14)) | |||
| @pytest.mark.level0 | |||
| @@ -168,85 +178,91 @@ def test_ReduceSum(): | |||
| output = reduce_sum() | |||
| expect0 = np.sum(x0, axis=axis0, keepdims=keep_dims0) | |||
| diff0 = output[0].asnumpy() - expect0 | |||
| diff0 = abs(output[0].asnumpy() - expect0) | |||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | |||
| assert np.all(diff0 < error0) | |||
| assert (output[0].shape() == expect0.shape) | |||
| expect1 = np.sum(x1, axis=axis1, keepdims=keep_dims1) | |||
| diff1 = output[1].asnumpy() - expect1 | |||
| diff1 = abs(output[1].asnumpy() - expect1) | |||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||
| assert np.all(diff1 < error1) | |||
| assert (output[1].shape() == expect1.shape) | |||
| expect2 = np.sum(x2, axis=axis2, keepdims=keep_dims2) | |||
| diff2 = output[2].asnumpy() - expect2 | |||
| diff2 = abs(output[2].asnumpy() - expect2) | |||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | |||
| assert np.all(diff2 < error2) | |||
| assert (output[2].shape() == expect2.shape) | |||
| expect3 = np.sum(x3, axis=axis3, keepdims=keep_dims3) | |||
| diff3 = output[3].asnumpy() - expect3 | |||
| diff3 = abs(output[3].asnumpy() - expect3) | |||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | |||
| assert np.all(diff3 < error3) | |||
| assert (output[3].shape() == expect3.shape) | |||
| expect4 = np.sum(x4, axis=np_axis4, keepdims=keep_dims4) | |||
| diff4 = output[4].asnumpy() - expect4 | |||
| diff4 = abs(output[4].asnumpy() - expect4) | |||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | |||
| assert np.all(diff4 < error4) | |||
| assert (output[4].shape() == expect4.shape) | |||
| expect5 = np.sum(x5, axis=np_axis5, keepdims=keep_dims5) | |||
| diff5 = output[5].asnumpy() - expect5 | |||
| diff5 = abs(output[5].asnumpy() - expect5) | |||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | |||
| assert np.all(diff5 < error5) | |||
| assert (output[5].shape() == expect5.shape) | |||
| expect6 = np.sum(x6, axis=axis6, keepdims=keep_dims6) | |||
| diff6 = output[6].asnumpy() - expect6 | |||
| diff6 = abs(output[6].asnumpy() - expect6) | |||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | |||
| assert np.all(diff6 < error6) | |||
| assert (output[6].shape() == expect6.shape) | |||
| expect7 = np.sum(x7, axis=axis7, keepdims=keep_dims7) | |||
| diff7 = output[7].asnumpy() - expect7 | |||
| diff7 = abs(output[7].asnumpy() - expect7) | |||
| error7 = np.ones(shape=expect7.shape) * 1.0e-5 | |||
| assert np.all(diff7 < error7) | |||
| assert (output[7].shape() == expect7.shape) | |||
| expect8 = np.sum(x8, axis=axis8, keepdims=keep_dims8) | |||
| diff8 = output[8].asnumpy() - expect8 | |||
| diff8 = abs(output[8].asnumpy() - expect8) | |||
| error8 = np.ones(shape=expect8.shape) * 1.0e-5 | |||
| assert np.all(diff8 < error8) | |||
| assert (output[8].shape() == expect8.shape) | |||
| expect9 = np.sum(x9, axis=axis9, keepdims=keep_dims9) | |||
| diff9 = output[9].asnumpy() - expect9 | |||
| diff9 = abs(output[9].asnumpy() - expect9) | |||
| error9 = np.ones(shape=expect9.shape) * 1.0e-5 | |||
| assert np.all(diff9 < error9) | |||
| assert (output[9].shape() == expect9.shape) | |||
| expect10 = np.sum(x10, axis=axis10, keepdims=keep_dims10) | |||
| diff10 = output[10].asnumpy() - expect10 | |||
| diff10 = abs(output[10].asnumpy() - expect10) | |||
| error10 = np.ones(shape=expect10.shape) * 1.0e-5 | |||
| assert np.all(diff10 < error10) | |||
| assert (output[10].shape() == expect10.shape) | |||
| expect11 = np.sum(x11, axis=axis11, keepdims=keep_dims11) | |||
| diff11 = output[11].asnumpy() - expect11 | |||
| diff11 = abs(output[11].asnumpy() - expect11) | |||
| error11 = np.ones(shape=expect11.shape) * 1.0e-5 | |||
| assert np.all(diff11 < error11) | |||
| assert (output[11].shape() == expect11.shape) | |||
| expect12 = np.sum(x12, axis=axis12, keepdims=keep_dims12) | |||
| diff12 = output[12].asnumpy() - expect12 | |||
| diff12 = abs(output[12].asnumpy() - expect12) | |||
| error12 = np.ones(shape=expect12.shape) * 1.0e-5 | |||
| assert np.all(diff12 < error12) | |||
| assert (output[12].shape() == expect12.shape) | |||
| expect13 = np.sum(x13, axis=axis13, keepdims=keep_dims13) | |||
| diff13 = output[13].asnumpy() - expect13 | |||
| diff13 = abs(output[13].asnumpy() - expect13) | |||
| error13 = np.ones(shape=expect13.shape) * 1.0e-5 | |||
| assert np.all(diff13 < error13) | |||
| assert (output[13].shape() == expect13.shape) | |||
| expect14 = np.sum(x14, axis=np_axis14, keepdims=keep_dims14) | |||
| diff14 = abs(output[14].asnumpy() - expect14) | |||
| error14 = np.ones(shape=expect14.shape) * 1.0e-5 | |||
| assert np.all(diff14 < error14) | |||
| assert (output[14].shape() == expect14.shape) | |||