diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h index 162c40d2c3..f4afe298f5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h @@ -62,6 +62,14 @@ MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOut ReshapeCPUKernel); MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); + +MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Squeeze, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ReshapeCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/cpu/test_squeeze_op.py b/tests/st/ops/cpu/test_squeeze_op.py new file mode 100644 index 0000000000..883579e4a0 --- /dev/null +++ b/tests/st/ops/cpu/test_squeeze_op.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +import mindspore.ops as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class SqueezeNet(Cell): + def __init__(self): + super(SqueezeNet, self).__init__() + self.squeeze = P.Squeeze() + + def construct(self, x): + return self.squeeze(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_squeeze_shape_float32(): + x = np.ones(shape=[1, 2, 1, 1, 8, 3, 1]).astype(np.float32) + expect = np.ones(shape=[2, 8, 3]).astype(np.float32) + net = SqueezeNet() + result = net(Tensor(x)) + assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_squeeze_shape_int32(): + x = np.array([[7], [11]]).astype(np.int32) + expect = np.array([7, 11]).astype(np.int32) + net = SqueezeNet() + result = net(Tensor(x)) + assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_squeeze_shape_bool(): + x = np.array([[True], [False]]).astype(np.bool_) + expect = np.array([True, False]).astype(np.bool_) + net = SqueezeNet() + result = net(Tensor(x)) + assert np.allclose(result.asnumpy(), expect, rtol=1.e-4, atol=1.e-8, equal_nan=True)