| @@ -315,9 +315,9 @@ class GraphInference: | |||||
| inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") | inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") | ||||
| self._inp_dict = OrderedDict() | self._inp_dict = OrderedDict() | ||||
| replace_dict = {} | replace_dict = {} | ||||
| for i in inputs: | |||||
| for idx, i in enumerate(inputs): | |||||
| inp_node = G.InputNode( | inp_node = G.InputNode( | ||||
| device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph | |||||
| device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph | |||||
| ) | ) | ||||
| self._inp_dict[i.name] = inp_node | self._inp_dict[i.name] = inp_node | ||||
| replace_dict[i] = inp_node.outputs[0] | replace_dict[i] = inp_node.outputs[0] | ||||
| @@ -1,13 +1,22 @@ | |||||
| import io | |||||
| import numpy as np | import numpy as np | ||||
| import megengine.utils.comp_graph_tools as cgtools | |||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.jit import trace | |||||
| def _default_compare_fn(x, y): | def _default_compare_fn(x, y): | ||||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||||
| if isinstance(x, np.ndarray): | |||||
| np.testing.assert_allclose(x, y, rtol=1e-6) | |||||
| else: | |||||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||||
| def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs): | |||||
| def opr_test( | |||||
| cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||||
| ): | |||||
| """ | """ | ||||
| :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | ||||
| and the dict should have input, | and the dict should have input, | ||||
| @@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||||
| if not isinstance(results, (tuple, list)): | if not isinstance(results, (tuple, list)): | ||||
| results = (results,) | results = (results,) | ||||
| for r, e in zip(results, expected): | for r, e in zip(results, expected): | ||||
| if not isinstance(r, tensor): | |||||
| r = tensor(r) | |||||
| compare_fn(r, e) | compare_fn(r, e) | ||||
| def get_param(cases, idx): | def get_param(cases, idx): | ||||
| @@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||||
| inp, outp = get_param(cases, 0) | inp, outp = get_param(cases, 0) | ||||
| inp_tensor = [tensor(inpi) for inpi in inp] | inp_tensor = [tensor(inpi) for inpi in inp] | ||||
| if test_trace: | |||||
| copied_inp = inp_tensor.copy() | |||||
| for symbolic in [False, True]: | |||||
| traced_func = trace(symbolic=symbolic)(func) | |||||
| for _ in range(3): | |||||
| traced_results = traced_func(*copied_inp, **kwargs) | |||||
| check_results(traced_results, outp) | |||||
| dumped_func = trace(symbolic=True, capture_as_const=True)(func) | |||||
| dumped_results = dumped_func(*copied_inp, **kwargs) | |||||
| check_results(dumped_results, outp) | |||||
| file = io.BytesIO() | |||||
| dump_info = dumped_func.dump(file) | |||||
| file.seek(0) | |||||
| # arg_name has pattern arg_xxx, xxx is int value | |||||
| def take_number(arg_name): | |||||
| return int(arg_name.split("_")[-1]) | |||||
| input_names = dump_info[4] | |||||
| inps_np = [i.numpy() for i in copied_inp] | |||||
| input_names.sort(key=take_number) | |||||
| inp_dict = dict(zip(input_names, inps_np)) | |||||
| infer_cg = cgtools.GraphInference(file) | |||||
| # assume #outputs == 1 | |||||
| loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] | |||||
| check_results(loaded_results, outp) | |||||
| results = func(*inp_tensor, **kwargs) | results = func(*inp_tensor, **kwargs) | ||||
| check_results(results, outp) | check_results(results, outp) | ||||
| @@ -36,7 +36,7 @@ def test_where(): | |||||
| {"input": [maskv0, xv0, yv0]}, | {"input": [maskv0, xv0, yv0]}, | ||||
| {"input": [maskv1, xv1, yv1]}, | {"input": [maskv1, xv1, yv1]}, | ||||
| ] | ] | ||||
| opr_test(cases, F.where, ref_fn=np.where) | |||||
| opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||||
| maskv2 = np.array([1, 1, 1], dtype=np.bool_) | maskv2 = np.array([1, 1, 1], dtype=np.bool_) | ||||
| xv2 = np.array([1, 3, 2], dtype=np.float32) | xv2 = np.array([1, 3, 2], dtype=np.float32) | ||||
| @@ -50,7 +50,7 @@ def test_where(): | |||||
| {"input": [maskv2, xv2, yv2]}, | {"input": [maskv2, xv2, yv2]}, | ||||
| {"input": [maskv3, xv3, yv3]}, | {"input": [maskv3, xv3, yv3]}, | ||||
| ] | ] | ||||
| opr_test(cases, F.where, ref_fn=np.where) | |||||
| opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||||
| def test_dropout(): | def test_dropout(): | ||||
| @@ -115,14 +115,17 @@ def test_matmul(): | |||||
| {"input": [data4, data5]}, | {"input": [data4, data5]}, | ||||
| ] | ] | ||||
| for _ in range(0, batch_size): | for _ in range(0, batch_size): | ||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test( | opr_test( | ||||
| cases, F.matmul, ref_fn=np.matmul, | |||||
| cases, F.matmul, test_trace=False, ref_fn=np.matmul, | |||||
| ) | ) | ||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test( | opr_test( | ||||
| [{"input": [data1, data4]}], | [{"input": [data1, data4]}], | ||||
| F.matmul, | F.matmul, | ||||
| ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), | ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), | ||||
| test_trace=False, | |||||
| transpose_b=True, | transpose_b=True, | ||||
| ) | ) | ||||
| @@ -162,20 +162,24 @@ def test_linspace(): | |||||
| {"input": [1, 9, 9]}, | {"input": [1, 9, 9]}, | ||||
| {"input": [3, 10, 8]}, | {"input": [3, 10, 8]}, | ||||
| ] | ] | ||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test( | opr_test( | ||||
| cases, | cases, | ||||
| F.linspace, | F.linspace, | ||||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
| test_trace=False, | |||||
| ) | ) | ||||
| cases = [ | cases = [ | ||||
| {"input": [9, 1, 9]}, | {"input": [9, 1, 9]}, | ||||
| {"input": [10, 3, 8]}, | {"input": [10, 3, 8]}, | ||||
| ] | ] | ||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test( | opr_test( | ||||
| cases, | cases, | ||||
| F.linspace, | F.linspace, | ||||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
| test_trace=False, | |||||
| ) | ) | ||||
| @@ -184,30 +188,36 @@ def test_arange(): | |||||
| {"input": [1, 9, 1]}, | {"input": [1, 9, 1]}, | ||||
| {"input": [2, 10, 2]}, | {"input": [2, 10, 2]}, | ||||
| ] | ] | ||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test( | opr_test( | ||||
| cases, | cases, | ||||
| F.arange, | F.arange, | ||||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
| test_trace=False, | |||||
| ) | ) | ||||
| cases = [ | cases = [ | ||||
| {"input": [9, 1, -1]}, | {"input": [9, 1, -1]}, | ||||
| {"input": [10, 2, -2]}, | {"input": [10, 2, -2]}, | ||||
| ] | ] | ||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test( | opr_test( | ||||
| cases, | cases, | ||||
| F.arange, | F.arange, | ||||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
| test_trace=False, | |||||
| ) | ) | ||||
| cases = [ | cases = [ | ||||
| {"input": [9.3, 1.2, -0.5]}, | {"input": [9.3, 1.2, -0.5]}, | ||||
| {"input": [10.3, 2.1, -1.7]}, | {"input": [10.3, 2.1, -1.7]}, | ||||
| ] | ] | ||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test( | opr_test( | ||||
| cases, | cases, | ||||
| F.arange, | F.arange, | ||||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
| test_trace=False, | |||||
| ) | ) | ||||
| @@ -279,7 +289,8 @@ def test_broadcast(): | |||||
| {"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
| {"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
| ] | ] | ||||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||||
| # FIXME: remove test_trace=False in the future | |||||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False) | |||||
| x = F.ones((2, 1, 3)) | x = F.ones((2, 1, 3)) | ||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||