diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc index 3ad2a6a264..29c9357f32 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,16 +18,18 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + SliceGpuFwdKernel, double) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SliceGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGpuFwdKernel, int) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), SliceGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - SliceGpuFwdKernel, int16_t) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), SliceGpuFwdKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + SliceGpuFwdKernel, int16_t) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), SliceGpuFwdKernel, uchar) MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h index 002a623de8..63f39986db 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_ #include #include @@ -134,4 +134,4 @@ class SliceGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc index 5b81b02bc2..68b60123b7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,17 +18,21 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + SliceGradGpuKernel, double) MS_REG_GPU_KERNEL_ONE( SliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SliceGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGradGpuKernel, int) MS_REG_GPU_KERNEL_ONE( SliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), SliceGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGradGpuKernel, int) MS_REG_GPU_KERNEL_ONE( SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SliceGradGpuKernel, int16_t) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h index 3d9fd28a41..6e2e94a1cb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_ #include #include @@ -143,4 +143,4 @@ class SliceGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc index 2b6848fd36..f7a93f23b2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.cc @@ -18,6 +18,8 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + StridedSliceGpuKernel, double) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), StridedSliceGpuKernel, float) MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h index 3aa3f4df06..7899a4af99 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_ #include #include @@ -210,4 +210,4 @@ class StridedSliceGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc index ceaccaa26f..b135d5ddfa 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.cc @@ -18,6 +18,8 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + StridedSliceGradGpuKernel, double) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), StridedSliceGradGpuKernel, float) MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h index 5c1e6ef425..b98149b70a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_ #include #include @@ -211,4 +211,4 @@ class StridedSliceGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu index 9b733ebf86..2fbf2c0a60 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -159,57 +159,57 @@ void StridedSliceGrad(const std::vector &dy_shape, const std::vector(const size_t input_size, const float *dy, const std::vector in_shape, - const std::vector begin, const std::vector size, float *output, - cudaStream_t cuda_stream); - template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const half *dy, const std::vector in_shape, - const std::vector begin, const std::vector size, half *output, - cudaStream_t cuda_stream); - +template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const int64_t *input, int64_t *output, + cudaStream_t stream); template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const int *dy, const std::vector in_shape, - const std::vector begin, const std::vector size, int *output, - cudaStream_t cuda_stream); - template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const short *input, short *output, // NOLINT cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const short *dy, // NOLINT - const std::vector in_shape, const std::vector begin, - const std::vector size, - short *output, // NOLINT - cudaStream_t cuda_stream); - template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const unsigned char *dy, - const std::vector in_shape, const std::vector begin, - const std::vector size, unsigned char *output, - cudaStream_t cuda_stream); - template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, - const size_t d3, const size_t d4, const int64_t *input, int64_t *output, - cudaStream_t stream); + const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); + +template void CalSliceGrad(const size_t input_size, const double *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, double *output, + cudaStream_t cuda_stream); +template void CalSliceGrad(const size_t input_size, const float *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, float *output, + cudaStream_t cuda_stream); +template void CalSliceGrad(const size_t input_size, const half *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, half *output, + cudaStream_t cuda_stream); template void CalSliceGrad(const size_t input_size, const int64_t *dy, const std::vector in_shape, const std::vector begin, const std::vector size, int64_t *output, cudaStream_t cuda_stream); - -template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, - const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, - const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const int *dy, const std::vector in_shape, + const std::vector begin, const std::vector size, int *output, + cudaStream_t cuda_stream); +template void CalSliceGrad(const size_t input_size, const short *dy, // NOLINT + const std::vector in_shape, const std::vector begin, + const std::vector size, short *output, // NOLINT + cudaStream_t cuda_stream); +template void CalSliceGrad(const size_t input_size, const unsigned char *dy, + const std::vector in_shape, const std::vector begin, + const std::vector size, unsigned char *output, + cudaStream_t cuda_stream); template void CalSliceGrad(const size_t input_size, const bool *dy, const std::vector in_shape, const std::vector begin, const std::vector size, bool *output, cudaStream_t cuda_stream); @@ -232,10 +232,15 @@ template void FillDeviceArray(const size_t input_size, unsigned c cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, double *addr, const float value, + cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const bool *input, bool *output, cudaStream_t cuda_stream); +template void StridedSlice(const std::vector &input_shape, const std::vector &begin, + const std::vector &strides, const std::vector &output_shape, + const double *input, double *output, cudaStream_t cuda_stream); template void StridedSlice(const std::vector &input_shape, const std::vector &begin, const std::vector &strides, const std::vector &output_shape, const float *input, float *output, cudaStream_t cuda_stream); @@ -270,6 +275,9 @@ template void StridedSlice(const std::vector &input_shape, const std::ve template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const bool *dy, bool *dx, cudaStream_t cuda_stream); +template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, + const std::vector &strides, const std::vector &dx_shape, + const double *dy, double *dx, cudaStream_t cuda_stream); template void StridedSliceGrad(const std::vector &dy_shape, const std::vector &begin, const std::vector &strides, const std::vector &dx_shape, const float *dy, float *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 1bc88762ec..e25af8e2bb 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1859,7 +1859,7 @@ class StridedSliceGrad(PrimitiveWithInfer): ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): - """Initialize StrideSliceGrad""" + """Initialize StridedSliceGrad""" validator.check_value_type('begin_mask', begin_mask, [int], self.name) validator.check_value_type('end_mask', end_mask, [int], self.name) validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b14a4d8977..896a188b57 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2792,7 +2792,7 @@ class StridedSlice(PrimitiveWithInfer): ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0): - """Initialize StrideSlice""" + """Initialize StridedSlice""" self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) validator.check_non_negative_int(begin_mask, 'begin_mask', self.name) validator.check_non_negative_int(end_mask, 'end_mask', self.name) diff --git a/tests/st/ops/gpu/test_slice.py b/tests/st/ops/gpu/test_slice.py index 4f63bb83a6..aef14084b5 100644 --- a/tests/st/ops/gpu/test_slice.py +++ b/tests/st/ops/gpu/test_slice.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -55,6 +55,9 @@ class SliceNet(nn.Cell): return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224)) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_slice_4d(): x_np = np.random.randn(32, 24, 224, 224).astype(np.float32) output_np = x_np[:, 11:18, :, :] @@ -64,3 +67,18 @@ def test_slice_4d(): output_ms = net(x_ms) assert (output_ms.asnumpy() == output_np).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_slice_float64(): + x = Tensor( + np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]).astype(np.float64)) + expect = np.array([[[2., -2., 2.]], + [[4., -4., 4.]]]).astype(np.float64) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + slice_op = Slice() + output = slice_op(x) + assert (output.asnumpy() == expect).all() diff --git a/tests/st/ops/gpu/test_slice_grad.py b/tests/st/ops/gpu/test_slice_grad.py index d6cc84d6ef..8b1cd982ef 100644 --- a/tests/st/ops/gpu/test_slice_grad.py +++ b/tests/st/ops/gpu/test_slice_grad.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,5 +50,21 @@ def test_slice(): [4., 1., 4.]], [[0., 0., 0.], [0., 0., 0.]]] - print(output) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_slice_float64(): + x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]).astype(np.float64)) + dy = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]).astype(np.float64)) + slicegrad = SliceGrad() + output = slicegrad(dy, x) + expect = np.array([[[0., 0., 0.], + [3., 1., 2.]], + [[0., 0., 0.], + [4., 1., 4.]], + [[0., 0., 0.], + [0., 0., 0.]]]).astype(np.float64) assert (output.asnumpy() == expect).all() diff --git a/tests/st/ops/gpu/test_stridedslice_grad_op.py b/tests/st/ops/gpu/test_stridedslice_grad_op.py index 093333440b..91760aba18 100644 --- a/tests/st/ops/gpu/test_stridedslice_grad_op.py +++ b/tests/st/ops/gpu/test_stridedslice_grad_op.py @@ -239,6 +239,12 @@ def strided_slice_grad(nptype): [0., 0., 0., 0., 0.]]]]]]]).astype(nptype) assert np.allclose(dx[0].asnumpy(), expect) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_grad_float64(): + strided_slice_grad(np.float64) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard diff --git a/tests/st/ops/gpu/test_stridedslice_op.py b/tests/st/ops/gpu/test_stridedslice_op.py index ecbb495a89..dc31fbc94d 100644 --- a/tests/st/ops/gpu/test_stridedslice_op.py +++ b/tests/st/ops/gpu/test_stridedslice_op.py @@ -102,6 +102,12 @@ def strided_slice(nptype): [[[2122.]]]]]]]).astype(nptype) assert np.allclose(y.asnumpy(), expect) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_strided_slice_float64(): + strided_slice(np.float64) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard