|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- @File : test_parse_method.py
- @Author:
- @Date : 2019-06-27
- @Desc : test parse the object's method
- """
- import logging
- import numpy as np
- import mindspore.nn as nn
- from dataclasses import dataclass
- from mindspore.common.api import ms_function
- from mindspore.common.tensor import Tensor
- from mindspore.ops.composite import core
- from mindspore._extends.parse.standard_method import ms_len
- from ..ut_filter import non_graph_engine
- import pytest
-
- from mindspore import context
-
-
- def setup_module(module):
- context.set_context(mode=context.PYNATIVE_MODE)
-
-
- log = logging.getLogger("test")
- log.setLevel(level=logging.ERROR)
-
-
- @ms_function
- def default_parameter_f(x, y=3):
- """ default_parameter_f """
- z = x + y
- return z
-
-
- # Test case: test parse fn that use default parameter
- def test_parse_defalut_parameter_case1():
- """ Test default parameter function call """
- log.debug("begin test_parse_defalut_parameter_case1")
- ret = default_parameter_f(2)
- log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
-
-
- def get_val_fn(x):
- """ get_val_fn """
- ret = x + 3
- return ret
-
-
- # Test case: test bool not
- @ms_function
- def bool_exp(x, y):
- """ bool_exp """
- return not x > y
-
-
- def test_bool_exp():
- """ test_bool_exp """
- bool_exp(1, 2)
-
-
- # Test case: use the variable parameter for @mindspore
- @ms_function
- def var_parameter_f(x, *args):
- """ var_parameter_f """
- z = x + args[0] + args[1] + args[2]
- return z
-
-
- def test_var_parameter_case1():
- """ test_var_parameter_case1 """
- log.debug("start test_var_parameter_case1")
- var_parameter_f(1, 2, 3, 4, 5)
- log.debug("end test_var_parameter_case1")
-
-
- class Net(nn.Cell):
- """ Net definition """
- def __init__(self, value1):
- super(Net, self).__init__()
- self.relu = nn.ReLU()
- self.softmax = nn.Softmax(0)
- self.axis = 0
- self.TC = ClassTest("test_class", 1.2)
- self.value = value1
-
- @ms_function
- def construct(self, x):
- x = self.get_test_value(x)
- return x
-
- def get_test_value(self, x):
- ret = x + self.value
- return ret
-
-
- class ClassTest:
- """ ClassTest definition """
- def __init__(self, name, value1):
- self.name = name
- self.value = value1
-
- def get_name(self):
- return self.name
-
- def get_value(self, inc):
- ret = self.value + inc
- return ret
-
- def __call__(self, *args, **kwargs):
- pass
-
-
- # Test: call method on parse graph code
- @non_graph_engine
- def test_call_method_on_construct():
- """ test_call_method_on_construct """
- log.debug("begin test_call_method_on_construct")
-
- x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
- y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
- z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
-
- net = Net(y)
- output = net.construct(x)
- result = output.asnumpy()
- print(result)
- assert np.all(result == z)
-
- log.debug("finished test_call_method_on_construct")
-
-
- # Test: call method on parse graph code
- class Net1(nn.Cell):
- """ Net1 definition """
- def __init__(self, v1, v2):
- super(Net1, self).__init__()
- self.relu = nn.ReLU()
- self.softmax = nn.Softmax(0)
- self.axis = 0
- self.TC = ClassTest("test_class", v1)
- self.value = v2
-
- @ms_function
- def construct(self, x):
- x = x + self.TC.get_value(self.value)
- return x
-
-
- @non_graph_engine
- def test_call_other_object_method():
- """ test_call_other_object_method """
- log.debug("begin test_call_other_object_method")
-
- x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
- y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
- y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
- z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
-
- net = Net1(y, y1)
- with pytest.raises(TypeError):
- output = net.construct(x)
- result = output.asnumpy()
- print(result)
- assert np.all(result == z)
-
- log.debug("finished test_call_other_object_method")
-
-
- # Test: call global object method(not self) on parse graph code
- value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
- TC = ClassTest("test_class", value)
-
-
- class Net2(nn.Cell):
- """ Net2 definition """
- def __init__(self, value1):
- super(Net2, self).__init__()
- self.value = value1
-
- @ms_function
- def construct(self, x):
- x = x + TC.get_value(self.value)
- return x
-
- @ms_function
- def construct1(self, x):
- x = x + TC.value
- x = x + self.value
- return x
-
-
- @non_graph_engine
- def test_call_no_self_other_object_method():
- """ test_call_no_self_other_object_method """
- log.debug("begin test_call_other_object_method")
- x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
- y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
- z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
-
- net = Net2(y)
- with pytest.raises(TypeError):
- output = net.construct(x)
- result = output.asnumpy()
- print(result)
- assert np.all(result == z)
-
- log.debug("finished test_call_other_object_method")
-
-
- def test_call_no_self_other_object_attr_value():
- """ test_call_no_self_other_object_attr_value """
- # do not support tensor as init input.
- return
-
-
- # Test case: use the * to unlock the varargs for @mindspore
- def vararg1(x, y):
- """ vararg1 """
- z = x + y
- return z
-
-
- def varargs_main(fn):
- """ varargs_main """
- @ms_function
- def t1(*args):
- return fn(*args)
- return t1
-
-
- def test_var_parameter_case3():
- """ test_var_parameter_case3 """
- log.debug("start test_var_parameter_case3")
- ret = varargs_main(vararg1)(1, 2)
- log.debug("ret = %r", ret)
- log.debug("end test_var_parameter_case3")
-
-
- # Test case: test the flag set
- @core(tg=True)
- def set_flag(x):
- """ set_flag """
- return x + 1
-
-
- @ms_function
- def set_test_flag_main(x, y):
- """ set_test_flag_main """
- z = set_flag(x)
- z = z + y
- return z
-
-
- def test_set_flag():
- """ Test default parameter function call """
- log.debug("begin test_set_flag")
- ret = set_test_flag_main(2, 3)
- log.debug("finished test_set_flag, ret = %r", ret)
-
-
- @dataclass
- class Access:
- a: int
- b: int
-
- def max(self):
- if self.a > self.b:
- return self.a
- return self.b
-
- @ms_function
- def invoke_dataclass(x, y):
- """ invoke_dataclass """
- acs = Access(x, y)
- return acs.max()
-
-
- def test_access():
- """ test_access """
- invoke_dataclass(1, 2)
-
-
- def myfunc(x):
- """ myfunc """
- return x * x
-
-
- @ms_function
- def ms_infer_for():
- """ ms_infer_for """
- a = 0.0
- for x in [1.1, 2.3, 3.3]:
- a = a + x
- return a
-
-
- def test_infer_for():
- """ test_infer_for """
- ms_infer_for()
-
-
- @ms_function
- def ms_infer_for_func(y):
- """ ms_infer_for_func """
- for x in [1.0, 2.0, 3.0]:
- y = myfunc(x) + y
- return y
-
-
- def test_ms_infer_for_func():
- """ test_ms_infer_for_func """
- ms_infer_for_func(1.0)
-
-
- @ms_function
- def add(x, y):
- """ add """
- return x + y
-
-
- def test_add():
- """ test_add """
- res = add(1, 2.0)
- return res
-
-
- @ms_function
- def add_list():
- """ add_list """
- a = [1, 2, 3]
- b = a[1] + a[2]
- return b
-
-
- def test_list():
- """ test_list """
- return add_list()
-
-
- @ms_function
- def compare_list_len():
- """ compare_list_len """
- a = [1, 2, 3]
- return ms_len(a)
-
-
- def test_list_len():
- """ test_list_len """
- compare_list_len()
-
-
- @ms_function
- def add_tuple():
- """ add_tuple """
- a = (1, 2, 3)
- b = a[1] + a[2]
- return b
-
-
- def test_tuple():
- """ test_tuple """
- return add_tuple()
-
-
- def invoke_func(x):
- """ invoke_func """
- return x * x
-
-
- @ms_function
- def tuple_of_node(x, y):
- """ tuple_of_node """
- a = invoke_func(x)
- b = invoke_func(y)
- c = (a, b)
- d = c[1] * x
- return d
-
-
- def test_tuple_node():
- """ test_tuple_node """
- res = tuple_of_node(1, 2)
- return res
-
-
- @ms_function
- def range_spec(x, y):
- """ range_spec """
- for _ in range(1, 10, 3):
- x = x + 1
- return x + y
-
-
- def test_range():
- """ test_range """
- res = range_spec(10, 10)
- return res
|