Browse Source

!12399 Add type support to Squeeze gpu op

From: @peilin-wang
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
feb07198e7
2 changed files with 81 additions and 50 deletions
  1. +7
    -2
      mindspore/ops/_op_impl/akg/gpu/squeeze.py
  2. +74
    -48
      tests/st/ops/gpu/test_squeeze_op.py

+ 7
- 2
mindspore/ops/_op_impl/akg/gpu/squeeze.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -22,9 +22,14 @@ squeeze_op_info = AkgGpuRegOp("Squeeze") \
.attr("axis", "optional", "listInt") \ .attr("axis", "optional", "listInt") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \ .dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()




+ 74
- 48
tests/st/ops/gpu/test_squeeze_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,67 +13,93 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest


import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P


context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")


class Net(nn.Cell):
class SqueezeNet(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__()
super(SqueezeNet, self).__init__()
self.squeeze = P.Squeeze() self.squeeze = P.Squeeze()


def construct(self, tensor): def construct(self, tensor):
return self.squeeze(tensor) return self.squeeze(tensor)




def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())


def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())


def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())


def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())

def squeeze(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
np.random.seed(0)
x = np.random.randn(1, 16, 1, 1).astype(nptype)
net = SqueezeNet()
output = net(Tensor(x)) output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze()) assert np.all(output.asnumpy() == x.squeeze())



def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_bool():
squeeze(np.bool)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_uint8():
squeeze(np.uint8)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_uint16():
squeeze(np.uint16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_uint32():
squeeze(np.uint32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int8():
squeeze(np.int8)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int16():
squeeze(np.int16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int32():
squeeze(np.int32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int64():
squeeze(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float16():
squeeze(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float32():
squeeze(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float64():
squeeze(np.float64)

Loading…
Cancel
Save