| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,16 +18,18 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| SliceGpuFwdKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SliceGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SliceGpuFwdKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SliceGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| SliceGpuFwdKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||
| SliceGpuFwdKernel, int64_t) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SliceGpuFwdKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| SliceGpuFwdKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||
| SliceGpuFwdKernel, uchar) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <utility> | |||
| @@ -134,4 +134,4 @@ class SliceGpuFwdKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GPU_KERNEL_H | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,17 +18,21 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SliceGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| SliceGradGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SliceGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SliceGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SliceGradGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SliceGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SliceGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SliceGradGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| SliceGradGpuKernel, int16_t) | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| @@ -143,4 +143,4 @@ class SliceGradGpuKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SLICE_GRAD_GPU_KERNEL_H | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_ | |||
| @@ -18,6 +18,8 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| StridedSliceGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| StridedSliceGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <bitset> | |||
| @@ -210,4 +210,4 @@ class StridedSliceGpuKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GPU_KERNEL_H_ | |||
| @@ -18,6 +18,8 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| StridedSliceGradGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| StridedSliceGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <bitset> | |||
| @@ -211,4 +211,4 @@ class StridedSliceGradGpuKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_STRIDED_SLICE_GRAD_GPU_KERNEL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -159,57 +159,57 @@ void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int | |||
| dy, dx); | |||
| } | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const double *input, double *output, cudaStream_t stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const float *input, float *output, cudaStream_t stream); | |||
| template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const half *input, half *output, cudaStream_t stream); | |||
| template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output, | |||
| cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const int64_t *input, int64_t *output, | |||
| cudaStream_t stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const int *input, int *output, cudaStream_t stream); | |||
| template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output, | |||
| cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const short *input, short *output, // NOLINT | |||
| cudaStream_t stream); | |||
| template void CalSliceGrad<short>(const size_t input_size, const short *dy, // NOLINT | |||
| const std::vector<size_t> in_shape, const std::vector<int64_t> begin, | |||
| const std::vector<int64_t> size, | |||
| short *output, // NOLINT | |||
| cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const unsigned char *input, unsigned char *output, | |||
| cudaStream_t stream); | |||
| template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy, | |||
| const std::vector<size_t> in_shape, const std::vector<int64_t> begin, | |||
| const std::vector<int64_t> size, unsigned char *output, | |||
| cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const int64_t *input, int64_t *output, | |||
| cudaStream_t stream); | |||
| const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); | |||
| template void CalSliceGrad<double>(const size_t input_size, const double *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, double *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, half *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalSliceGrad<int64_t>(const size_t input_size, const int64_t *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, int64_t *output, | |||
| cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, | |||
| const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, | |||
| const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); | |||
| template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, int *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalSliceGrad<short>(const size_t input_size, const short *dy, // NOLINT | |||
| const std::vector<size_t> in_shape, const std::vector<int64_t> begin, | |||
| const std::vector<int64_t> size, short *output, // NOLINT | |||
| cudaStream_t cuda_stream); | |||
| template void CalSliceGrad<unsigned char>(const size_t input_size, const unsigned char *dy, | |||
| const std::vector<size_t> in_shape, const std::vector<int64_t> begin, | |||
| const std::vector<int64_t> size, unsigned char *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalSliceGrad<bool>(const size_t input_size, const bool *dy, const std::vector<size_t> in_shape, | |||
| const std::vector<int64_t> begin, const std::vector<int64_t> size, bool *output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -232,10 +232,15 @@ template void FillDeviceArray<unsigned char>(const size_t input_size, unsigned c | |||
| cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<double>(const size_t input_size, double *addr, const float value, | |||
| cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, | |||
| const bool *input, bool *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, | |||
| const double *input, double *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &strides, const std::vector<size_t> &output_shape, | |||
| const float *input, float *output, cudaStream_t cuda_stream); | |||
| @@ -270,6 +275,9 @@ template void StridedSlice(const std::vector<size_t> &input_shape, const std::ve | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, const bool *dy, | |||
| bool *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, | |||
| const double *dy, double *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<size_t> &dy_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &strides, const std::vector<size_t> &dx_shape, | |||
| const float *dy, float *dx, cudaStream_t cuda_stream); | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1859,7 +1859,7 @@ class StridedSliceGrad(PrimitiveWithInfer): | |||
| ellipsis_mask=0, | |||
| new_axis_mask=0, | |||
| shrink_axis_mask=0): | |||
| """Initialize StrideSliceGrad""" | |||
| """Initialize StridedSliceGrad""" | |||
| validator.check_value_type('begin_mask', begin_mask, [int], self.name) | |||
| validator.check_value_type('end_mask', end_mask, [int], self.name) | |||
| validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) | |||
| @@ -2792,7 +2792,7 @@ class StridedSlice(PrimitiveWithInfer): | |||
| ellipsis_mask=0, | |||
| new_axis_mask=0, | |||
| shrink_axis_mask=0): | |||
| """Initialize StrideSlice""" | |||
| """Initialize StridedSlice""" | |||
| self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) | |||
| validator.check_non_negative_int(begin_mask, 'begin_mask', self.name) | |||
| validator.check_non_negative_int(end_mask, 'end_mask', self.name) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -55,6 +55,9 @@ class SliceNet(nn.Cell): | |||
| return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_slice_4d(): | |||
| x_np = np.random.randn(32, 24, 224, 224).astype(np.float32) | |||
| output_np = x_np[:, 11:18, :, :] | |||
| @@ -64,3 +67,18 @@ def test_slice_4d(): | |||
| output_ms = net(x_ms) | |||
| assert (output_ms.asnumpy() == output_np).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_slice_float64(): | |||
| x = Tensor( | |||
| np.array([[[1, -1, 1], [2, -2, 2]], [[3, -3, 3], [4, -4, 4]], [[5, -5, 5], [6, -6, 6]]]).astype(np.float64)) | |||
| expect = np.array([[[2., -2., 2.]], | |||
| [[4., -4., 4.]]]).astype(np.float64) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| slice_op = Slice() | |||
| output = slice_op(x) | |||
| assert (output.asnumpy() == expect).all() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -50,5 +50,21 @@ def test_slice(): | |||
| [4., 1., 4.]], | |||
| [[0., 0., 0.], | |||
| [0., 0., 0.]]] | |||
| print(output) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_slice_float64(): | |||
| x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]).astype(np.float64)) | |||
| dy = Tensor(np.array([[[3., 1., 2.]], [[4., 1., 4.]]]).astype(np.float64)) | |||
| slicegrad = SliceGrad() | |||
| output = slicegrad(dy, x) | |||
| expect = np.array([[[0., 0., 0.], | |||
| [3., 1., 2.]], | |||
| [[0., 0., 0.], | |||
| [4., 1., 4.]], | |||
| [[0., 0., 0.], | |||
| [0., 0., 0.]]]).astype(np.float64) | |||
| assert (output.asnumpy() == expect).all() | |||
| @@ -239,6 +239,12 @@ def strided_slice_grad(nptype): | |||
| [0., 0., 0., 0., 0.]]]]]]]).astype(nptype) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_strided_slice_grad_float64(): | |||
| strided_slice_grad(np.float64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -102,6 +102,12 @@ def strided_slice(nptype): | |||
| [[[2122.]]]]]]]).astype(nptype) | |||
| assert np.allclose(y.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_strided_slice_float64(): | |||
| strided_slice(np.float64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||