Browse Source

!12399 Add type support to Squeeze gpu op

From: @peilin-wang
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
feb07198e7
2 changed files with 81 additions and 50 deletions
  1. +7
    -2
      mindspore/ops/_op_impl/akg/gpu/squeeze.py
  2. +74
    -48
      tests/st/ops/gpu/test_squeeze_op.py

+ 7
- 2
mindspore/ops/_op_impl/akg/gpu/squeeze.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,9 +22,14 @@ squeeze_op_info = AkgGpuRegOp("Squeeze") \
.attr("axis", "optional", "listInt") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info()



+ 74
- 48
tests/st/ops/gpu/test_squeeze_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,67 +13,93 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P

context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")


class Net(nn.Cell):
class SqueezeNet(nn.Cell):
def __init__(self):
super(Net, self).__init__()
super(SqueezeNet, self).__init__()
self.squeeze = P.Squeeze()

def construct(self, tensor):
return self.squeeze(tensor)


def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())


def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())


def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())


def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())

def squeeze(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
np.random.seed(0)
x = np.random.randn(1, 16, 1, 1).astype(nptype)
net = SqueezeNet()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())


def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_bool():
squeeze(np.bool)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_uint8():
squeeze(np.uint8)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_uint16():
squeeze(np.uint16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_uint32():
squeeze(np.uint32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int8():
squeeze(np.int8)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int16():
squeeze(np.int16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int32():
squeeze(np.int32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int64():
squeeze(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float16():
squeeze(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float32():
squeeze(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float64():
squeeze(np.float64)

Loading…
Cancel
Save