Browse Source

add aicpu ops of Reshape/Flatten/Squeeze/ExpandDims/IsFinite

tags/v0.3.0-alpha
yanzhenxiang2020 5 years ago
parent
commit
691337a6e1
15 changed files with 843 additions and 5 deletions
  1. +0
    -1
      mindspore/ccsrc/kernel/aicpu/aicpu_util.h
  2. +18
    -0
      mindspore/ccsrc/pre_activate/common/helper.cc
  3. +5
    -0
      mindspore/ops/_op_impl/aicpu/__init__.py
  4. +52
    -0
      mindspore/ops/_op_impl/aicpu/expand_dims.py
  5. +48
    -0
      mindspore/ops/_op_impl/aicpu/flatten.py
  6. +52
    -0
      mindspore/ops/_op_impl/aicpu/is_finite.py
  7. +52
    -0
      mindspore/ops/_op_impl/aicpu/reshape.py
  8. +52
    -0
      mindspore/ops/_op_impl/aicpu/squeeze.py
  9. +0
    -3
      mindspore/ops/_op_impl/tbe/__init__.py
  10. +10
    -1
      mindspore/ops/op_info_register.py
  11. +114
    -0
      tests/st/ops/davinci/test_aicpu_ops/test_expand_dims.py
  12. +99
    -0
      tests/st/ops/davinci/test_aicpu_ops/test_flatten.py
  13. +114
    -0
      tests/st/ops/davinci/test_aicpu_ops/test_is_finite.py
  14. +114
    -0
      tests/st/ops/davinci/test_aicpu_ops/test_reshape.py
  15. +113
    -0
      tests/st/ops/davinci/test_aicpu_ops/test_squeeze.py

+ 0
- 1
mindspore/ccsrc/kernel/aicpu/aicpu_util.h View File

@@ -27,7 +27,6 @@ namespace kernel {
constexpr auto kInitDataSetQueue = "InitDataSetQueue";
constexpr auto kInitData = "InitData";
constexpr auto kGetNext = "GetNext";
constexpr auto kDropoutGenMask = "DropoutGenMask";
constexpr auto kPrint = "Print";

constexpr auto kOutputTypes = "output_types";


+ 18
- 0
mindspore/ccsrc/pre_activate/common/helper.cc View File

@@ -340,8 +340,23 @@ bool IsNopNode(const AnfNodePtr &node) {
return true;
}

bool IsAllNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &cnode : execution_order) {
MS_EXCEPTION_IF_NULL(cnode);
if (!IsNopNode(cnode)) {
return false;
}
}
return true;
}

void HideNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
if (IsAllNopNode(graph) == true) {
return;
}
auto execution_order = graph->execution_order();
MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
std::vector<CNodePtr> new_nodes;
@@ -357,6 +372,9 @@ void HideNopNode(session::KernelGraph *const graph) {

void RemoveNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
if (IsAllNopNode(graph) == true) {
return;
}
bool changed = true;
while (changed) {
changed = false;


+ 5
- 0
mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -17,3 +17,8 @@ from .init_data_set_queue import _init_data_set_queue_aicpu
from .dropout_genmask import _dropout_genmask_aicpu
from .get_next import _get_next_aicpu
from .print_tensor import _print_aicpu
from .is_finite import _is_finite_aicpu
from .reshape import _reshape_aicpu
from .flatten import _flatten_aicpu
from .squeeze import _squeeze_aicpu
from .expand_dims import _expand_dims_aicpu

+ 52
- 0
mindspore/ops/_op_impl/aicpu/expand_dims.py View File

@@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ExpandDims op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

expand_dims_op_info = AiCPURegOp("ExpandDims") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()

@op_info_register(expand_dims_op_info)
def _expand_dims_aicpu():
"""ExpandDims AiCPU register"""
return

+ 48
- 0
mindspore/ops/_op_impl/aicpu/flatten.py View File

@@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Flatten op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

flatten_op_info = AiCPURegOp("Flatten") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()

@op_info_register(flatten_op_info)
def _flatten_aicpu():
"""Flatten AiCPU register"""
return

+ 52
- 0
mindspore/ops/_op_impl/aicpu/is_finite.py View File

@@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""IsFinite op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

is_finite_op_info = AiCPURegOp("IsFinite") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.BOOL_NCHW) \
.get_op_info()

@op_info_register(is_finite_op_info)
def _is_finite_aicpu():
"""IsFinite AiCPU register"""
return

+ 52
- 0
mindspore/ops/_op_impl/aicpu/reshape.py View File

@@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Reshape op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

reshape_op_info = AiCPURegOp("Reshape") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()

@op_info_register(reshape_op_info)
def _reshape_aicpu():
"""Rpeshape AiCPU register"""
return

+ 52
- 0
mindspore/ops/_op_impl/aicpu/squeeze.py View File

@@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Squeeze op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

squeeze_op_info = AiCPURegOp("Squeeze") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()

@op_info_register(squeeze_op_info)
def _squeeze_aicpu():
"""Squeeze AiCPU register"""
return

+ 0
- 3
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -61,9 +61,6 @@ from .reduce_mean_d import _reduce_mean_d_tbe
from .scatter_nd import _scatter_nd_tbe
from .scatter_nd_d import _scatter_nd_d_tbe
from .reduce_mean import _reduce_mean_tbe
from .reshape import _reshape_tbe
from .expand_dims import _expand_dims_tbe
from .squeeze import _squeeze_tbe
from .tile import _tile_tbe
from .atomic_addr_clean import _atomic_addr_clean_tbe
from .gather_v2 import _gather_v2_tbe


+ 10
- 1
mindspore/ops/op_info_register.py View File

@@ -599,4 +599,13 @@ class DataType:
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")

F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
F64_5HD = ("float64", "NC1HWC0")
F64_FracZ = ("float64", "FracZ")
F64_FracNZ = ("float64", "FRACTAL_NZ")
F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")

+ 114
- 0
tests/st/ops/davinci/test_aicpu_ops/test_expand_dims.py View File

@@ -0,0 +1,114 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.expand_dims = P.ExpandDims()

def construct(self, tensor, dim):
return self.expand_dims(tensor, dim)


def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x), -1)
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.expand_dims(x, -1)))

+ 99
- 0
tests/st/ops/davinci/test_aicpu_ops/test_flatten.py View File

@@ -0,0 +1,99 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flatten = P.Flatten()

def construct(self, tensor):
return self.flatten(tensor)


def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.flatten()))

+ 114
- 0
tests/st/ops/davinci/test_aicpu_ops/test_is_finite.py View File

@@ -0,0 +1,114 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.isfinite = P.IsFinite()

def construct(self, tensor):
return self.isfinite(tensor)


def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.isfinite(x)))

+ 114
- 0
tests/st/ops/davinci/test_aicpu_ops/test_reshape.py View File

@@ -0,0 +1,114 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.reshape = P.Reshape()

def construct(self, tensor):
return self.reshape(tensor, (4,4))


def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == np.reshape(x, (4,4))))

+ 113
- 0
tests/st/ops/davinci/test_aicpu_ops/test_squeeze.py View File

@@ -0,0 +1,113 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import numpy as np
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.squeeze = P.Squeeze()

def construct(self, tensor):
return self.squeeze(tensor)


def test_net_bool():
x = np.random.randn(1, 16, 1, 1).astype(np.bool)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_int8():
x = np.random.randn(1, 16, 1, 1).astype(np.int8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))
def test_net_uint8():
x = np.random.randn(1, 16, 1, 1).astype(np.uint8)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_int16():
x = np.random.randn(1, 16, 1, 1).astype(np.int16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_uint16():
x = np.random.randn(1, 16, 1, 1).astype(np.uint16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_int32():
x = np.random.randn(1, 16, 1, 1).astype(np.int32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_uint32():
x = np.random.randn(1, 16, 1, 1).astype(np.uint32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_int64():
x = np.random.randn(1, 16, 1, 1).astype(np.int64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_uint64():
x = np.random.randn(1, 16, 1, 1).astype(np.uint64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_float16():
x = np.random.randn(1, 16, 1, 1).astype(np.float16)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_float32():
x = np.random.randn(1, 16, 1, 1).astype(np.float32)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

def test_net_float64():
x = np.random.randn(1, 16, 1, 1).astype(np.float64)
net = Net()
output = net(Tensor(x))
print(output.asnumpy())
assert(np.all(output.asnumpy() == x.squeeze()))

Loading…
Cancel
Save