| @@ -315,9 +315,9 @@ class GraphInference: | |||
| inputs = get_dep_vars(output_nodes, "Host2DeviceCopy") | |||
| self._inp_dict = OrderedDict() | |||
| replace_dict = {} | |||
| for i in inputs: | |||
| for idx, i in enumerate(inputs): | |||
| inp_node = G.InputNode( | |||
| device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph | |||
| device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph | |||
| ) | |||
| self._inp_dict[i.name] = inp_node | |||
| replace_dict[i] = inp_node.outputs[0] | |||
| @@ -1,13 +1,22 @@ | |||
| import io | |||
| import numpy as np | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine import tensor | |||
| from megengine.jit import trace | |||
| def _default_compare_fn(x, y): | |||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
| if isinstance(x, np.ndarray): | |||
| np.testing.assert_allclose(x, y, rtol=1e-6) | |||
| else: | |||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
| def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs): | |||
| def opr_test( | |||
| cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||
| ): | |||
| """ | |||
| :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||
| and the dict should have input, | |||
| @@ -35,6 +44,8 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||
| if not isinstance(results, (tuple, list)): | |||
| results = (results,) | |||
| for r, e in zip(results, expected): | |||
| if not isinstance(r, tensor): | |||
| r = tensor(r) | |||
| compare_fn(r, e) | |||
| def get_param(cases, idx): | |||
| @@ -63,5 +74,36 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||
| inp, outp = get_param(cases, 0) | |||
| inp_tensor = [tensor(inpi) for inpi in inp] | |||
| if test_trace: | |||
| copied_inp = inp_tensor.copy() | |||
| for symbolic in [False, True]: | |||
| traced_func = trace(symbolic=symbolic)(func) | |||
| for _ in range(3): | |||
| traced_results = traced_func(*copied_inp, **kwargs) | |||
| check_results(traced_results, outp) | |||
| dumped_func = trace(symbolic=True, capture_as_const=True)(func) | |||
| dumped_results = dumped_func(*copied_inp, **kwargs) | |||
| check_results(dumped_results, outp) | |||
| file = io.BytesIO() | |||
| dump_info = dumped_func.dump(file) | |||
| file.seek(0) | |||
| # arg_name has pattern arg_xxx, xxx is int value | |||
| def take_number(arg_name): | |||
| return int(arg_name.split("_")[-1]) | |||
| input_names = dump_info[4] | |||
| inps_np = [i.numpy() for i in copied_inp] | |||
| input_names.sort(key=take_number) | |||
| inp_dict = dict(zip(input_names, inps_np)) | |||
| infer_cg = cgtools.GraphInference(file) | |||
| # assume #outputs == 1 | |||
| loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0] | |||
| check_results(loaded_results, outp) | |||
| results = func(*inp_tensor, **kwargs) | |||
| check_results(results, outp) | |||
| @@ -36,7 +36,7 @@ def test_where(): | |||
| {"input": [maskv0, xv0, yv0]}, | |||
| {"input": [maskv1, xv1, yv1]}, | |||
| ] | |||
| opr_test(cases, F.where, ref_fn=np.where) | |||
| opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||
| maskv2 = np.array([1, 1, 1], dtype=np.bool_) | |||
| xv2 = np.array([1, 3, 2], dtype=np.float32) | |||
| @@ -50,7 +50,7 @@ def test_where(): | |||
| {"input": [maskv2, xv2, yv2]}, | |||
| {"input": [maskv3, xv3, yv3]}, | |||
| ] | |||
| opr_test(cases, F.where, ref_fn=np.where) | |||
| opr_test(cases, F.where, ref_fn=np.where, test_trace=False) | |||
| def test_dropout(): | |||
| @@ -115,14 +115,17 @@ def test_matmul(): | |||
| {"input": [data4, data5]}, | |||
| ] | |||
| for _ in range(0, batch_size): | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test( | |||
| cases, F.matmul, ref_fn=np.matmul, | |||
| cases, F.matmul, test_trace=False, ref_fn=np.matmul, | |||
| ) | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test( | |||
| [{"input": [data1, data4]}], | |||
| F.matmul, | |||
| ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), | |||
| test_trace=False, | |||
| transpose_b=True, | |||
| ) | |||
| @@ -162,20 +162,24 @@ def test_linspace(): | |||
| {"input": [1, 9, 9]}, | |||
| {"input": [3, 10, 8]}, | |||
| ] | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test( | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
| test_trace=False, | |||
| ) | |||
| cases = [ | |||
| {"input": [9, 1, 9]}, | |||
| {"input": [10, 3, 8]}, | |||
| ] | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test( | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
| test_trace=False, | |||
| ) | |||
| @@ -184,30 +188,36 @@ def test_arange(): | |||
| {"input": [1, 9, 1]}, | |||
| {"input": [2, 10, 2]}, | |||
| ] | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test( | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| test_trace=False, | |||
| ) | |||
| cases = [ | |||
| {"input": [9, 1, -1]}, | |||
| {"input": [10, 2, -2]}, | |||
| ] | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test( | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| test_trace=False, | |||
| ) | |||
| cases = [ | |||
| {"input": [9.3, 1.2, -0.5]}, | |||
| {"input": [10.3, 2.1, -1.7]}, | |||
| ] | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test( | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| test_trace=False, | |||
| ) | |||
| @@ -279,7 +289,8 @@ def test_broadcast(): | |||
| {"input": [data1, output1_shape], "output": output1_shape}, | |||
| {"input": [data2, output2_shape], "output": output2_shape}, | |||
| ] | |||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
| # FIXME: remove test_trace=False in the future | |||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False) | |||
| x = F.ones((2, 1, 3)) | |||
| with pytest.raises(RuntimeError): | |||