| @@ -2,12 +2,14 @@ import collections | |||
| import contextlib | |||
| import functools | |||
| import itertools | |||
| import json | |||
| import typing | |||
| import warnings | |||
| import weakref | |||
| import numpy as np | |||
| from ..core._imperative_rt import GraphProfiler | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| @@ -85,11 +87,14 @@ class trace: | |||
| symbolic=False, | |||
| capture_as_const=False, | |||
| sublinear_memory_config: SublinearMemoryConfig = None, | |||
| profiling: bool = False, | |||
| ): | |||
| self.__wrapped__ = function | |||
| self._symbolic = symbolic | |||
| self._capture_as_const = capture_as_const | |||
| self._sublinear_memory_config = sublinear_memory_config | |||
| self._profiling = profiling | |||
| self._profiler = None | |||
| self._untraced = True | |||
| self._tinfo = [] # handle -> TensorInfo | |||
| @@ -308,6 +313,8 @@ class trace: | |||
| ) | |||
| sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try | |||
| sublinear_config.num_worker = self._sublinear_memory_config.num_worker | |||
| if self._profiling: | |||
| self._profiler = GraphProfiler(graph) | |||
| def _compile(self): | |||
| graph = self._graph = G.Graph() | |||
| @@ -581,6 +588,16 @@ class trace: | |||
| % (output_names and output_names[i] or i) | |||
| ) | |||
| def get_profile(self): | |||
| """ | |||
| Get profiling result for compiled trace. | |||
| :return: a json compatible object. | |||
| """ | |||
| if not self._profiler: | |||
| raise RuntimeError("trace is not set with profiling=True") | |||
| return json.loads(self._profiler.get()) | |||
| class CompiledTensorProxy(RawTensor): | |||
| """ | |||
| @@ -11,18 +11,38 @@ | |||
| #include "./graph_rt.h" | |||
| #include "megbrain/graph/cg.h" | |||
| #include "megbrain/serialization/serializer.h" | |||
| #include "megbrain/imperative/opr_utility.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/imperative.h" | |||
| #include "./helper.h" | |||
| #include "megbrain/plugin/profiler.h" | |||
| namespace py = pybind11; | |||
| using namespace mgb; | |||
| using namespace imperative; | |||
| namespace { | |||
| class _CompGraphProfilerImpl { | |||
| std::shared_ptr<ComputingGraph> m_comp_graph; | |||
| GraphProfiler m_profiler; | |||
| public: | |||
| _CompGraphProfilerImpl(std::shared_ptr<ComputingGraph> cg): | |||
| m_comp_graph{cg}, | |||
| m_profiler{m_comp_graph.get()} | |||
| { | |||
| } | |||
| std::string _get_result() { | |||
| auto json = m_profiler.to_json_full( | |||
| m_comp_graph->current_comp_seq()); | |||
| return json->to_string(); | |||
| } | |||
| }; | |||
| } | |||
| #define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name) | |||
| template<typename T> | |||
| @@ -102,6 +122,12 @@ void init_graph_rt(py::module m) { | |||
| }) | |||
| .def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options)); | |||
| py::class_<_CompGraphProfilerImpl, std::shared_ptr<_CompGraphProfilerImpl>>(m, "GraphProfiler") | |||
| .def(py::init([](std::shared_ptr<ComputingGraph> graph) { | |||
| return std::make_shared<_CompGraphProfilerImpl>(graph); | |||
| })) | |||
| .def("get", [](_CompGraphProfilerImpl& profiler) { return profiler._get_result(); }); | |||
| m.def("dump_graph", [](const std::vector<VarNode*>& dest_vars) { | |||
| using namespace mgb::serialization; | |||
| std::vector<uint8_t> buf; | |||
| @@ -82,3 +82,22 @@ def test_dump(): | |||
| file = io.BytesIO() | |||
| f.dump(file) | |||
| def test_trace_profiler(): | |||
| for symbolic in [False, True]: | |||
| @trace(symbolic=symbolic, profiling=True) | |||
| def f(x): | |||
| op = ops.Elemwise(mode="negate") | |||
| (y,) = apply(op, x) | |||
| return y | |||
| x = as_raw_tensor([1]).numpy() | |||
| y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
| f(as_raw_tensor(x)) | |||
| f(as_raw_tensor(x)) # XXX: has to run twice | |||
| out = f.get_profile() | |||
| assert out.get("profiler") | |||