GitOrigin-RevId: 0e688ebd59
tags/v1.3.0
| @@ -27,7 +27,6 @@ __all__ = [ | |||||
| "replace_vars", | "replace_vars", | ||||
| "replace_oprs", | "replace_oprs", | ||||
| "set_priority_to_id", | "set_priority_to_id", | ||||
| "load_and_inference", | |||||
| "GraphInference", | "GraphInference", | ||||
| ] | ] | ||||
| @@ -274,21 +273,6 @@ def replace_oprs( | |||||
| return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | ||||
| def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]: | |||||
| """ | |||||
| Loads a serialized computing graph and run inference with input data. | |||||
| :param file: path or handle of the input file. | |||||
| :param inp_data_list: list of input data. | |||||
| :return: list of inference results. | |||||
| """ | |||||
| graph = GraphInference(file) | |||||
| result = graph.run(*inp_data_list) | |||||
| out_data_list = list(result.values()) | |||||
| return out_data_list | |||||
| class GraphInference: | class GraphInference: | ||||
| """ | """ | ||||
| Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. | Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph. | ||||
| @@ -201,5 +201,6 @@ def test_quantize_batchmatmul_activation(): | |||||
| file = io.BytesIO() | file = io.BytesIO() | ||||
| f.dump(file, enable_nchw4=True) | f.dump(file, enable_nchw4=True) | ||||
| file.seek(0) | file.seek(0) | ||||
| dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] | |||||
| infer_cg = cgtools.GraphInference(file)[0] | |||||
| dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0] | |||||
| np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) | np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) | ||||
| @@ -141,7 +141,8 @@ def test_dump(): | |||||
| np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | ||||
| np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) | np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) | ||||
| file.seek(0) | file.seek(0) | ||||
| result = cgtools.load_and_inference(file, [a, b]) | |||||
| infer_cg = cgtools.GraphInference(file) | |||||
| result = list((infer_cg.run(a, b)).values())[0] | |||||
| np.testing.assert_equal(result[0], y) | np.testing.assert_equal(result[0], y) | ||||
| @@ -161,7 +162,8 @@ def test_capture_dump(): | |||||
| file = io.BytesIO() | file = io.BytesIO() | ||||
| f.dump(file) | f.dump(file) | ||||
| file.seek(0) | file.seek(0) | ||||
| result = cgtools.load_and_inference(file, [x]) | |||||
| infer_cg = cgtools.GraphInference(file) | |||||
| result = list((infer_cg.run(x)).values())[0] | |||||
| np.testing.assert_equal(result[0], y) | np.testing.assert_equal(result[0], y) | ||||