From: @wangyanling10 Reviewed-by: Signed-off-by:pull/13693/MERGE
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2021Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -102,8 +102,10 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| ret = LaunchKernel<float>(inputs, outputs); | ret = LaunchKernel<float>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeBool) { | } else if (dtype_ == kNumberTypeBool) { | ||||
| ret = LaunchKernel<bool>(inputs, outputs); | ret = LaunchKernel<bool>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeFloat64) { | |||||
| ret = LaunchKernel<double>(inputs, outputs); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Slice op only support input_x int32 and float32"; | |||||
| MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -55,9 +55,14 @@ class SliceCPUKernel : public CPUKernel { | |||||
| TypeId dtype_{kTypeUnknown}; | TypeId dtype_{kTypeUnknown}; | ||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| SliceCPUKernel); | SliceCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel); | MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SliceCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| SliceCPUKernel); | SliceCPUKernel); | ||||
| MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| @@ -86,8 +86,10 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| ret = LaunchKernel<float>(inputs, outputs); | ret = LaunchKernel<float>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeBool) { | } else if (dtype_ == kNumberTypeBool) { | ||||
| ret = LaunchKernel<bool>(inputs, outputs); | ret = LaunchKernel<bool>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeFloat64) { | |||||
| ret = LaunchKernel<double>(inputs, outputs); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Slice op only support input_x int32 and float32"; | |||||
| MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -60,10 +60,23 @@ MS_REG_CPU_KERNEL( | |||||
| SliceGrad, | SliceGrad, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| SliceGradCPUKernel); | SliceGradCPUKernel); | ||||
| MS_REG_CPU_KERNEL( | |||||
| SliceGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SliceGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| SliceGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| SliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||||
| SliceGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| SliceGradCPUKernel); | SliceGradCPUKernel); | ||||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| SliceGradCPUKernel); | SliceGradCPUKernel); | ||||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SliceGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | ||||
| SliceGradCPUKernel); | SliceGradCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -296,7 +296,7 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): | |||||
| TypeError: If `quant_delay` is not greater than or equal to 0. | TypeError: If `quant_delay` is not greater than or equal to 0. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> fake_quant = nn.FakeQuantWithMinMaxObserver() | >>> fake_quant = nn.FakeQuantWithMinMaxObserver() | ||||
| @@ -451,7 +451,7 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||||
| ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'. | ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> qconfig = compression.quant.create_quant_config() | >>> qconfig = compression.quant.create_quant_config() | ||||
| @@ -4633,7 +4633,7 @@ class BroadcastTo(PrimitiveWithInfer): | |||||
| target shape is in an invalid location. | target shape is in an invalid location. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> shape = (2, 3) | >>> shape = (2, 3) | ||||
| @@ -78,6 +78,21 @@ def test_slice_grad2(): | |||||
| [[0., 0.], [8., 9.], [10., 11.]]] | [[0., 0.], [8., 9.], [10., 11.]]] | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| def test_slice_grad3(): | |||||
| x = Tensor(np.array([[[1.0, 3.5, 5.8], [2.5, 4, 1]], [[3.5, 15.3, 3.1], [2.2, 4.0, 1.1]], | |||||
| [[43.4, 1.1, 12.1], [2.4, 6.5, 6.3]]]), mstype.float64) | |||||
| dy = Tensor(np.array([[[3.1, 1.1, 2.2]], [[4.4, 1.2, 4.2]]]), mstype.float64) | |||||
| slicegrad = SliceGrad() | |||||
| output = slicegrad(dy, x) | |||||
| expect = [[[0., 0., 0.], | |||||
| [3.1, 1.1, 2.2]], | |||||
| [[0., 0., 0.], | |||||
| [4.4, 1.2, 4.2]], | |||||
| [[0., 0., 0.], | |||||
| [0., 0., 0.]]] | |||||
| print("output:\n", output) | |||||
| assert (output.asnumpy() == expect).all() | |||||
| class StridedSliceGrad(nn.Cell): | class StridedSliceGrad(nn.Cell): | ||||
| def __init__(self, x, begin, end, stride): | def __init__(self, x, begin, end, stride): | ||||
| super(StridedSliceGrad, self).__init__() | super(StridedSliceGrad, self).__init__() | ||||
| @@ -69,6 +69,14 @@ def test_slice2(): | |||||
| output = slice_op(x) | output = slice_op(x) | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| def test_slice_float64(): | |||||
| data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], | |||||
| [[3, 3, 3], [4, 4, 4]], | |||||
| [[5, 5, 5], [6, 6, 6]]]).astype(np.float64)) | |||||
| slice_op = P.Slice() | |||||
| output = slice_op(data, (1, 0, 0), (1, 1, 3)) | |||||
| expect = [[[3.0, 3.0, 3.0]]] | |||||
| assert (output.asnumpy() == expect).all() | |||||
| class Slice3(nn.Cell): | class Slice3(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||